Skip to content

Commit c847023

Browse files
authored
Merge pull request #593 from adanaja/features-download-file
Improve download_file API with a fileobj parameter
2 parents fe3dc0c + 8fc060a commit c847023

File tree

2 files changed

+62
-19
lines changed

2 files changed

+62
-19
lines changed

src/e3/net/http.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,15 @@
2323

2424
if TYPE_CHECKING:
2525
from types import TracebackType
26-
from typing import Any, Deque
26+
from typing import Any, Deque, Protocol
2727
from collections.abc import Callable
2828
from requests.auth import AuthBase
2929
from requests.models import Response
3030

31+
class _Fileobj(Protocol):
32+
def write(self, __b: bytes) -> object:
33+
...
34+
3135

3236
logger = e3.log.getLogger("net.http")
3337

@@ -236,30 +240,41 @@ def request(
236240
def download_file(
237241
self,
238242
url: str,
239-
dest: str,
243+
dest: str | None = None,
240244
filename: str | None = None,
245+
fileobj: _Fileobj | None = None,
241246
validate: Callable[[str], bool] | None = None,
242247
exception_on_error: bool = False,
243248
**kwargs: Any,
244249
) -> str | None:
245250
"""Download a file.
246251
247252
:param url: the url to GET
248-
:param dest: local directory path for the downloaded file
249-
:param filename: the local path where to store this resource, by
250-
default uses the name provided in the ``Content-Disposition``
253+
:param dest: local directory path for the downloaded file. If
254+
None, a file object must be specified.
255+
:param filename: the local path whether to store this resource, by
256+
default use the name provided in the ``Content-Disposition``
251257
header.
258+
:param fileobj: if specified, the downloaded file is written to this
259+
file object instead of opening a file. The file object must be
260+
opened in binary mode.
252261
:param validate: function to call once the download is complete for
253262
detecting invalid / corrupted download. Takes the local path as
254-
parameter and returns a boolean.
263+
parameter and returns a boolean. The function is not called
264+
when a file object is specified.
255265
:param exception_on_error: if True raises an exception in case download
256266
fails instead of returning None.
257267
:param kwargs: additional parameters for the request
258-
:return: the name of the file or None if there is an error
268+
:return: the name of the file, or None if there is an error or a file
269+
object is passed and the filename could not be deduced from the
270+
request.
271+
:raises ValueError: if neither dest nor fileobj is provided
259272
"""
260273
# When using stream=True, Requests cannot release the connection back
261274
# to the pool unless all the data is consumed or Response.close called.
262275
# Force Response.close by wrapping the code with contextlib.closing
276+
if dest is None and fileobj is None:
277+
raise ValueError("no destination provided")
263278

264279
path = None
265280
try:
@@ -271,6 +286,24 @@ def download_file(
271286
if filename is None:
272287
if "content-disposition" in response.headers:
273288
filename = get_filename(response.headers["content-disposition"])
289+
290+
expected_size = content_length // self.CHUNK_SIZE
291+
292+
chunks = e3.log.progress_bar(
293+
response.iter_content(self.CHUNK_SIZE), total=expected_size
294+
)
295+
296+
if fileobj is not None:
297+
# Write to file object if provided
298+
logger.info("downloading %s size=%s", path, content_length)
299+
for chunk in chunks:
300+
fileobj.write(chunk)
301+
return filename
302+
else:
303+
# Dest can't be None here according to condition at the top
304+
assert dest is not None
305+
306+
# Fallback to local file otherwise
274307
if filename is None:
275308
# Generate a temporary name
276309
tmpf = tempfile.NamedTemporaryFile(
@@ -279,19 +312,18 @@ def download_file(
279312
tmpf.close()
280313
filename = tmpf.name
281314

282-
path = os.path.join(dest, filename)
283-
logger.info("downloading %s size=%s", path, content_length)
315+
path = os.path.join(dest, filename)
284316

285-
expected_size = content_length // self.CHUNK_SIZE
286-
with open(path, "wb") as fd:
287-
for chunk in e3.log.progress_bar(
288-
response.iter_content(self.CHUNK_SIZE), total=expected_size
289-
):
290-
fd.write(chunk)
291-
if validate is None or validate(path):
292-
return path
293-
else:
294-
rm(path)
317+
logger.info("downloading %s size=%s", path, content_length)
318+
319+
with open(path, "wb") as fd:
320+
for chunk in chunks:
321+
fd.write(chunk)
322+
323+
if validate is None or validate(path):
324+
return path
325+
else:
326+
rm(path)
295327
except (requests.exceptions.RequestException, HTTPError) as e:
296328
# An error (timeout?) occurred while downloading the file
297329
logger.warning("download failed")

tests/tests_e3/net/http_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import threading
55
import time
6+
from io import BytesIO
67

78
import requests_toolbelt.multipart
89
from e3.net.http import HTTPSession, HTTPError
@@ -146,6 +147,16 @@ def func(server, base_url):
146147

147148
run_server(ContentDispoHandler, func)
148149

150+
def test_content_dispo_fileobj(self, socket_enabled):
151+
def func(server, base_url):
152+
with HTTPSession() as session:
153+
fo = BytesIO()
154+
result = session.download_file(base_url + "dummy", fileobj=fo)
155+
assert fo.getvalue() == b"Dummy!"
156+
assert os.path.basename(result) == "dummy.txt"
157+
158+
run_server(ContentDispoHandler, func)
159+
149160
def test_content_validation(self, socket_enabled):
150161
def validate(path):
151162
return False

0 commit comments

Comments
 (0)