Skip to content

Commit a6d6fb1

Browse files
authored
Merge pull request #595 from adanaja/features-archive
Improve the archive API with fileobj parameters
2 parents c847023 + 0f43ccf commit a6d6fb1

File tree

2 files changed

+135
-15
lines changed

2 files changed

+135
-15
lines changed

src/e3/archive.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tempfile
1111
import zipfile
1212
from contextlib import closing
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, cast
1414

1515
import e3
1616
import e3.error
@@ -20,7 +20,7 @@
2020

2121

2222
if TYPE_CHECKING:
23-
from typing import Literal, Text, Union
23+
from typing import Literal, Text, Union, IO, Any
2424
from collections.abc import Callable, Sequence
2525
from os import PathLike
2626
from e3.mypy import assert_never
@@ -146,6 +146,7 @@ def check_type(
146146
def unpack_archive(
147147
filename: str,
148148
dest: str,
149+
fileobj: IO[bytes] | None = None,
149150
selected_files: Sequence[str] | None = None,
150151
remove_root_dir: RemoveRootDirType = False,
151152
unpack_cmd: Callable[..., None] | None = None,
@@ -159,6 +160,10 @@ def unpack_archive(
159160
160161
:param filename: archive to unpack
161162
:param dest: destination directory (should exist)
163+
:param fileobj: if specified, the archive is read from this file object
164+
instead of opening a file. The file object must be opened in binary
165+
mode. In this case filename is the name of the archive contained
166+
in the file object.
162167
:param selected_files: list of files to unpack (partial extraction). If
163168
None all files are unpacked
164169
:param remove_root_dir: if True then the root dir of the archive is
@@ -192,7 +197,7 @@ def unpack_archive(
192197
logger.debug("unpack %s in %s", filename, dest)
193198
# First do some checks such as archive existence or destination directory
194199
# existence.
195-
if not os.path.isfile(filename):
200+
if fileobj is None and not os.path.isfile(filename):
196201
raise ArchiveError(origin="unpack_archive", message=f"cannot find {filename}")
197202

198203
if not os.path.isdir(dest):
@@ -205,14 +210,19 @@ def unpack_archive(
205210

206211
# We need to resolve to an absolute path as the extraction related
207212
# processes will be run in the destination directory
208-
filename = os.path.abspath(filename)
213+
if fileobj is None:
214+
filename = os.path.abspath(filename)
209215

210216
if unpack_cmd is not None:
211217
# Use user defined unpack command
212-
if not selected_files:
213-
return unpack_cmd(filename, dest)
214-
else:
215-
return unpack_cmd(filename, dest, selected_files=selected_files)
218+
kwargs: dict[str, Any] = {}
219+
if selected_files:
220+
kwargs["selected_files"] = selected_files
221+
222+
if fileobj is not None:
223+
kwargs["fileobj"] = fileobj
224+
225+
return unpack_cmd(filename, dest, **kwargs)
216226

217227
ext = check_type(filename, force_extension=force_extension)
218228

@@ -237,7 +247,13 @@ def unpack_archive(
237247
elif ext.endswith("xz"):
238248
mode += "xz"
239249
# Extract tar files
240-
with closing(tarfile.open(filename, mode=mode)) as fd:
250+
with closing(
251+
tarfile.open(
252+
filename if fileobj is None else None,
253+
fileobj=fileobj,
254+
mode=mode,
255+
)
256+
) as fd:
241257
check_selected = set(selected_files)
242258

243259
def is_match(name: str, files: Sequence[str]) -> bool:
@@ -291,7 +307,9 @@ def is_match(name: str, files: Sequence[str]) -> bool:
291307

292308
elif ext == "zip":
293309
try:
294-
with closing(E3ZipFile(filename, mode="r")) as zip_fd:
310+
with closing(
311+
E3ZipFile(fileobj if fileobj is not None else filename, mode="r")
312+
) as zip_fd:
295313
zip_fd.extractall(
296314
tmp_dest, selected_files if selected_files else None
297315
)
@@ -358,7 +376,8 @@ def is_match(name: str, files: Sequence[str]) -> bool:
358376
def create_archive(
359377
filename: str,
360378
from_dir: str,
361-
dest: str,
379+
dest: str | None = None,
380+
fileobj: IO[bytes] | None = None,
362381
force_extension: str | None = None,
363382
from_dir_rename: str | None = None,
364383
no_root_dir: bool = False,
@@ -372,26 +391,45 @@ def create_archive(
372391
373392
:param filename: archive to create
374393
:param from_dir: directory to pack (full path)
375-
:param dest: destination directory (should exist)
394+
:param dest: destination directory (should exist). If not specified,
395+
the archive is written to the file object passed with fileobj.
396+
:param fileobj: if specified, the archive is written to this file object
397+
instead of opening a file. The file object must be opened in binary
398+
mode. In this case filename is the name of the archive contained
399+
in the file object.
376400
:param force_extension: specify the archive extension if not in the
377401
filename. If filename has no extension and force_extension is None
378402
create_archive will fail.
379403
:param from_dir_rename: name of root directory in the archive.
380404
:param no_root_dir: create archive without the root dir (zip only)
381405
406+
:raise ValueError: neither dest nor fileobj is provided
382407
:raise ArchiveError: if an error occurs
383408
"""
409+
if dest is None and fileobj is None:
410+
raise ValueError("no destination provided")
411+
384412
# Check extension
385413
from_dir = from_dir.rstrip("/")
386-
filepath = os.path.abspath(os.path.join(dest, filename))
414+
415+
# If fileobj is None, dest is not None
416+
filepath = (
417+
os.path.abspath(os.path.join(cast(str, dest), filename))
418+
if fileobj is None
419+
else None
420+
)
387421

388422
ext = check_type(filename, force_extension=force_extension)
389423

390424
if from_dir_rename is None:
391425
from_dir_rename = os.path.basename(from_dir)
392426

393427
if ext == "zip":
394-
zip_archive = zipfile.ZipFile(filepath, "w", zipfile.ZIP_DEFLATED)
428+
zip_archive = zipfile.ZipFile(
429+
cast(str, filepath) if fileobj is None else fileobj,
430+
"w",
431+
zipfile.ZIP_DEFLATED,
432+
)
395433
for root, _, files in os.walk(from_dir):
396434
relative_root = os.path.relpath(
397435
os.path.abspath(root), os.path.abspath(from_dir)
@@ -413,5 +451,7 @@ def create_archive(
413451
tar_format = "w:xz"
414452
else:
415453
assert_never()
416-
with closing(tarfile.open(filepath, tar_format)) as tar_archive:
454+
with closing(
455+
tarfile.open(filepath, fileobj=fileobj, mode=tar_format)
456+
) as tar_archive:
417457
tar_archive.add(name=from_dir, arcname=from_dir_rename, recursive=True)

tests/tests_e3/archive/main_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import tempfile
4+
import io
45

56
import e3.archive
67
import e3.fs
@@ -91,6 +92,43 @@ def test_unpack(ext):
9192
e3.fs.rm(dest, True)
9293

9394

95+
@pytest.mark.parametrize("ext", (".tar.gz", ".zip"))
96+
def test_unpack_fileobj(ext):
97+
dir_to_pack = os.path.dirname(__file__)
98+
99+
test_dir = os.path.basename(dir_to_pack)
100+
101+
dest = "dest"
102+
e3.fs.mkdir(dest)
103+
104+
archive_name = "e3-core" + ext
105+
106+
try:
107+
fo = io.BytesIO()
108+
e3.archive.create_archive(
109+
filename=archive_name,
110+
from_dir=os.path.abspath(dir_to_pack),
111+
fileobj=fo,
112+
)
113+
114+
fo.seek(0)
115+
e3.fs.mkdir(os.path.join(dest, "dest2"))
116+
e3.archive.unpack_archive(
117+
filename=archive_name,
118+
dest=os.path.join(dest, "dest2"),
119+
fileobj=fo,
120+
selected_files=(
121+
e3.os.fs.unixpath(os.path.join(test_dir, os.path.basename(__file__))),
122+
),
123+
remove_root_dir=True,
124+
)
125+
126+
assert os.path.exists(os.path.join(dest, "dest2", os.path.basename(__file__)))
127+
128+
finally:
129+
e3.fs.rm(dest, True)
130+
131+
94132
def test_unsupported():
95133
"""Test unsupported archive format."""
96134
with pytest.raises(e3.archive.ArchiveError) as err:
@@ -147,6 +185,48 @@ def custom_unpack(filename, dest, selected_files):
147185
assert t.kwargs["s"] == ["bar"]
148186

149187

188+
def test_unpack_cmd_fileobj():
189+
"""Test custom unpack_cmd with fileobj."""
190+
dir_to_pack = os.path.dirname(__file__)
191+
192+
dest = "dest"
193+
e3.fs.mkdir(dest)
194+
195+
archive_name = "e3-core.tar"
196+
197+
fo = io.BytesIO()
198+
e3.archive.create_archive(
199+
filename=archive_name,
200+
from_dir=os.path.abspath(dir_to_pack),
201+
fileobj=fo,
202+
)
203+
204+
all_dest = "all_dest"
205+
e3.fs.mkdir(all_dest)
206+
207+
# Use a custom unpack function and verify that it is called with
208+
# the expected arguments
209+
class TestResult:
210+
def store_result(self, **kwargs):
211+
self.kwargs = kwargs
212+
213+
t = TestResult()
214+
215+
def custom_unpack(filename, dest, fileobj):
216+
t.store_result(f=filename, d=dest, fo=fileobj)
217+
218+
fo.seek(0)
219+
e3.archive.unpack_archive(
220+
filename=archive_name,
221+
fileobj=fo,
222+
dest=all_dest,
223+
unpack_cmd=custom_unpack,
224+
)
225+
assert t.kwargs["f"] == archive_name
226+
assert t.kwargs["d"] == all_dest
227+
assert t.kwargs["fo"] == fo
228+
229+
150230
def test_unpack_files():
151231
"""Test unpack_archive with selected_files."""
152232
e3.fs.mkdir("d")

0 commit comments

Comments
 (0)