@@ -15,33 +15,37 @@ def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9
15
15
self .compresslevel = compresslevel
16
16
17
17
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
18
- if scope ["type" ] == "http" : # pragma: no branch
19
- headers = Headers (scope = scope )
20
- if "gzip" in headers .get ("Accept-Encoding" , "" ):
21
- responder = GZipResponder (self .app , self .minimum_size , compresslevel = self .compresslevel )
22
- await responder (scope , receive , send )
23
- return
24
- await self .app (scope , receive , send )
18
+ if scope ["type" ] != "http" : # pragma: no cover
19
+ await self .app (scope , receive , send )
20
+ return
25
21
22
+ headers = Headers (scope = scope )
23
+ responder : ASGIApp
24
+ if "gzip" in headers .get ("Accept-Encoding" , "" ):
25
+ responder = GZipResponder (self .app , self .minimum_size , compresslevel = self .compresslevel )
26
+ else :
27
+ responder = IdentityResponder (self .app , self .minimum_size )
26
28
27
- class GZipResponder :
28
- def __init__ (self , app : ASGIApp , minimum_size : int , compresslevel : int = 9 ) -> None :
29
+ await responder (scope , receive , send )
30
+
31
+
32
+ class IdentityResponder :
33
+ content_encoding : str
34
+
35
+ def __init__ (self , app : ASGIApp , minimum_size : int ) -> None :
29
36
self .app = app
30
37
self .minimum_size = minimum_size
31
38
self .send : Send = unattached_send
32
39
self .initial_message : Message = {}
33
40
self .started = False
34
41
self .content_encoding_set = False
35
42
self .content_type_is_excluded = False
36
- self .gzip_buffer = io .BytesIO ()
37
- self .gzip_file = gzip .GzipFile (mode = "wb" , fileobj = self .gzip_buffer , compresslevel = compresslevel )
38
43
39
44
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
40
45
self .send = send
41
- with self .gzip_buffer , self .gzip_file :
42
- await self .app (scope , receive , self .send_with_gzip )
46
+ await self .app (scope , receive , self .send_with_compression )
43
47
44
- async def send_with_gzip (self , message : Message ) -> None :
48
+ async def send_with_compression (self , message : Message ) -> None :
45
49
message_type = message ["type" ]
46
50
if message_type == "http.response.start" :
47
51
# Don't send the initial message until we've determined how to
@@ -60,53 +64,78 @@ async def send_with_gzip(self, message: Message) -> None:
60
64
body = message .get ("body" , b"" )
61
65
more_body = message .get ("more_body" , False )
62
66
if len (body ) < self .minimum_size and not more_body :
63
- # Don't apply GZip to small outgoing responses.
67
+ # Don't apply compression to small outgoing responses.
64
68
await self .send (self .initial_message )
65
69
await self .send (message )
66
70
elif not more_body :
67
- # Standard GZip response.
68
- self .gzip_file .write (body )
69
- self .gzip_file .close ()
70
- body = self .gzip_buffer .getvalue ()
71
+ # Standard response.
72
+ body = self .apply_compression (body , more_body = False )
71
73
72
74
headers = MutableHeaders (raw = self .initial_message ["headers" ])
73
- headers ["Content-Encoding" ] = "gzip"
74
- headers ["Content-Length" ] = str (len (body ))
75
75
headers .add_vary_header ("Accept-Encoding" )
76
- message ["body" ] = body
76
+ if body != message ["body" ]:
77
+ headers ["Content-Encoding" ] = self .content_encoding
78
+ headers ["Content-Length" ] = str (len (body ))
79
+ message ["body" ] = body
77
80
78
81
await self .send (self .initial_message )
79
82
await self .send (message )
80
83
else :
81
- # Initial body in streaming GZip response.
84
+ # Initial body in streaming response.
85
+ body = self .apply_compression (body , more_body = True )
86
+
82
87
headers = MutableHeaders (raw = self .initial_message ["headers" ])
83
- headers ["Content-Encoding" ] = "gzip"
84
88
headers .add_vary_header ("Accept-Encoding" )
85
- del headers ["Content-Length" ]
86
-
87
- self .gzip_file .write (body )
88
- message ["body" ] = self .gzip_buffer .getvalue ()
89
- self .gzip_buffer .seek (0 )
90
- self .gzip_buffer .truncate ()
89
+ if body != message ["body" ]:
90
+ headers ["Content-Encoding" ] = self .content_encoding
91
+ del headers ["Content-Length" ]
92
+ message ["body" ] = body
91
93
92
94
await self .send (self .initial_message )
93
95
await self .send (message )
94
-
95
96
elif message_type == "http.response.body" : # pragma: no branch
96
- # Remaining body in streaming GZip response.
97
+ # Remaining body in streaming response.
97
98
body = message .get ("body" , b"" )
98
99
more_body = message .get ("more_body" , False )
99
100
100
- self .gzip_file .write (body )
101
- if not more_body :
102
- self .gzip_file .close ()
103
-
104
- message ["body" ] = self .gzip_buffer .getvalue ()
105
- self .gzip_buffer .seek (0 )
106
- self .gzip_buffer .truncate ()
101
+ message ["body" ] = self .apply_compression (body , more_body = more_body )
107
102
108
103
await self .send (message )
109
104
105
+ def apply_compression (self , body : bytes , * , more_body : bool ) -> bytes :
106
+ """Apply compression on the response body.
107
+
108
+ If more_body is False, any compression file should be closed. If it
109
+ isn't, it won't be closed automatically until all background tasks
110
+ complete.
111
+ """
112
+ return body
113
+
114
+
115
+ class GZipResponder (IdentityResponder ):
116
+ content_encoding = "gzip"
117
+
118
+ def __init__ (self , app : ASGIApp , minimum_size : int , compresslevel : int = 9 ) -> None :
119
+ super ().__init__ (app , minimum_size )
120
+
121
+ self .gzip_buffer = io .BytesIO ()
122
+ self .gzip_file = gzip .GzipFile (mode = "wb" , fileobj = self .gzip_buffer , compresslevel = compresslevel )
123
+
124
+ async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
125
+ with self .gzip_buffer , self .gzip_file :
126
+ await super ().__call__ (scope , receive , send )
127
+
128
+ def apply_compression (self , body : bytes , * , more_body : bool ) -> bytes :
129
+ self .gzip_file .write (body )
130
+ if not more_body :
131
+ self .gzip_file .close ()
132
+
133
+ body = self .gzip_buffer .getvalue ()
134
+ self .gzip_buffer .seek (0 )
135
+ self .gzip_buffer .truncate ()
136
+
137
+ return body
138
+
110
139
111
140
async def unattached_send (message : Message ) -> typing .NoReturn :
112
141
raise RuntimeError ("send awaitable not set" ) # pragma: no cover
0 commit comments