Skip to content

Commit f13d354

Browse files
mattmess1221Kludex
andauthored
fix(gzip): Make sure Vary header is always added if a response can be compressed (#2865)
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
1 parent abe3554 commit f13d354

File tree

2 files changed

+98
-39
lines changed

2 files changed

+98
-39
lines changed

starlette/middleware/gzip.py

Lines changed: 68 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,37 @@ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
1515
self.compresslevel = compresslevel
1616

1717
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
18-
if scope["type"] == "http": # pragma: no branch
19-
headers = Headers(scope=scope)
20-
if "gzip" in headers.get("Accept-Encoding", ""):
21-
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
22-
await responder(scope, receive, send)
23-
return
24-
await self.app(scope, receive, send)
18+
if scope["type"] != "http": # pragma: no cover
19+
await self.app(scope, receive, send)
20+
return
2521

22+
headers = Headers(scope=scope)
23+
responder: ASGIApp
24+
if "gzip" in headers.get("Accept-Encoding", ""):
25+
responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
26+
else:
27+
responder = IdentityResponder(self.app, self.minimum_size)
2628

27-
class GZipResponder:
28-
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
29+
await responder(scope, receive, send)
30+
31+
32+
class IdentityResponder:
33+
content_encoding: str
34+
35+
def __init__(self, app: ASGIApp, minimum_size: int) -> None:
2936
self.app = app
3037
self.minimum_size = minimum_size
3138
self.send: Send = unattached_send
3239
self.initial_message: Message = {}
3340
self.started = False
3441
self.content_encoding_set = False
3542
self.content_type_is_excluded = False
36-
self.gzip_buffer = io.BytesIO()
37-
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
3843

3944
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
4045
self.send = send
41-
with self.gzip_buffer, self.gzip_file:
42-
await self.app(scope, receive, self.send_with_gzip)
46+
await self.app(scope, receive, self.send_with_compression)
4347

44-
async def send_with_gzip(self, message: Message) -> None:
48+
async def send_with_compression(self, message: Message) -> None:
4549
message_type = message["type"]
4650
if message_type == "http.response.start":
4751
# Don't send the initial message until we've determined how to
@@ -60,53 +64,78 @@ async def send_with_gzip(self, message: Message) -> None:
6064
body = message.get("body", b"")
6165
more_body = message.get("more_body", False)
6266
if len(body) < self.minimum_size and not more_body:
63-
# Don't apply GZip to small outgoing responses.
67+
# Don't apply compression to small outgoing responses.
6468
await self.send(self.initial_message)
6569
await self.send(message)
6670
elif not more_body:
67-
# Standard GZip response.
68-
self.gzip_file.write(body)
69-
self.gzip_file.close()
70-
body = self.gzip_buffer.getvalue()
71+
# Standard response.
72+
body = self.apply_compression(body, more_body=False)
7173

7274
headers = MutableHeaders(raw=self.initial_message["headers"])
73-
headers["Content-Encoding"] = "gzip"
74-
headers["Content-Length"] = str(len(body))
7575
headers.add_vary_header("Accept-Encoding")
76-
message["body"] = body
76+
if body != message["body"]:
77+
headers["Content-Encoding"] = self.content_encoding
78+
headers["Content-Length"] = str(len(body))
79+
message["body"] = body
7780

7881
await self.send(self.initial_message)
7982
await self.send(message)
8083
else:
81-
# Initial body in streaming GZip response.
84+
# Initial body in streaming response.
85+
body = self.apply_compression(body, more_body=True)
86+
8287
headers = MutableHeaders(raw=self.initial_message["headers"])
83-
headers["Content-Encoding"] = "gzip"
8488
headers.add_vary_header("Accept-Encoding")
85-
del headers["Content-Length"]
86-
87-
self.gzip_file.write(body)
88-
message["body"] = self.gzip_buffer.getvalue()
89-
self.gzip_buffer.seek(0)
90-
self.gzip_buffer.truncate()
89+
if body != message["body"]:
90+
headers["Content-Encoding"] = self.content_encoding
91+
del headers["Content-Length"]
92+
message["body"] = body
9193

9294
await self.send(self.initial_message)
9395
await self.send(message)
94-
9596
elif message_type == "http.response.body": # pragma: no branch
96-
# Remaining body in streaming GZip response.
97+
# Remaining body in streaming response.
9798
body = message.get("body", b"")
9899
more_body = message.get("more_body", False)
99100

100-
self.gzip_file.write(body)
101-
if not more_body:
102-
self.gzip_file.close()
103-
104-
message["body"] = self.gzip_buffer.getvalue()
105-
self.gzip_buffer.seek(0)
106-
self.gzip_buffer.truncate()
101+
message["body"] = self.apply_compression(body, more_body=more_body)
107102

108103
await self.send(message)
109104

105+
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
106+
"""Apply compression on the response body.
107+
108+
If more_body is False, any compression file should be closed. If it
109+
isn't, it won't be closed automatically until all background tasks
110+
complete.
111+
"""
112+
return body
113+
114+
115+
class GZipResponder(IdentityResponder):
116+
content_encoding = "gzip"
117+
118+
def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
119+
super().__init__(app, minimum_size)
120+
121+
self.gzip_buffer = io.BytesIO()
122+
self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
123+
124+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
125+
with self.gzip_buffer, self.gzip_file:
126+
await super().__call__(scope, receive, send)
127+
128+
def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
129+
self.gzip_file.write(body)
130+
if not more_body:
131+
self.gzip_file.close()
132+
133+
body = self.gzip_buffer.getvalue()
134+
self.gzip_buffer.seek(0)
135+
self.gzip_buffer.truncate()
136+
137+
return body
138+
110139

111140
async def unattached_send(message: Message) -> typing.NoReturn:
112141
raise RuntimeError("send awaitable not set") # pragma: no cover

tests/middleware/test_gzip.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
from starlette.applications import Starlette
24
from starlette.middleware import Middleware
35
from starlette.middleware.gzip import GZipMiddleware
@@ -21,6 +23,7 @@ def homepage(request: Request) -> PlainTextResponse:
2123
assert response.status_code == 200
2224
assert response.text == "x" * 4000
2325
assert response.headers["Content-Encoding"] == "gzip"
26+
assert response.headers["Vary"] == "Accept-Encoding"
2427
assert int(response.headers["Content-Length"]) < 4000
2528

2629

@@ -38,6 +41,7 @@ def homepage(request: Request) -> PlainTextResponse:
3841
assert response.status_code == 200
3942
assert response.text == "x" * 4000
4043
assert "Content-Encoding" not in response.headers
44+
assert response.headers["Vary"] == "Accept-Encoding"
4145
assert int(response.headers["Content-Length"]) == 4000
4246

4347

@@ -57,6 +61,7 @@ def homepage(request: Request) -> PlainTextResponse:
5761
assert response.status_code == 200
5862
assert response.text == "OK"
5963
assert "Content-Encoding" not in response.headers
64+
assert "Vary" not in response.headers
6065
assert int(response.headers["Content-Length"]) == 2
6166

6267

@@ -79,6 +84,30 @@ async def generator(bytes: bytes, count: int) -> ContentStream:
7984
assert response.status_code == 200
8085
assert response.text == "x" * 4000
8186
assert response.headers["Content-Encoding"] == "gzip"
87+
assert response.headers["Vary"] == "Accept-Encoding"
88+
assert "Content-Length" not in response.headers
89+
90+
91+
def test_gzip_streaming_response_identity(test_client_factory: TestClientFactory) -> None:
92+
def homepage(request: Request) -> StreamingResponse:
93+
async def generator(bytes: bytes, count: int) -> ContentStream:
94+
for index in range(count):
95+
yield bytes
96+
97+
streaming = generator(bytes=b"x" * 400, count=10)
98+
return StreamingResponse(streaming, status_code=200)
99+
100+
app = Starlette(
101+
routes=[Route("/", endpoint=homepage)],
102+
middleware=[Middleware(GZipMiddleware)],
103+
)
104+
105+
client = test_client_factory(app)
106+
response = client.get("/", headers={"accept-encoding": "identity"})
107+
assert response.status_code == 200
108+
assert response.text == "x" * 4000
109+
assert "Content-Encoding" not in response.headers
110+
assert response.headers["Vary"] == "Accept-Encoding"
82111
assert "Content-Length" not in response.headers
83112

84113

@@ -103,6 +132,7 @@ async def generator(bytes: bytes, count: int) -> ContentStream:
103132
assert response.status_code == 200
104133
assert response.text == "x" * 4000
105134
assert response.headers["Content-Encoding"] == "text"
135+
assert "Vary" not in response.headers
106136
assert "Content-Length" not in response.headers
107137

108138

0 commit comments

Comments
 (0)