1from __future__ import annotations
2
3import enum
4import json
5import typing
6
7from starlette.requests import HTTPConnection
8from starlette.responses import Response
9from starlette.types import Message, Receive, Scope, Send
10
11
12class WebSocketState(enum.Enum):
13 CONNECTING = 0
14 CONNECTED = 1
15 DISCONNECTED = 2
16 RESPONSE = 3
17
18
19class WebSocketDisconnect(Exception):
20 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
21 self.code = code
22 self.reason = reason or ""
23
24
25class WebSocket(HTTPConnection):
26 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
27 super().__init__(scope)
28 assert scope["type"] == "websocket"
29 self._receive = receive
30 self._send = send
31 self.client_state = WebSocketState.CONNECTING
32 self.application_state = WebSocketState.CONNECTING
33
34 async def receive(self) -> Message:
35 """
36 Receive ASGI websocket messages, ensuring valid state transitions.
37 """
38 if self.client_state == WebSocketState.CONNECTING:
39 message = await self._receive()
40 message_type = message["type"]
41 if message_type != "websocket.connect":
42 raise RuntimeError(
43 'Expected ASGI message "websocket.connect", '
44 f"but got {message_type!r}"
45 )
46 self.client_state = WebSocketState.CONNECTED
47 return message
48 elif self.client_state == WebSocketState.CONNECTED:
49 message = await self._receive()
50 message_type = message["type"]
51 if message_type not in {"websocket.receive", "websocket.disconnect"}:
52 raise RuntimeError(
53 'Expected ASGI message "websocket.receive" or '
54 f'"websocket.disconnect", but got {message_type!r}'
55 )
56 if message_type == "websocket.disconnect":
57 self.client_state = WebSocketState.DISCONNECTED
58 return message
59 else:
60 raise RuntimeError(
61 'Cannot call "receive" once a disconnect message has been received.'
62 )
63
64 async def send(self, message: Message) -> None:
65 """
66 Send ASGI websocket messages, ensuring valid state transitions.
67 """
68 if self.application_state == WebSocketState.CONNECTING:
69 message_type = message["type"]
70 if message_type not in {
71 "websocket.accept",
72 "websocket.close",
73 "websocket.http.response.start",
74 }:
75 raise RuntimeError(
76 'Expected ASGI message "websocket.accept",'
77 '"websocket.close" or "websocket.http.response.start",'
78 f"but got {message_type!r}"
79 )
80 if message_type == "websocket.close":
81 self.application_state = WebSocketState.DISCONNECTED
82 elif message_type == "websocket.http.response.start":
83 self.application_state = WebSocketState.RESPONSE
84 else:
85 self.application_state = WebSocketState.CONNECTED
86 await self._send(message)
87 elif self.application_state == WebSocketState.CONNECTED:
88 message_type = message["type"]
89 if message_type not in {"websocket.send", "websocket.close"}:
90 raise RuntimeError(
91 'Expected ASGI message "websocket.send" or "websocket.close", '
92 f"but got {message_type!r}"
93 )
94 if message_type == "websocket.close":
95 self.application_state = WebSocketState.DISCONNECTED
96 try:
97 await self._send(message)
98 except OSError:
99 self.application_state = WebSocketState.DISCONNECTED
100 raise WebSocketDisconnect(code=1006)
101 elif self.application_state == WebSocketState.RESPONSE:
102 message_type = message["type"]
103 if message_type != "websocket.http.response.body":
104 raise RuntimeError(
105 'Expected ASGI message "websocket.http.response.body", '
106 f"but got {message_type!r}"
107 )
108 if not message.get("more_body", False):
109 self.application_state = WebSocketState.DISCONNECTED
110 await self._send(message)
111 else:
112 raise RuntimeError('Cannot call "send" once a close message has been sent.')
113
114 async def accept(
115 self,
116 subprotocol: str | None = None,
117 headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
118 ) -> None:
119 headers = headers or []
120
121 if self.client_state == WebSocketState.CONNECTING:
122 # If we haven't yet seen the 'connect' message, then wait for it first.
123 await self.receive()
124 await self.send(
125 {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
126 )
127
128 def _raise_on_disconnect(self, message: Message) -> None:
129 if message["type"] == "websocket.disconnect":
130 raise WebSocketDisconnect(message["code"], message.get("reason"))
131
132 async def receive_text(self) -> str:
133 if self.application_state != WebSocketState.CONNECTED:
134 raise RuntimeError(
135 'WebSocket is not connected. Need to call "accept" first.'
136 )
137 message = await self.receive()
138 self._raise_on_disconnect(message)
139 return typing.cast(str, message["text"])
140
141 async def receive_bytes(self) -> bytes:
142 if self.application_state != WebSocketState.CONNECTED:
143 raise RuntimeError(
144 'WebSocket is not connected. Need to call "accept" first.'
145 )
146 message = await self.receive()
147 self._raise_on_disconnect(message)
148 return typing.cast(bytes, message["bytes"])
149
150 async def receive_json(self, mode: str = "text") -> typing.Any:
151 if mode not in {"text", "binary"}:
152 raise RuntimeError('The "mode" argument should be "text" or "binary".')
153 if self.application_state != WebSocketState.CONNECTED:
154 raise RuntimeError(
155 'WebSocket is not connected. Need to call "accept" first.'
156 )
157 message = await self.receive()
158 self._raise_on_disconnect(message)
159
160 if mode == "text":
161 text = message["text"]
162 else:
163 text = message["bytes"].decode("utf-8")
164 return json.loads(text)
165
166 async def iter_text(self) -> typing.AsyncIterator[str]:
167 try:
168 while True:
169 yield await self.receive_text()
170 except WebSocketDisconnect:
171 pass
172
173 async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
174 try:
175 while True:
176 yield await self.receive_bytes()
177 except WebSocketDisconnect:
178 pass
179
180 async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
181 try:
182 while True:
183 yield await self.receive_json()
184 except WebSocketDisconnect:
185 pass
186
187 async def send_text(self, data: str) -> None:
188 await self.send({"type": "websocket.send", "text": data})
189
190 async def send_bytes(self, data: bytes) -> None:
191 await self.send({"type": "websocket.send", "bytes": data})
192
193 async def send_json(self, data: typing.Any, mode: str = "text") -> None:
194 if mode not in {"text", "binary"}:
195 raise RuntimeError('The "mode" argument should be "text" or "binary".')
196 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
197 if mode == "text":
198 await self.send({"type": "websocket.send", "text": text})
199 else:
200 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
201
202 async def close(self, code: int = 1000, reason: str | None = None) -> None:
203 await self.send(
204 {"type": "websocket.close", "code": code, "reason": reason or ""}
205 )
206
207 async def send_denial_response(self, response: Response) -> None:
208 if "websocket.http.response" in self.scope.get("extensions", {}):
209 await response(self.scope, self.receive, self.send)
210 else:
211 raise RuntimeError(
212 "The server doesn't support the Websocket Denial Response extension."
213 )
214
215
216class WebSocketClose:
217 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
218 self.code = code
219 self.reason = reason or ""
220
221 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
222 await send(
223 {"type": "websocket.close", "code": self.code, "reason": self.reason}
224 )