Skip to content

Commit fe3dc0c

Browse files
authored
Merge pull request #592 from adanaja/features
Add kwargs parameter to HTTPSession.download_file
2 parents 7f08cb2 + 5e32290 commit fe3dc0c

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

src/e3/net/http.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def download_file(
240240
filename: str | None = None,
241241
validate: Callable[[str], bool] | None = None,
242242
exception_on_error: bool = False,
243+
**kwargs: Any,
243244
) -> str | None:
244245
"""Download a file.
245246
@@ -253,6 +254,7 @@ def download_file(
253254
parameter and returns a boolean.
254255
:param exception_on_error: if True raises an exception in case download
255256
fails instead of returning None.
257+
:param kwargs: additional parameters for the request
256258
:return: the name of the file or None if there is an error
257259
"""
258260
# When using stream=True, Requests cannot release the connection back
@@ -262,7 +264,7 @@ def download_file(
262264
path = None
263265
try:
264266
with contextlib.closing(
265-
self.request(method="GET", url=url, stream=True)
267+
self.request(method="GET", url=url, stream=True, **kwargs)
266268
) as response:
267269
content_length = int(response.headers.get("content-length", 0))
268270
e3.log.debug(response.headers)

tests/tests_e3/net/http_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,17 @@ def do_POST(self):
9393
logging.debug("POST finish")
9494

9595

96+
class AuthorizationHeaderHandler(ContentDispoHandler):
97+
def do_GET(self):
98+
if self.headers.get("Authorization") != "Bearer toto":
99+
self.send_response(403)
100+
self.end_headers()
101+
self.wfile.close()
102+
return
103+
104+
super(AuthorizationHeaderHandler, self).do_GET()
105+
106+
96107
def run_server(handler, func):
97108
server = HTTPServer(("localhost", 0), handler)
98109
try:
@@ -220,3 +231,28 @@ def func(server, url):
220231
run_server(MultiPartPostHandler, func)
221232

222233
run_server(ServerErrorHandler, outter_func)
234+
235+
def test_authorization_header(self, socket_enabled):
236+
def func(server, base_url):
237+
with HTTPSession() as session:
238+
# first test with no authorization header
239+
try:
240+
result = session.download_file(
241+
base_url + "dummy", dest=".", exception_on_error=True
242+
)
243+
raise AssertionError("exception not raised")
244+
except HTTPError as e:
245+
assert e.status == 403
246+
# second test with authorization header
247+
result = session.download_file(
248+
base_url + "dummy",
249+
dest=".",
250+
exception_on_error=True,
251+
headers={"Authorization": "Bearer toto"},
252+
)
253+
with open(result, "rb") as fd:
254+
content = fd.read()
255+
assert content == b"Dummy!"
256+
assert os.path.basename(result) == "dummy.txt"
257+
258+
run_server(AuthorizationHeaderHandler, func)

0 commit comments

Comments
 (0)