1from __future__ import annotations
2
3import enum
4import json
5from collections.abc import AsyncIterator, Iterable
6from typing import Any, cast
7
8from starlette.requests import HTTPConnection
9from starlette.responses import Response
10from starlette.types import Message, Receive, Scope, Send
11
12
13class WebSocketState(enum.Enum):
14 CONNECTING = 0
15 CONNECTED = 1
16 DISCONNECTED = 2
17 RESPONSE = 3
18
19
20class WebSocketDisconnect(Exception):
21 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
22 self.code = code
23 self.reason = reason or ""
24
25
26class WebSocket(HTTPConnection):
27 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
28 super().__init__(scope)
29 assert scope["type"] == "websocket"
30 self._receive = receive
31 self._send = send
32 self.client_state = WebSocketState.CONNECTING
33 self.application_state = WebSocketState.CONNECTING
34
35 async def receive(self) -> Message:
36 """
37 Receive ASGI websocket messages, ensuring valid state transitions.
38 """
39 if self.client_state == WebSocketState.CONNECTING:
40 message = await self._receive()
41 message_type = message["type"]
42 if message_type != "websocket.connect":
43 raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
44 self.client_state = WebSocketState.CONNECTED
45 return message
46 elif self.client_state == WebSocketState.CONNECTED:
47 message = await self._receive()
48 message_type = message["type"]
49 if message_type not in {"websocket.receive", "websocket.disconnect"}:
50 raise RuntimeError(
51 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
52 )
53 if message_type == "websocket.disconnect":
54 self.client_state = WebSocketState.DISCONNECTED
55 return message
56 else:
57 raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
58
59 async def send(self, message: Message) -> None:
60 """
61 Send ASGI websocket messages, ensuring valid state transitions.
62 """
63 if self.application_state == WebSocketState.CONNECTING:
64 message_type = message["type"]
65 if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
66 raise RuntimeError(
67 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
68 f"but got {message_type!r}"
69 )
70 if message_type == "websocket.close":
71 self.application_state = WebSocketState.DISCONNECTED
72 elif message_type == "websocket.http.response.start":
73 self.application_state = WebSocketState.RESPONSE
74 else:
75 self.application_state = WebSocketState.CONNECTED
76 await self._send(message)
77 elif self.application_state == WebSocketState.CONNECTED:
78 message_type = message["type"]
79 if message_type not in {"websocket.send", "websocket.close"}:
80 raise RuntimeError(
81 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
82 )
83 if message_type == "websocket.close":
84 self.application_state = WebSocketState.DISCONNECTED
85 try:
86 await self._send(message)
87 except OSError:
88 self.application_state = WebSocketState.DISCONNECTED
89 raise WebSocketDisconnect(code=1006)
90 elif self.application_state == WebSocketState.RESPONSE:
91 message_type = message["type"]
92 if message_type != "websocket.http.response.body":
93 raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
94 if not message.get("more_body", False):
95 self.application_state = WebSocketState.DISCONNECTED
96 await self._send(message)
97 else:
98 raise RuntimeError('Cannot call "send" once a close message has been sent.')
99
100 async def accept(
101 self,
102 subprotocol: str | None = None,
103 headers: Iterable[tuple[bytes, bytes]] | None = None,
104 ) -> None:
105 headers = headers or []
106
107 if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
108 # If we haven't yet seen the 'connect' message, then wait for it first.
109 await self.receive()
110 await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
111
112 def _raise_on_disconnect(self, message: Message) -> None:
113 if message["type"] == "websocket.disconnect":
114 raise WebSocketDisconnect(message["code"], message.get("reason"))
115
116 async def receive_text(self) -> str:
117 if self.application_state != WebSocketState.CONNECTED:
118 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
119 message = await self.receive()
120 self._raise_on_disconnect(message)
121 return cast(str, message["text"])
122
123 async def receive_bytes(self) -> bytes:
124 if self.application_state != WebSocketState.CONNECTED:
125 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
126 message = await self.receive()
127 self._raise_on_disconnect(message)
128 return cast(bytes, message["bytes"])
129
130 async def receive_json(self, mode: str = "text") -> Any:
131 if mode not in {"text", "binary"}:
132 raise RuntimeError('The "mode" argument should be "text" or "binary".')
133 if self.application_state != WebSocketState.CONNECTED:
134 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
135 message = await self.receive()
136 self._raise_on_disconnect(message)
137
138 if mode == "text":
139 text = message["text"]
140 else:
141 text = message["bytes"].decode("utf-8")
142 return json.loads(text)
143
144 async def iter_text(self) -> AsyncIterator[str]:
145 try:
146 while True:
147 yield await self.receive_text()
148 except WebSocketDisconnect:
149 pass
150
151 async def iter_bytes(self) -> AsyncIterator[bytes]:
152 try:
153 while True:
154 yield await self.receive_bytes()
155 except WebSocketDisconnect:
156 pass
157
158 async def iter_json(self) -> AsyncIterator[Any]:
159 try:
160 while True:
161 yield await self.receive_json()
162 except WebSocketDisconnect:
163 pass
164
165 async def send_text(self, data: str) -> None:
166 await self.send({"type": "websocket.send", "text": data})
167
168 async def send_bytes(self, data: bytes) -> None:
169 await self.send({"type": "websocket.send", "bytes": data})
170
171 async def send_json(self, data: Any, mode: str = "text") -> None:
172 if mode not in {"text", "binary"}:
173 raise RuntimeError('The "mode" argument should be "text" or "binary".')
174 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
175 if mode == "text":
176 await self.send({"type": "websocket.send", "text": text})
177 else:
178 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
179
180 async def close(self, code: int = 1000, reason: str | None = None) -> None:
181 await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
182
183 async def send_denial_response(self, response: Response) -> None:
184 if "websocket.http.response" in self.scope.get("extensions", {}):
185 await response(self.scope, self.receive, self.send)
186 else:
187 raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
188
189
190class WebSocketClose:
191 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
192 self.code = code
193 self.reason = reason or ""
194
195 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
196 await send({"type": "websocket.close", "code": self.code, "reason": self.reason})