Skip to content

Commit 743966d

Browse files
committed
Add client and server support for SFTP copy-data extension
This commit adds client and server support for the SFTP "copy-data" extension, and a new remote_copy() method on SFTPClient wihch allows you to make a request to copy bytes between two files on the remote server without needing to download and re-upload the data, if the server supports it. Thanks go to Ali Khosravi for suggesting this addition.
1 parent 4917d8d commit 743966d

File tree

3 files changed

+211
-19
lines changed

3 files changed

+211
-19
lines changed

asyncssh/sftp.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@
161161
MAX_SFTP_WRITE_LEN = 4*1024*1024 # 4 MiB
162162
MAX_SFTP_PACKET_LEN = MAX_SFTP_WRITE_LEN + 1024
163163

164+
_COPY_DATA_BLOCK_SIZE = 256*1024 # 256 KiB
165+
164166
_MAX_SFTP_REQUESTS = 128
165167
_MAX_READDIR_NAMES = 128
166168

@@ -806,6 +808,24 @@ async def run(self) -> None:
806808
if self._progress_handler and self._total_bytes == 0:
807809
self._progress_handler(self._srcpath, self._dstpath, 0, 0)
808810

811+
if self._srcfs == self._dstfs and \
812+
isinstance(self._srcfs, SFTPClient):
813+
try:
814+
await self._srcfs.remote_copy(
815+
cast(SFTPClientFile, self._src),
816+
cast(SFTPClientFile, self._dst))
817+
except SFTPOpUnsupported:
818+
pass
819+
else:
820+
self._bytes_copied = self._total_bytes
821+
822+
if self._progress_handler:
823+
self._progress_handler(self._srcpath, self._dstpath,
824+
self._bytes_copied,
825+
self._total_bytes)
826+
827+
return
828+
809829
async for _, datalen in self.iter():
810830
if datalen:
811831
self._bytes_copied += datalen
@@ -822,8 +842,6 @@ async def run(self) -> None:
822842
setattr(exc, 'offset', self._bytes_copied)
823843

824844
raise exc
825-
826-
827845
finally:
828846
if self._src: # pragma: no branch
829847
await self._src.close()
@@ -2472,6 +2490,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop,
24722490
self._supports_fsync = False
24732491
self._supports_lsetstat = False
24742492
self._supports_limits = False
2493+
self._supports_copy_data = False
24752494

24762495
@property
24772496
def version(self) -> int:
@@ -2692,6 +2711,8 @@ async def start(self) -> None:
26922711
self._supports_lsetstat = True
26932712
elif name == b'limits@openssh.com' and data == b'1':
26942713
self._supports_limits = True
2714+
elif name == b'copy-data' and data == b'1':
2715+
self._supports_copy_data = True
26952716

26962717
if version == 3:
26972718
# Check if the server has a buggy SYMLINK implementation
@@ -3090,6 +3111,26 @@ async def fsync(self, handle: bytes) -> None:
30903111
else:
30913112
raise SFTPOpUnsupported('fsync not supported')
30923113

3114+
async def copy_data(self, read_from_handle: bytes, read_from_offset: int,
3115+
read_from_length: int, write_to_handle: bytes,
3116+
write_to_offset: int) -> None:
3117+
"""Make an SFTP copy data request"""
3118+
3119+
if self._supports_copy_data:
3120+
self.logger.debug1('Sending copy-data from handle %s, '
3121+
'offset %d, length %d to handle %s, '
3122+
'offset %d', read_from_handle.hex(),
3123+
read_from_offset, read_from_length,
3124+
write_to_handle.hex(), write_to_offset)
3125+
3126+
await self._make_request(b'copy-data', String(read_from_handle),
3127+
UInt64(read_from_offset),
3128+
UInt64(read_from_length),
3129+
String(write_to_handle),
3130+
UInt64(write_to_offset))
3131+
else:
3132+
raise SFTPOpUnsupported('copy-data not supported')
3133+
30933134
def exit(self) -> None:
30943135
"""Handle a request to close the SFTP session"""
30953136

@@ -3142,6 +3183,15 @@ async def __aexit__(self, _exc_type: Optional[Type[BaseException]],
31423183
await self.close()
31433184
return False
31443185

3186+
@property
3187+
def handle(self) -> bytes:
3188+
"""Return handle or raise an error if clsoed"""
3189+
3190+
if self._handle is None:
3191+
raise ValueError('I/O operation on closed file')
3192+
3193+
return self._handle
3194+
31453195
async def _end(self) -> int:
31463196
"""Return the offset of the end of the file"""
31473197

@@ -4233,6 +4283,35 @@ async def mcopy(self, srcpaths: _SFTPPaths,
42334283
block_size, max_requests, progress_handler,
42344284
error_handler)
42354285

4286+
async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
4287+
src_offset: int = 0, src_length: int = 0,
4288+
dst_offset: int = 0) -> None:
4289+
"""Copy data between remote files
4290+
4291+
:param src:
4292+
The remote file object to read data from
4293+
:param dst:
4294+
The remote file object to write data to
4295+
:param src_offset: (optional)
4296+
The offset to begin reading data from
4297+
:param src_length: (optional)
4298+
The number of bytes to attempt to copy
4299+
:param dst_offset: (optional)
4300+
The offset to begin writing data to
4301+
:type src: :class:`SSHClientFile`
4302+
:type dst: :class:`SSHClientFile`
4303+
:type src_offset: `int`
4304+
:type src_length: `int`
4305+
:type dst_offset: `int`
4306+
4307+
:raises: :exc:`SFTPError` if the server doesn't support this
4308+
extension or returns an error
4309+
4310+
"""
4311+
4312+
await self._handler.copy_data(src.handle, src_offset, src_length,
4313+
dst.handle, dst_offset)
4314+
42364315
async def glob(self, patterns: _SFTPPaths,
42374316
error_handler: SFTPErrorHandler = None) -> \
42384317
Sequence[BytesOrStr]:
@@ -5583,7 +5662,8 @@ class SFTPServerHandler(SFTPHandler):
55835662
(b'hardlink@openssh.com', b'1'),
55845663
(b'fsync@openssh.com', b'1'),
55855664
(b'lsetstat@openssh.com', b'1'),
5586-
(b'limits@openssh.com', b'1')]
5665+
(b'limits@openssh.com', b'1'),
5666+
(b'copy-data', b'1')]
55875667

55885668
_attrib_extensions: List[bytes] = []
55895669

@@ -6437,6 +6517,55 @@ async def _process_limits(self, packet: SSHPacket) -> SFTPLimits:
64376517
return SFTPLimits(MAX_SFTP_PACKET_LEN, MAX_SFTP_READ_LEN,
64386518
MAX_SFTP_WRITE_LEN, nfiles)
64396519

6520+
async def _process_copy_data(self, packet: SSHPacket) -> None:
6521+
"""Process an incoming copy data request"""
6522+
6523+
read_from_handle = packet.get_string()
6524+
read_from_offset = packet.get_uint64()
6525+
read_from_length = packet.get_uint64()
6526+
write_to_handle = packet.get_string()
6527+
write_to_offset = packet.get_uint64()
6528+
packet.check_end()
6529+
6530+
self.logger.debug1('Received copy-data from handle %s, '
6531+
'offset %d, length %d to handle %s, '
6532+
'offset %d', read_from_handle.hex(),
6533+
read_from_offset, read_from_length,
6534+
write_to_handle.hex(), write_to_offset)
6535+
6536+
src = self._file_handles.get(read_from_handle)
6537+
dst = self._file_handles.get(write_to_handle)
6538+
6539+
if src and dst:
6540+
read_to_end = read_from_length == 0
6541+
6542+
while read_to_end or read_from_length:
6543+
if read_to_end:
6544+
size = _COPY_DATA_BLOCK_SIZE
6545+
else:
6546+
size = min(read_from_length, _COPY_DATA_BLOCK_SIZE)
6547+
6548+
data = self._server.read(src, read_from_offset, size)
6549+
6550+
if inspect.isawaitable(data):
6551+
data = await cast(Awaitable[bytes], data)
6552+
6553+
result = self._server.write(dst, write_to_offset, data)
6554+
6555+
if inspect.isawaitable(result):
6556+
await result
6557+
6558+
if len(data) < size:
6559+
break
6560+
6561+
read_from_offset += size
6562+
write_to_offset += size
6563+
6564+
if not read_to_end:
6565+
read_from_length -= size
6566+
else:
6567+
raise SFTPInvalidHandle('Invalid file handle')
6568+
64406569
_packet_handlers: Dict[Union[int, bytes], _SFTPPacketHandler] = {
64416570
FXP_OPEN: _process_open,
64426571
FXP_CLOSE: _process_close,
@@ -6465,7 +6594,8 @@ async def _process_limits(self, packet: SSHPacket) -> SFTPLimits:
64656594
b'hardlink@openssh.com': _process_openssh_link,
64666595
b'fsync@openssh.com': _process_fsync,
64676596
b'lsetstat@openssh.com': _process_lsetstat,
6468-
b'limits@openssh.com': _process_limits
6597+
b'limits@openssh.com': _process_limits,
6598+
b'copy-data': _process_copy_data
64696599
}
64706600

64716601
async def run(self) -> None:

docs/api.rst

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,16 +1077,17 @@ SFTP Support
10771077
.. autoattribute:: limits
10781078
======================================================================= =
10791079

1080-
===================== =
1080+
=========================== =
10811081
File transfer methods
1082-
===================== =
1082+
=========================== =
10831083
.. automethod:: get
10841084
.. automethod:: put
10851085
.. automethod:: copy
10861086
.. automethod:: mget
10871087
.. automethod:: mput
10881088
.. automethod:: mcopy
1089-
===================== =
1089+
.. automethod:: remote_copy
1090+
=========================== =
10901091

10911092
============================================================================================================================================================================================================================== =
10921093
File access methods

tests/test_sftp.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,26 @@ async def test_copy(self, sftp):
748748
finally:
749749
remove('src dst')
750750

751+
def test_copy_non_remote(self):
752+
"""Test copying without using remote_copy function"""
753+
754+
@sftp_test
755+
async def _test_copy_non_remote(self, sftp):
756+
"""Test copying without using remote_copy function"""
757+
758+
for src in ('src', b'src', Path('src')):
759+
with self.subTest(src=type(src)):
760+
try:
761+
self._create_file('src')
762+
await sftp.copy(src, 'dst')
763+
self._check_file('src', 'dst')
764+
finally:
765+
remove('src dst')
766+
767+
with patch('asyncssh.sftp.SFTPServerHandler._extensions', []):
768+
# pylint: disable=no-value-for-parameter
769+
_test_copy_non_remote(self)
770+
751771
@sftp_test
752772
async def test_copy_progress(self, sftp):
753773
"""Test copying a file over SFTP with progress reporting"""
@@ -769,7 +789,9 @@ def _report_progress(_srcpath, _dstpath, bytes_copied, _total_bytes):
769789
progress_handler=_report_progress)
770790
self._check_file('src', 'dst')
771791

772-
self.assertEqual(len(reports), (size // 8192) + 1)
792+
if method != 'copy':
793+
self.assertEqual(len(reports), (size // 8192) + 1)
794+
773795
self.assertEqual(reports[-1], size)
774796
finally:
775797
remove('src dst')
@@ -1130,6 +1152,37 @@ def err_handler(exc):
11301152
finally:
11311153
remove('src1 src2 dst')
11321154

1155+
@sftp_test
1156+
async def test_remote_copy_arguments(self, sftp):
1157+
"""Test remote copy arguments"""
1158+
1159+
try:
1160+
self._create_file('src', os.urandom(2*1024*1024))
1161+
1162+
async with sftp.open('src', 'rb') as src:
1163+
async with sftp.open('dst', 'wb') as dst:
1164+
await sftp.remote_copy(src, dst, 0, 1024*1024, 0)
1165+
await sftp.remote_copy(src, dst, 1024*1024, 0, 1024*1024)
1166+
1167+
self._check_file('src', 'dst')
1168+
finally:
1169+
remove('src dst')
1170+
1171+
@sftp_test
1172+
async def test_remote_copy_closed_file(self, sftp):
1173+
"""Test remote copy of a closed file"""
1174+
1175+
try:
1176+
self._create_file('file')
1177+
1178+
async with sftp.open('file', 'rb') as f:
1179+
await f.close()
1180+
1181+
with self.assertRaises(ValueError):
1182+
await sftp.remote_copy(f, f)
1183+
finally:
1184+
remove('file')
1185+
11331186
@sftp_test
11341187
async def test_glob(self, sftp):
11351188
"""Test a glob pattern match over SFTP"""
@@ -3173,6 +3226,9 @@ async def _return_invalid_handle(self, path, pflags, attrs):
31733226
with self.assertRaises(SFTPFailure):
31743227
await f.fsync()
31753228

3229+
with self.assertRaises(SFTPFailure):
3230+
await sftp.remote_copy(f, f)
3231+
31763232
with self.assertRaises(SFTPFailure):
31773233
await f.close()
31783234

@@ -4300,19 +4356,24 @@ async def start_server(cls):
43004356

43014357
return await cls.create_server(sftp_factory=_IOErrorSFTPServer)
43024358

4303-
@sftp_test
4304-
async def test_put_error(self, sftp):
4305-
"""Test error when putting a file to an SFTP server"""
4359+
def test_copy_error(self):
4360+
"""Test error when copying a file on an SFTP server"""
43064361

4307-
for method in ('get', 'put', 'copy'):
4308-
with self.subTest(method=method):
4309-
try:
4310-
self._create_file('src', 8*1024*1024*'\0')
4362+
@sftp_test
4363+
async def _test_copy_error(self, sftp):
4364+
"""Test error when copying a file on an SFTP server"""
43114365

4312-
with self.assertRaises(SFTPFailure):
4313-
await getattr(sftp, method)('src', 'dst')
4314-
finally:
4315-
remove('src dst')
4366+
try:
4367+
self._create_file('src', 8*1024*1024*'\0')
4368+
4369+
with self.assertRaises(SFTPFailure):
4370+
await sftp.copy('src', 'dst')
4371+
finally:
4372+
remove('src dst')
4373+
4374+
with patch('asyncssh.sftp.SFTPServerHandler._extensions', []):
4375+
# pylint: disable=no-value-for-parameter
4376+
_test_copy_error(self)
43164377

43174378
@sftp_test
43184379
async def test_read_error(self, sftp):

0 commit comments

Comments
 (0)