Skip to content

Commit 8fc060a

Browse files
committed
Improve download_file API with a fileobj parameter
HTTPSession.download_file provides an easy way to download a file from an URL and save it to disk. However it may be sometimes interesting to keep the file in memory, if you don't want to, or don't need to save it temporarily to disk. This change makes download_file more look like the API of tarfile.open, that allows you to either pass the name of a file, or a file object to read from. Here the typing of the file object is done with a Protocol, which is the convention from typeshed (for the record python/typing#829)
1 parent fe3dc0c commit 8fc060a

File tree

2 files changed

+62
-19
lines changed

2 files changed

+62
-19
lines changed

src/e3/net/http.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@
2323

2424
if TYPE_CHECKING:
2525
from types import TracebackType
26-
from typing import Any, Deque
26+
from typing import Any, Deque, Protocol
2727
from collections.abc import Callable
2828
from requests.auth import AuthBase
2929
from requests.models import Response
3030

31+
class _Fileobj(Protocol):
32+
def write(self, __b: bytes) -> object:
33+
...
34+
3135

3236
logger = e3.log.getLogger("net.http")
3337

@@ -236,30 +240,41 @@ def request(
236240
def download_file(
237241
self,
238242
url: str,
239-
dest: str,
243+
dest: str | None = None,
240244
filename: str | None = None,
245+
fileobj: _Fileobj | None = None,
241246
validate: Callable[[str], bool] | None = None,
242247
exception_on_error: bool = False,
243248
**kwargs: Any,
244249
) -> str | None:
245250
"""Download a file.
246251
247252
:param url: the url to GET
248-
:param dest: local directory path for the downloaded file
249-
:param filename: the local path where to store this resource, by
250-
default uses the name provided in the ``Content-Disposition``
253+
:param dest: local directory path for the downloaded file. If
254+
None, a file object must be specified.
255+
:param filename: the local path whether to store this resource, by
256+
default use the name provided in the ``Content-Disposition``
251257
header.
258+
:param fileobj: if specified, the downloaded file is written to this
259+
file object instead of opening a file. The file object must be
260+
opened in binary mode.
252261
:param validate: function to call once the download is complete for
253262
detecting invalid / corrupted download. Takes the local path as
254-
parameter and returns a boolean.
263+
parameter and returns a boolean. The function is not called
264+
when a file object is specified.
255265
:param exception_on_error: if True raises an exception in case download
256266
fails instead of returning None.
257267
:param kwargs: additional parameters for the request
258-
:return: the name of the file or None if there is an error
268+
:return: the name of the file, or None if there is an error or a file
269+
object is passed and the filename could not be deduced from the
270+
request.
271+
:raises ValueError: if neither dest nor fileobj is provided
259272
"""
260273
# When using stream=True, Requests cannot release the connection back
261274
# to the pool unless all the data is consumed or Response.close called.
262275
# Force Response.close by wrapping the code with contextlib.closing
276+
if dest is None and fileobj is None:
277+
raise ValueError("no destination provided")
263278

264279
path = None
265280
try:
@@ -271,6 +286,24 @@ def download_file(
271286
if filename is None:
272287
if "content-disposition" in response.headers:
273288
filename = get_filename(response.headers["content-disposition"])
289+
290+
expected_size = content_length // self.CHUNK_SIZE
291+
292+
chunks = e3.log.progress_bar(
293+
response.iter_content(self.CHUNK_SIZE), total=expected_size
294+
)
295+
296+
if fileobj is not None:
297+
# Write to file object if provided
298+
logger.info("downloading %s size=%s", path, content_length)
299+
for chunk in chunks:
300+
fileobj.write(chunk)
301+
return filename
302+
else:
303+
# Dest can't be None here according to condition at the top
304+
assert dest is not None
305+
306+
# Fallback to local file otherwise
274307
if filename is None:
275308
# Generate a temporary name
276309
tmpf = tempfile.NamedTemporaryFile(
@@ -279,19 +312,18 @@ def download_file(
279312
tmpf.close()
280313
filename = tmpf.name
281314

282-
path = os.path.join(dest, filename)
283-
logger.info("downloading %s size=%s", path, content_length)
315+
path = os.path.join(dest, filename)
284316

285-
expected_size = content_length // self.CHUNK_SIZE
286-
with open(path, "wb") as fd:
287-
for chunk in e3.log.progress_bar(
288-
response.iter_content(self.CHUNK_SIZE), total=expected_size
289-
):
290-
fd.write(chunk)
291-
if validate is None or validate(path):
292-
return path
293-
else:
294-
rm(path)
317+
logger.info("downloading %s size=%s", path, content_length)
318+
319+
with open(path, "wb") as fd:
320+
for chunk in chunks:
321+
fd.write(chunk)
322+
323+
if validate is None or validate(path):
324+
return path
325+
else:
326+
rm(path)
295327
except (requests.exceptions.RequestException, HTTPError) as e:
296328
# An error (timeout?) occurred while downloading the file
297329
logger.warning("download failed")

tests/tests_e3/net/http_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import threading
55
import time
6+
from io import BytesIO
67

78
import requests_toolbelt.multipart
89
from e3.net.http import HTTPSession, HTTPError
@@ -146,6 +147,16 @@ def func(server, base_url):
146147

147148
run_server(ContentDispoHandler, func)
148149

150+
def test_content_dispo_fileobj(self, socket_enabled):
151+
def func(server, base_url):
152+
with HTTPSession() as session:
153+
fo = BytesIO()
154+
result = session.download_file(base_url + "dummy", fileobj=fo)
155+
assert fo.getvalue() == b"Dummy!"
156+
assert os.path.basename(result) == "dummy.txt"
157+
158+
run_server(ContentDispoHandler, func)
159+
149160
def test_content_validation(self, socket_enabled):
150161
def validate(path):
151162
return False

0 commit comments

Comments
 (0)