Skip to content

Commit 0f43ccf

Browse files
committed
Improve the archive API with fileobj parameters
The archive API only allows unpacking an archive from file, and writing an archive to file. However the zipfile, and tarfile modules this API is based on, both accept fileobj parameters to read from, and write to memory. This change adds a fileobj parameter to create_archive, and unpack_archive, in order to allow creating archives to, and unpacking archives from memory.
1 parent c847023 commit 0f43ccf

File tree

2 files changed

+135
-15
lines changed

2 files changed

+135
-15
lines changed

src/e3/archive.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import tempfile
1111
import zipfile
1212
from contextlib import closing
13-
from typing import TYPE_CHECKING
13+
from typing import TYPE_CHECKING, cast
1414

1515
import e3
1616
import e3.error
@@ -20,7 +20,7 @@
2020

2121

2222
if TYPE_CHECKING:
23-
from typing import Literal, Text, Union
23+
from typing import Literal, Text, Union, IO, Any
2424
from collections.abc import Callable, Sequence
2525
from os import PathLike
2626
from e3.mypy import assert_never
@@ -146,6 +146,7 @@ def check_type(
146146
def unpack_archive(
147147
filename: str,
148148
dest: str,
149+
fileobj: IO[bytes] | None = None,
149150
selected_files: Sequence[str] | None = None,
150151
remove_root_dir: RemoveRootDirType = False,
151152
unpack_cmd: Callable[..., None] | None = None,
@@ -159,6 +160,10 @@ def unpack_archive(
159160
160161
:param filename: archive to unpack
161162
:param dest: destination directory (should exist)
163+
:param fileobj: if specified, the archive is read from this file object
164+
instead of opening a file. The file object must be opened in binary
165+
mode. In this case filename is the name of the archive contained
166+
in the file object.
162167
:param selected_files: list of files to unpack (partial extraction). If
163168
None all files are unpacked
164169
:param remove_root_dir: if True then the root dir of the archive is
@@ -192,7 +197,7 @@ def unpack_archive(
192197
logger.debug("unpack %s in %s", filename, dest)
193198
# First do some checks such as archive existence or destination directory
194199
# existence.
195-
if not os.path.isfile(filename):
200+
if fileobj is None and not os.path.isfile(filename):
196201
raise ArchiveError(origin="unpack_archive", message=f"cannot find {filename}")
197202

198203
if not os.path.isdir(dest):
@@ -205,14 +210,19 @@ def unpack_archive(
205210

206211
# We need to resolve to an absolute path as the extraction related
207212
# processes will be run in the destination directory
208-
filename = os.path.abspath(filename)
213+
if fileobj is None:
214+
filename = os.path.abspath(filename)
209215

210216
if unpack_cmd is not None:
211217
# Use user defined unpack command
212-
if not selected_files:
213-
return unpack_cmd(filename, dest)
214-
else:
215-
return unpack_cmd(filename, dest, selected_files=selected_files)
218+
kwargs: dict[str, Any] = {}
219+
if selected_files:
220+
kwargs["selected_files"] = selected_files
221+
222+
if fileobj is not None:
223+
kwargs["fileobj"] = fileobj
224+
225+
return unpack_cmd(filename, dest, **kwargs)
216226

217227
ext = check_type(filename, force_extension=force_extension)
218228

@@ -237,7 +247,13 @@ def unpack_archive(
237247
elif ext.endswith("xz"):
238248
mode += "xz"
239249
# Extract tar files
240-
with closing(tarfile.open(filename, mode=mode)) as fd:
250+
with closing(
251+
tarfile.open(
252+
filename if fileobj is None else None,
253+
fileobj=fileobj,
254+
mode=mode,
255+
)
256+
) as fd:
241257
check_selected = set(selected_files)
242258

243259
def is_match(name: str, files: Sequence[str]) -> bool:
@@ -291,7 +307,9 @@ def is_match(name: str, files: Sequence[str]) -> bool:
291307

292308
elif ext == "zip":
293309
try:
294-
with closing(E3ZipFile(filename, mode="r")) as zip_fd:
310+
with closing(
311+
E3ZipFile(fileobj if fileobj is not None else filename, mode="r")
312+
) as zip_fd:
295313
zip_fd.extractall(
296314
tmp_dest, selected_files if selected_files else None
297315
)
@@ -358,7 +376,8 @@ def is_match(name: str, files: Sequence[str]) -> bool:
358376
def create_archive(
359377
filename: str,
360378
from_dir: str,
361-
dest: str,
379+
dest: str | None = None,
380+
fileobj: IO[bytes] | None = None,
362381
force_extension: str | None = None,
363382
from_dir_rename: str | None = None,
364383
no_root_dir: bool = False,
@@ -372,26 +391,45 @@ def create_archive(
372391
373392
:param filename: archive to create
374393
:param from_dir: directory to pack (full path)
375-
:param dest: destination directory (should exist)
394+
:param dest: destination directory (should exist). If not specified,
395+
the archive is written to the file object passed with fileobj.
396+
:param fileobj: if specified, the archive is written to this file object
397+
instead of opening a file. The file object must be opened in binary
398+
mode. In this case filename is the name of the archive contained
399+
in the file object.
376400
:param force_extension: specify the archive extension if not in the
377401
filename. If filename has no extension and force_extension is None
378402
create_archive will fail.
379403
:param from_dir_rename: name of root directory in the archive.
380404
:param no_root_dir: create archive without the root dir (zip only)
381405
406+
:raise ValueError: neither dest nor fileobj is provided
382407
:raise ArchiveError: if an error occurs
383408
"""
409+
if dest is None and fileobj is None:
410+
raise ValueError("no destination provided")
411+
384412
# Check extension
385413
from_dir = from_dir.rstrip("/")
386-
filepath = os.path.abspath(os.path.join(dest, filename))
414+
415+
# If fileobj is None, dest is not None
416+
filepath = (
417+
os.path.abspath(os.path.join(cast(str, dest), filename))
418+
if fileobj is None
419+
else None
420+
)
387421

388422
ext = check_type(filename, force_extension=force_extension)
389423

390424
if from_dir_rename is None:
391425
from_dir_rename = os.path.basename(from_dir)
392426

393427
if ext == "zip":
394-
zip_archive = zipfile.ZipFile(filepath, "w", zipfile.ZIP_DEFLATED)
428+
zip_archive = zipfile.ZipFile(
429+
cast(str, filepath) if fileobj is None else fileobj,
430+
"w",
431+
zipfile.ZIP_DEFLATED,
432+
)
395433
for root, _, files in os.walk(from_dir):
396434
relative_root = os.path.relpath(
397435
os.path.abspath(root), os.path.abspath(from_dir)
@@ -413,5 +451,7 @@ def create_archive(
413451
tar_format = "w:xz"
414452
else:
415453
assert_never()
416-
with closing(tarfile.open(filepath, tar_format)) as tar_archive:
454+
with closing(
455+
tarfile.open(filepath, fileobj=fileobj, mode=tar_format)
456+
) as tar_archive:
417457
tar_archive.add(name=from_dir, arcname=from_dir_rename, recursive=True)

tests/tests_e3/archive/main_test.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import tempfile
4+
import io
45

56
import e3.archive
67
import e3.fs
@@ -91,6 +92,43 @@ def test_unpack(ext):
9192
e3.fs.rm(dest, True)
9293

9394

95+
@pytest.mark.parametrize("ext", (".tar.gz", ".zip"))
96+
def test_unpack_fileobj(ext):
97+
dir_to_pack = os.path.dirname(__file__)
98+
99+
test_dir = os.path.basename(dir_to_pack)
100+
101+
dest = "dest"
102+
e3.fs.mkdir(dest)
103+
104+
archive_name = "e3-core" + ext
105+
106+
try:
107+
fo = io.BytesIO()
108+
e3.archive.create_archive(
109+
filename=archive_name,
110+
from_dir=os.path.abspath(dir_to_pack),
111+
fileobj=fo,
112+
)
113+
114+
fo.seek(0)
115+
e3.fs.mkdir(os.path.join(dest, "dest2"))
116+
e3.archive.unpack_archive(
117+
filename=archive_name,
118+
dest=os.path.join(dest, "dest2"),
119+
fileobj=fo,
120+
selected_files=(
121+
e3.os.fs.unixpath(os.path.join(test_dir, os.path.basename(__file__))),
122+
),
123+
remove_root_dir=True,
124+
)
125+
126+
assert os.path.exists(os.path.join(dest, "dest2", os.path.basename(__file__)))
127+
128+
finally:
129+
e3.fs.rm(dest, True)
130+
131+
94132
def test_unsupported():
95133
"""Test unsupported archive format."""
96134
with pytest.raises(e3.archive.ArchiveError) as err:
@@ -147,6 +185,48 @@ def custom_unpack(filename, dest, selected_files):
147185
assert t.kwargs["s"] == ["bar"]
148186

149187

188+
def test_unpack_cmd_fileobj():
189+
"""Test custom unpack_cmd with fileobj."""
190+
dir_to_pack = os.path.dirname(__file__)
191+
192+
dest = "dest"
193+
e3.fs.mkdir(dest)
194+
195+
archive_name = "e3-core.tar"
196+
197+
fo = io.BytesIO()
198+
e3.archive.create_archive(
199+
filename=archive_name,
200+
from_dir=os.path.abspath(dir_to_pack),
201+
fileobj=fo,
202+
)
203+
204+
all_dest = "all_dest"
205+
e3.fs.mkdir(all_dest)
206+
207+
# Use a custom unpack function and verify that it is called with
208+
# the expected arguments
209+
class TestResult:
210+
def store_result(self, **kwargs):
211+
self.kwargs = kwargs
212+
213+
t = TestResult()
214+
215+
def custom_unpack(filename, dest, fileobj):
216+
t.store_result(f=filename, d=dest, fo=fileobj)
217+
218+
fo.seek(0)
219+
e3.archive.unpack_archive(
220+
filename=archive_name,
221+
fileobj=fo,
222+
dest=all_dest,
223+
unpack_cmd=custom_unpack,
224+
)
225+
assert t.kwargs["f"] == archive_name
226+
assert t.kwargs["d"] == all_dest
227+
assert t.kwargs["fo"] == fo
228+
229+
150230
def test_unpack_files():
151231
"""Test unpack_archive with selected_files."""
152232
e3.fs.mkdir("d")

0 commit comments

Comments
 (0)