1import gzip
2import io
3from typing import NoReturn
4
5from starlette.datastructures import Headers, MutableHeaders
6from starlette.types import ASGIApp, Message, Receive, Scope, Send
7
8DEFAULT_EXCLUDED_CONTENT_TYPES = ("text/event-stream",)
9
10
11class GZipMiddleware:
12 def __init__(self, app: ASGIApp, minimum_size: int = 500, compresslevel: int = 9) -> None:
13 self.app = app
14 self.minimum_size = minimum_size
15 self.compresslevel = compresslevel
16
17 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
18 if scope["type"] != "http": # pragma: no cover
19 await self.app(scope, receive, send)
20 return
21
22 headers = Headers(scope=scope)
23 responder: ASGIApp
24 if "gzip" in headers.get("Accept-Encoding", ""):
25 responder = GZipResponder(self.app, self.minimum_size, compresslevel=self.compresslevel)
26 else:
27 responder = IdentityResponder(self.app, self.minimum_size)
28
29 await responder(scope, receive, send)
30
31
32class IdentityResponder:
33 content_encoding: str
34
35 def __init__(self, app: ASGIApp, minimum_size: int) -> None:
36 self.app = app
37 self.minimum_size = minimum_size
38 self.send: Send = unattached_send
39 self.initial_message: Message = {}
40 self.started = False
41 self.content_encoding_set = False
42 self.content_type_is_excluded = False
43
44 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
45 self.send = send
46 await self.app(scope, receive, self.send_with_compression)
47
48 async def send_with_compression(self, message: Message) -> None:
49 message_type = message["type"]
50 if message_type == "http.response.start":
51 # Don't send the initial message until we've determined how to
52 # modify the outgoing headers correctly.
53 self.initial_message = message
54 headers = Headers(raw=self.initial_message["headers"])
55 self.content_encoding_set = "content-encoding" in headers
56 self.content_type_is_excluded = headers.get("content-type", "").startswith(DEFAULT_EXCLUDED_CONTENT_TYPES)
57 elif message_type == "http.response.body" and (self.content_encoding_set or self.content_type_is_excluded):
58 if not self.started:
59 self.started = True
60 await self.send(self.initial_message)
61 await self.send(message)
62 elif message_type == "http.response.body" and not self.started:
63 self.started = True
64 body = message.get("body", b"")
65 more_body = message.get("more_body", False)
66 if len(body) < self.minimum_size and not more_body:
67 # Don't apply compression to small outgoing responses.
68 await self.send(self.initial_message)
69 await self.send(message)
70 elif not more_body:
71 # Standard response.
72 body = self.apply_compression(body, more_body=False)
73
74 headers = MutableHeaders(raw=self.initial_message["headers"])
75 headers.add_vary_header("Accept-Encoding")
76 if body != message["body"]:
77 headers["Content-Encoding"] = self.content_encoding
78 headers["Content-Length"] = str(len(body))
79 message["body"] = body
80
81 await self.send(self.initial_message)
82 await self.send(message)
83 else:
84 # Initial body in streaming response.
85 body = self.apply_compression(body, more_body=True)
86
87 headers = MutableHeaders(raw=self.initial_message["headers"])
88 headers.add_vary_header("Accept-Encoding")
89 if body != message["body"]:
90 headers["Content-Encoding"] = self.content_encoding
91 del headers["Content-Length"]
92 message["body"] = body
93
94 await self.send(self.initial_message)
95 await self.send(message)
96 elif message_type == "http.response.body":
97 # Remaining body in streaming response.
98 body = message.get("body", b"")
99 more_body = message.get("more_body", False)
100
101 message["body"] = self.apply_compression(body, more_body=more_body)
102
103 await self.send(message)
104 elif message_type == "http.response.pathsend": # pragma: no branch
105 # Don't apply GZip to pathsend responses
106 await self.send(self.initial_message)
107 await self.send(message)
108
109 def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
110 """Apply compression on the response body.
111
112 If more_body is False, any compression file should be closed. If it
113 isn't, it won't be closed automatically until all background tasks
114 complete.
115 """
116 return body
117
118
119class GZipResponder(IdentityResponder):
120 content_encoding = "gzip"
121
122 def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> None:
123 super().__init__(app, minimum_size)
124
125 self.gzip_buffer = io.BytesIO()
126 self.gzip_file = gzip.GzipFile(mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel)
127
128 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
129 with self.gzip_buffer, self.gzip_file:
130 await super().__call__(scope, receive, send)
131
132 def apply_compression(self, body: bytes, *, more_body: bool) -> bytes:
133 self.gzip_file.write(body)
134 if not more_body:
135 self.gzip_file.close()
136
137 body = self.gzip_buffer.getvalue()
138 self.gzip_buffer.seek(0)
139 self.gzip_buffer.truncate()
140
141 return body
142
143
144async def unattached_send(message: Message) -> NoReturn:
145 raise RuntimeError("send awaitable not set") # pragma: no cover