Skip to content

Commit 3888862

Browse files
authored
Merge pull request #621 from Jakuje/large-scp
2 parents c43aad2 + 985f44d commit 3888862

File tree

3 files changed

+62
-8
lines changed

3 files changed

+62
-8
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Downloading files larger than 64kB over SCP no longer fails -- by :user:`Jakuje`.

src/pylibsshext/scp.pyx

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ from pylibsshext.errors cimport LibsshSCPException
2424
from pylibsshext.session cimport get_libssh_session
2525

2626

27+
SCP_MAX_CHUNK = 65536
28+
29+
2730
cdef class SCP:
2831
def __cinit__(self, session):
2932
self.session = session
@@ -122,7 +125,9 @@ cdef class SCP:
122125
size = libssh.ssh_scp_request_get_size(scp)
123126
mode = libssh.ssh_scp_request_get_permissions(scp)
124127

125-
read_buffer = <char *>PyMem_Malloc(size)
128+
# cap the buffer size to reasonable number -- libssh will not return the whole data at once anyway
129+
read_buffer_size = min(size, SCP_MAX_CHUNK)
130+
read_buffer = <char *>PyMem_Malloc(read_buffer_size)
126131
if read_buffer is NULL:
127132
raise LibsshSCPException("Memory allocation error")
128133

@@ -131,14 +136,17 @@ cdef class SCP:
131136
if rc == libssh.SSH_ERROR:
132137
raise LibsshSCPException("Failed to start read request: %s" % self._get_ssh_error_str())
133138

134-
# Read the file
135-
rc = libssh.ssh_scp_read(scp, read_buffer, size)
136-
if rc == libssh.SSH_ERROR:
137-
raise LibsshSCPException("Error receiving file data: %s" % self._get_ssh_error_str())
138-
139-
py_file_bytes = read_buffer[:size]
139+
remaining_bytes_to_read = size
140140
with open(local_file, "wb") as f:
141-
f.write(py_file_bytes)
141+
while remaining_bytes_to_read > 0:
142+
requested_read_bytes = min(remaining_bytes_to_read, read_buffer_size)
143+
read_bytes = libssh.ssh_scp_read(scp, read_buffer, requested_read_bytes)
144+
if read_bytes == libssh.SSH_ERROR:
145+
raise LibsshSCPException("Error receiving file data: %s" % self._get_ssh_error_str())
146+
147+
py_file_bytes = read_buffer[:read_bytes]
148+
f.write(py_file_bytes)
149+
remaining_bytes_to_read -= read_bytes
142150
if mode >= 0:
143151
os.chmod(local_file, mode)
144152

tests/unit/scp_test.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
"""Tests suite for scp."""
44

55
import os
6+
import random
7+
import string
68
import uuid
79

810
import pytest
@@ -75,3 +77,46 @@ def test_copy_from_non_existent_remote_path(path_to_non_existent_src_file, ssh_s
7577
error_msg = '^Error receiving information about file:'
7678
with pytest.raises(LibsshSCPException, match=error_msg):
7779
ssh_scp.get(str(path_to_non_existent_src_file), os.devnull)
80+
81+
82+
@pytest.fixture
83+
def pre_existing_file_path(tmp_path):
84+
"""Return local path for a pre-populated file."""
85+
path = tmp_path / 'pre-existing-file.txt'
86+
path.write_bytes(b'whatever')
87+
return path
88+
89+
90+
def test_get_existing_local(pre_existing_file_path, src_path, ssh_scp, transmit_payload):
91+
"""Check that SCP file download works and overwrites local file if it exists."""
92+
ssh_scp.get(str(src_path), str(pre_existing_file_path))
93+
assert pre_existing_file_path.read_bytes() == transmit_payload
94+
95+
96+
@pytest.fixture
97+
def large_payload():
98+
"""Generate a large 65537 byte (64kB+1B) test payload."""
99+
random_char_kilobyte = [ord(random.choice(string.printable)) for _ in range(1024)]
100+
full_bytes_number = 64
101+
a_64kB_chunk = bytes(random_char_kilobyte * full_bytes_number)
102+
the_last_byte = random.choice(random_char_kilobyte).to_bytes(length=1, byteorder='big')
103+
return a_64kB_chunk + the_last_byte
104+
105+
106+
@pytest.fixture
107+
def src_path_large(tmp_path, large_payload):
108+
"""Return a remote path that to a 65537 byte-sized file.
109+
110+
Typical single-read chunk size is 64kB in ``libssh`` so
111+
the test needs a file that would overflow that to trigger
112+
the read loop.
113+
"""
114+
path = tmp_path / 'large.txt'
115+
path.write_bytes(large_payload)
116+
return path
117+
118+
119+
def test_get_large(dst_path, src_path_large, ssh_scp, large_payload):
120+
"""Check that SCP file download gets over 64kB of data."""
121+
ssh_scp.get(str(src_path_large), str(dst_path))
122+
assert dst_path.read_bytes() == large_payload

0 commit comments

Comments
 (0)