Skip to content

Commit 07d1759

Browse files
[PR #10101/678993a4 backport][3.11] Fix race in FileResponse if file is replaced during prepare (#10105)
Co-authored-by: J. Nick Koston <nick@koston.org> fixes #8013
1 parent 23a4b31 commit 07d1759

File tree

3 files changed

+63
-24
lines changed

3 files changed

+63
-24
lines changed

CHANGES/10101.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed race condition in :class:`aiohttp.web.FileResponse` that could have resulted in an incorrect response if the file was replaced on the file system during ``prepare`` -- by :user:`bdraco`.

aiohttp/web_fileresponse.py

Lines changed: 58 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import io
23
import os
34
import pathlib
45
import sys
@@ -16,6 +17,7 @@
1617
Iterator,
1718
List,
1819
Optional,
20+
Set,
1921
Tuple,
2022
Union,
2123
cast,
@@ -73,6 +75,9 @@
7375
CONTENT_TYPES.add_type(content_type, extension) # type: ignore[attr-defined]
7476

7577

78+
_CLOSE_FUTURES: Set[asyncio.Future[None]] = set()
79+
80+
7681
class FileResponse(StreamResponse):
7782
"""A response object can be used to send files."""
7883

@@ -161,10 +166,10 @@ async def _precondition_failed(
161166
self.content_length = 0
162167
return await super().prepare(request)
163168

164-
def _get_file_path_stat_encoding(
169+
def _open_file_path_stat_encoding(
165170
self, accept_encoding: str
166-
) -> Tuple[pathlib.Path, os.stat_result, Optional[str]]:
167-
"""Return the file path, stat result, and encoding.
171+
) -> Tuple[Optional[io.BufferedReader], os.stat_result, Optional[str]]:
172+
"""Return the io object, stat result, and encoding.
168173
169174
If an uncompressed file is returned, the encoding is set to
170175
:py:data:`None`.
@@ -182,31 +187,72 @@ def _get_file_path_stat_encoding(
182187
# Do not follow symlinks and ignore any non-regular files.
183188
st = compressed_path.lstat()
184189
if S_ISREG(st.st_mode):
185-
return compressed_path, st, file_encoding
190+
fobj = compressed_path.open("rb")
191+
with suppress(OSError):
192+
# fstat() may not be available on all platforms
193+
# Once we open the file, we want the fstat() to ensure
194+
# the file has not changed between the first stat()
195+
# and the open().
196+
st = os.stat(fobj.fileno())
197+
return fobj, st, file_encoding
186198

187199
# Fallback to the uncompressed file
188-
return file_path, file_path.stat(), None
200+
st = file_path.stat()
201+
if not S_ISREG(st.st_mode):
202+
return None, st, None
203+
fobj = file_path.open("rb")
204+
with suppress(OSError):
205+
# fstat() may not be available on all platforms
206+
# Once we open the file, we want the fstat() to ensure
207+
# the file has not changed between the first stat()
208+
# and the open().
209+
st = os.stat(fobj.fileno())
210+
return fobj, st, None
189211

190212
async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]:
191213
loop = asyncio.get_running_loop()
192214
# Encoding comparisons should be case-insensitive
193215
# https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1
194216
accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower()
195217
try:
196-
file_path, st, file_encoding = await loop.run_in_executor(
197-
None, self._get_file_path_stat_encoding, accept_encoding
218+
fobj, st, file_encoding = await loop.run_in_executor(
219+
None, self._open_file_path_stat_encoding, accept_encoding
198220
)
221+
except PermissionError:
222+
self.set_status(HTTPForbidden.status_code)
223+
return await super().prepare(request)
199224
except OSError:
200225
# Most likely to be FileNotFoundError or OSError for circular
201226
# symlinks in python >= 3.13, so respond with 404.
202227
self.set_status(HTTPNotFound.status_code)
203228
return await super().prepare(request)
204229

205-
# Forbid special files like sockets, pipes, devices, etc.
206-
if not S_ISREG(st.st_mode):
207-
self.set_status(HTTPForbidden.status_code)
208-
return await super().prepare(request)
230+
try:
231+
# Forbid special files like sockets, pipes, devices, etc.
232+
if not fobj or not S_ISREG(st.st_mode):
233+
self.set_status(HTTPForbidden.status_code)
234+
return await super().prepare(request)
209235

236+
return await self._prepare_open_file(request, fobj, st, file_encoding)
237+
finally:
238+
if fobj:
239+
# We do not await here because we do not want to wait
240+
# for the executor to finish before returning the response
241+
# so the connection can begin servicing another request
242+
# as soon as possible.
243+
close_future = loop.run_in_executor(None, fobj.close)
244+
# Hold a strong reference to the future to prevent it from being
245+
# garbage collected before it completes.
246+
_CLOSE_FUTURES.add(close_future)
247+
close_future.add_done_callback(_CLOSE_FUTURES.remove)
248+
249+
async def _prepare_open_file(
250+
self,
251+
request: "BaseRequest",
252+
fobj: io.BufferedReader,
253+
st: os.stat_result,
254+
file_encoding: Optional[str],
255+
) -> Optional[AbstractStreamWriter]:
210256
etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}"
211257
last_modified = st.st_mtime
212258

@@ -349,18 +395,9 @@ async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter
349395
if count == 0 or must_be_empty_body(request.method, self.status):
350396
return await super().prepare(request)
351397

352-
try:
353-
fobj = await loop.run_in_executor(None, file_path.open, "rb")
354-
except PermissionError:
355-
self.set_status(HTTPForbidden.status_code)
356-
return await super().prepare(request)
357-
358398
if start: # be aware that start could be None or int=0 here.
359399
offset = start
360400
else:
361401
offset = 0
362402

363-
try:
364-
return await self._sendfile(request, fobj, offset, count)
365-
finally:
366-
await asyncio.shield(loop.run_in_executor(None, fobj.close))
403+
return await self._sendfile(request, fobj, offset, count)

tests/test_web_urldispatcher.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,16 +585,17 @@ async def test_access_mock_special_resource(
585585
my_special.touch()
586586

587587
real_result = my_special.stat()
588-
real_stat = pathlib.Path.stat
588+
real_stat = os.stat
589589

590-
def mock_stat(self: pathlib.Path, **kwargs: Any) -> os.stat_result:
591-
s = real_stat(self, **kwargs)
590+
def mock_stat(path: Any, **kwargs: Any) -> os.stat_result:
591+
s = real_stat(path, **kwargs)
592592
if os.path.samestat(s, real_result):
593593
mock_mode = S_IFIFO | S_IMODE(s.st_mode)
594594
s = os.stat_result([mock_mode] + list(s)[1:])
595595
return s
596596

597597
monkeypatch.setattr("pathlib.Path.stat", mock_stat)
598+
monkeypatch.setattr("os.stat", mock_stat)
598599

599600
app = web.Application()
600601
app.router.add_static("/", str(tmp_path))

0 commit comments

Comments
 (0)