1from __future__ import annotations
2
3import enum
4import json
5import typing
6
7from starlette.requests import HTTPConnection
8from starlette.responses import Response
9from starlette.types import Message, Receive, Scope, Send
10
11
12class WebSocketState(enum.Enum):
13 CONNECTING = 0
14 CONNECTED = 1
15 DISCONNECTED = 2
16 RESPONSE = 3
17
18
19class WebSocketDisconnect(Exception):
20 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
21 self.code = code
22 self.reason = reason or ""
23
24
25class WebSocket(HTTPConnection):
26 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
27 super().__init__(scope)
28 assert scope["type"] == "websocket"
29 self._receive = receive
30 self._send = send
31 self.client_state = WebSocketState.CONNECTING
32 self.application_state = WebSocketState.CONNECTING
33
34 async def receive(self) -> Message:
35 """
36 Receive ASGI websocket messages, ensuring valid state transitions.
37 """
38 if self.client_state == WebSocketState.CONNECTING:
39 message = await self._receive()
40 message_type = message["type"]
41 if message_type != "websocket.connect":
42 raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
43 self.client_state = WebSocketState.CONNECTED
44 return message
45 elif self.client_state == WebSocketState.CONNECTED:
46 message = await self._receive()
47 message_type = message["type"]
48 if message_type not in {"websocket.receive", "websocket.disconnect"}:
49 raise RuntimeError(
50 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
51 )
52 if message_type == "websocket.disconnect":
53 self.client_state = WebSocketState.DISCONNECTED
54 return message
55 else:
56 raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
57
58 async def send(self, message: Message) -> None:
59 """
60 Send ASGI websocket messages, ensuring valid state transitions.
61 """
62 if self.application_state == WebSocketState.CONNECTING:
63 message_type = message["type"]
64 if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
65 raise RuntimeError(
66 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
67 f"but got {message_type!r}"
68 )
69 if message_type == "websocket.close":
70 self.application_state = WebSocketState.DISCONNECTED
71 elif message_type == "websocket.http.response.start":
72 self.application_state = WebSocketState.RESPONSE
73 else:
74 self.application_state = WebSocketState.CONNECTED
75 await self._send(message)
76 elif self.application_state == WebSocketState.CONNECTED:
77 message_type = message["type"]
78 if message_type not in {"websocket.send", "websocket.close"}:
79 raise RuntimeError(
80 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
81 )
82 if message_type == "websocket.close":
83 self.application_state = WebSocketState.DISCONNECTED
84 try:
85 await self._send(message)
86 except OSError:
87 self.application_state = WebSocketState.DISCONNECTED
88 raise WebSocketDisconnect(code=1006)
89 elif self.application_state == WebSocketState.RESPONSE:
90 message_type = message["type"]
91 if message_type != "websocket.http.response.body":
92 raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
93 if not message.get("more_body", False):
94 self.application_state = WebSocketState.DISCONNECTED
95 await self._send(message)
96 else:
97 raise RuntimeError('Cannot call "send" once a close message has been sent.')
98
99 async def accept(
100 self,
101 subprotocol: str | None = None,
102 headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
103 ) -> None:
104 headers = headers or []
105
106 if self.client_state == WebSocketState.CONNECTING:
107 # If we haven't yet seen the 'connect' message, then wait for it first.
108 await self.receive()
109 await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
110
111 def _raise_on_disconnect(self, message: Message) -> None:
112 if message["type"] == "websocket.disconnect":
113 raise WebSocketDisconnect(message["code"], message.get("reason"))
114
115 async def receive_text(self) -> str:
116 if self.application_state != WebSocketState.CONNECTED:
117 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
118 message = await self.receive()
119 self._raise_on_disconnect(message)
120 return typing.cast(str, message["text"])
121
122 async def receive_bytes(self) -> bytes:
123 if self.application_state != WebSocketState.CONNECTED:
124 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
125 message = await self.receive()
126 self._raise_on_disconnect(message)
127 return typing.cast(bytes, message["bytes"])
128
129 async def receive_json(self, mode: str = "text") -> typing.Any:
130 if mode not in {"text", "binary"}:
131 raise RuntimeError('The "mode" argument should be "text" or "binary".')
132 if self.application_state != WebSocketState.CONNECTED:
133 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
134 message = await self.receive()
135 self._raise_on_disconnect(message)
136
137 if mode == "text":
138 text = message["text"]
139 else:
140 text = message["bytes"].decode("utf-8")
141 return json.loads(text)
142
143 async def iter_text(self) -> typing.AsyncIterator[str]:
144 try:
145 while True:
146 yield await self.receive_text()
147 except WebSocketDisconnect:
148 pass
149
150 async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
151 try:
152 while True:
153 yield await self.receive_bytes()
154 except WebSocketDisconnect:
155 pass
156
157 async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
158 try:
159 while True:
160 yield await self.receive_json()
161 except WebSocketDisconnect:
162 pass
163
164 async def send_text(self, data: str) -> None:
165 await self.send({"type": "websocket.send", "text": data})
166
167 async def send_bytes(self, data: bytes) -> None:
168 await self.send({"type": "websocket.send", "bytes": data})
169
170 async def send_json(self, data: typing.Any, mode: str = "text") -> None:
171 if mode not in {"text", "binary"}:
172 raise RuntimeError('The "mode" argument should be "text" or "binary".')
173 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
174 if mode == "text":
175 await self.send({"type": "websocket.send", "text": text})
176 else:
177 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
178
179 async def close(self, code: int = 1000, reason: str | None = None) -> None:
180 await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
181
182 async def send_denial_response(self, response: Response) -> None:
183 if "websocket.http.response" in self.scope.get("extensions", {}):
184 await response(self.scope, self.receive, self.send)
185 else:
186 raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
187
188
189class WebSocketClose:
190 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
191 self.code = code
192 self.reason = reason or ""
193
194 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
195 await send({"type": "websocket.close", "code": self.code, "reason": self.reason})