Skip to content

Commit 4cf2a39

Browse files
committed
Add support for str, bytes, or PurePath pathnames in remote_copy
This commit adds support for passing in pathnames of type str, bytes, or PurePath to the new remote_copy() function, in addition to passing in already-open SFTPClientFile objects.
1 parent 6ecb91e commit 4cf2a39

File tree

1 file changed

+33
-23
lines changed

1 file changed

+33
-23
lines changed

asyncssh/sftp.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@
137137
_SFTPPatList = List[Union[bytes, List[bytes]]]
138138
_SFTPStatFunc = Callable[[_SFTPPath], Awaitable['SFTPAttrs']]
139139

140+
_SFTPClientFileOrPath = Union['SFTPClientFile', _SFTPPath]
141+
140142
_SFTPNames = Tuple[Sequence['SFTPName'], bool]
141143
_SFTPOSAttrs = Union[os.stat_result, 'SFTPAttrs']
142144
_SFTPOSVFSAttrs = Union[os.statvfs_result, 'SFTPVFSAttrs']
@@ -799,6 +801,22 @@ async def run_task(self, offset: int, size: int) -> Tuple[int, int]:
799801
async def run(self) -> None:
800802
"""Perform parallel file copy"""
801803

804+
if self._srcfs == self._dstfs and \
805+
isinstance(self._srcfs, SFTPClient):
806+
try:
807+
await self._srcfs.remote_copy(self._srcpath, self._dstpath)
808+
except SFTPOpUnsupported:
809+
pass
810+
else:
811+
self._bytes_copied = self._total_bytes
812+
813+
if self._progress_handler:
814+
self._progress_handler(self._srcpath, self._dstpath,
815+
self._bytes_copied,
816+
self._total_bytes)
817+
818+
return
819+
802820
try:
803821
self._src = await self._srcfs.open(self._srcpath, 'rb',
804822
block_size=0)
@@ -808,24 +826,6 @@ async def run(self) -> None:
808826
if self._progress_handler and self._total_bytes == 0:
809827
self._progress_handler(self._srcpath, self._dstpath, 0, 0)
810828

811-
if self._srcfs == self._dstfs and \
812-
isinstance(self._srcfs, SFTPClient):
813-
try:
814-
await self._srcfs.remote_copy(
815-
cast(SFTPClientFile, self._src),
816-
cast(SFTPClientFile, self._dst))
817-
except SFTPOpUnsupported:
818-
pass
819-
else:
820-
self._bytes_copied = self._total_bytes
821-
822-
if self._progress_handler:
823-
self._progress_handler(self._srcpath, self._dstpath,
824-
self._bytes_copied,
825-
self._total_bytes)
826-
827-
return
828-
829829
async for _, datalen in self.iter():
830830
if datalen:
831831
self._bytes_copied += datalen
@@ -4283,9 +4283,9 @@ async def mcopy(self, srcpaths: _SFTPPaths,
42834283
block_size, max_requests, progress_handler,
42844284
error_handler)
42854285

4286-
async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
4287-
src_offset: int = 0, src_length: int = 0,
4288-
dst_offset: int = 0) -> None:
4286+
async def remote_copy(self, src: _SFTPClientFileOrPath,
4287+
dst: _SFTPClientFileOrPath, src_offset: int = 0,
4288+
src_length: int = 0, dst_offset: int = 0) -> None:
42894289
"""Copy data between remote files
42904290
42914291
:param src:
@@ -4298,8 +4298,12 @@ async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
42984298
The number of bytes to attempt to copy
42994299
:param dst_offset: (optional)
43004300
The offset to begin writing data to
4301-
:type src: :class:`SSHClientFile`
4302-
:type dst: :class:`SSHClientFile`
4301+
:type src:
4302+
:class:`SSHClientFile`, :class:`PurePath <pathlib.PurePath>`,
4303+
`str`, or `bytes`
4304+
:type dst:
4305+
:class:`SSHClientFile`, :class:`PurePath <pathlib.PurePath>`,
4306+
`str`, or `bytes`
43034307
:type src_offset: `int`
43044308
:type src_length: `int`
43054309
:type dst_offset: `int`
@@ -4309,6 +4313,12 @@ async def remote_copy(self, src: SFTPClientFile, dst: SFTPClientFile,
43094313
43104314
"""
43114315

4316+
if isinstance(src, (bytes, str, PurePath)):
4317+
src = await self.open(src, 'rb', block_size=0)
4318+
4319+
if isinstance(dst, (bytes, str, PurePath)):
4320+
dst = await self.open(dst, 'wb', block_size=0)
4321+
43124322
await self._handler.copy_data(src.handle, src_offset, src_length,
43134323
dst.handle, dst_offset)
43144324

0 commit comments

Comments
 (0)