23
23
24
24
if TYPE_CHECKING :
25
25
from types import TracebackType
26
- from typing import Any , Deque
26
+ from typing import Any , Deque , Protocol
27
27
from collections .abc import Callable
28
28
from requests .auth import AuthBase
29
29
from requests .models import Response
30
30
31
+ class _Fileobj (Protocol ):
32
+ def write (self , __b : bytes ) -> object :
33
+ ...
34
+
31
35
32
36
logger = e3 .log .getLogger ("net.http" )
33
37
@@ -236,30 +240,41 @@ def request(
236
240
def download_file (
237
241
self ,
238
242
url : str ,
239
- dest : str ,
243
+ dest : str | None = None ,
240
244
filename : str | None = None ,
245
+ fileobj : _Fileobj | None = None ,
241
246
validate : Callable [[str ], bool ] | None = None ,
242
247
exception_on_error : bool = False ,
243
248
** kwargs : Any ,
244
249
) -> str | None :
245
250
"""Download a file.
246
251
247
252
:param url: the url to GET
248
- :param dest: local directory path for the downloaded file
249
- :param filename: the local path where to store this resource, by
250
- default uses the name provided in the ``Content-Disposition``
253
+ :param dest: local directory path for the downloaded file. If
254
+ None, a file object must be specified.
255
+ :param filename: the local path whether to store this resource, by
256
+ default use the name provided in the ``Content-Disposition``
251
257
header.
258
+ :param fileobj: if specified, the downloaded file is written to this
259
+ file object instead of opening a file. The file object must be
260
+ opened in binary mode.
252
261
:param validate: function to call once the download is complete for
253
262
detecting invalid / corrupted download. Takes the local path as
254
- parameter and returns a boolean.
263
+ parameter and returns a boolean. The function is not called
264
+ when a file object is specified.
255
265
:param exception_on_error: if True raises an exception in case download
256
266
fails instead of returning None.
257
267
:param kwargs: additional parameters for the request
258
- :return: the name of the file or None if there is an error
268
+ :return: the name of the file, or None if there is an error or a file
269
+ object is passed and the filename could not be deduced from the
270
+ request.
271
+ :raises ValueError: if neither dest nor fileobj is provided
259
272
"""
260
273
# When using stream=True, Requests cannot release the connection back
261
274
# to the pool unless all the data is consumed or Response.close called.
262
275
# Force Response.close by wrapping the code with contextlib.closing
276
+ if dest is None and fileobj is None :
277
+ raise ValueError ("no destination provided" )
263
278
264
279
path = None
265
280
try :
@@ -271,6 +286,24 @@ def download_file(
271
286
if filename is None :
272
287
if "content-disposition" in response .headers :
273
288
filename = get_filename (response .headers ["content-disposition" ])
289
+
290
+ expected_size = content_length // self .CHUNK_SIZE
291
+
292
+ chunks = e3 .log .progress_bar (
293
+ response .iter_content (self .CHUNK_SIZE ), total = expected_size
294
+ )
295
+
296
+ if fileobj is not None :
297
+ # Write to file object if provided
298
+ logger .info ("downloading %s size=%s" , path , content_length )
299
+ for chunk in chunks :
300
+ fileobj .write (chunk )
301
+ return filename
302
+ else :
303
+ # Dest can't be None here according to condition at the top
304
+ assert dest is not None
305
+
306
+ # Fallback to local file otherwise
274
307
if filename is None :
275
308
# Generate a temporary name
276
309
tmpf = tempfile .NamedTemporaryFile (
@@ -279,19 +312,18 @@ def download_file(
279
312
tmpf .close ()
280
313
filename = tmpf .name
281
314
282
- path = os .path .join (dest , filename )
283
- logger .info ("downloading %s size=%s" , path , content_length )
315
+ path = os .path .join (dest , filename )
284
316
285
- expected_size = content_length // self . CHUNK_SIZE
286
- with open ( path , "wb" ) as fd :
287
- for chunk in e3 . log . progress_bar (
288
- response . iter_content ( self . CHUNK_SIZE ), total = expected_size
289
- ):
290
- fd . write ( chunk )
291
- if validate is None or validate (path ):
292
- return path
293
- else :
294
- rm (path )
317
+ logger . info ( "downloading %s size=%s" , path , content_length )
318
+
319
+ with open ( path , "wb" ) as fd :
320
+ for chunk in chunks :
321
+ fd . write ( chunk )
322
+
323
+ if validate is None or validate (path ):
324
+ return path
325
+ else :
326
+ rm (path )
295
327
except (requests .exceptions .RequestException , HTTPError ) as e :
296
328
# An error (timeout?) occurred while downloading the file
297
329
logger .warning ("download failed" )
0 commit comments