Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/websockets.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import enum
4import json
5from collections.abc import AsyncIterator, Iterable
6from typing import Any, cast
8from starlette.requests import HTTPConnection
9from starlette.responses import Response
10from starlette.types import Message, Receive, Scope, Send
13class WebSocketState(enum.Enum):
14 CONNECTING = 0
15 CONNECTED = 1
16 DISCONNECTED = 2
17 RESPONSE = 3
20class WebSocketDisconnect(Exception):
21 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
22 self.code = code
23 self.reason = reason or ""
26class WebSocket(HTTPConnection):
27 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
28 super().__init__(scope)
29 assert scope["type"] == "websocket"
30 self._receive = receive
31 self._send = send
32 self.client_state = WebSocketState.CONNECTING
33 self.application_state = WebSocketState.CONNECTING
35 async def receive(self) -> Message:
36 """
37 Receive ASGI websocket messages, ensuring valid state transitions.
38 """
39 if self.client_state == WebSocketState.CONNECTING:
40 message = await self._receive()
41 message_type = message["type"]
42 if message_type != "websocket.connect":
43 raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
44 self.client_state = WebSocketState.CONNECTED
45 return message
46 elif self.client_state == WebSocketState.CONNECTED:
47 message = await self._receive()
48 message_type = message["type"]
49 if message_type not in {"websocket.receive", "websocket.disconnect"}:
50 raise RuntimeError(
51 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
52 )
53 if message_type == "websocket.disconnect":
54 self.client_state = WebSocketState.DISCONNECTED
55 return message
56 else:
57 raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
59 async def send(self, message: Message) -> None:
60 """
61 Send ASGI websocket messages, ensuring valid state transitions.
62 """
63 if self.application_state == WebSocketState.CONNECTING:
64 message_type = message["type"]
65 if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
66 raise RuntimeError(
67 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
68 f"but got {message_type!r}"
69 )
70 if message_type == "websocket.close":
71 self.application_state = WebSocketState.DISCONNECTED
72 elif message_type == "websocket.http.response.start":
73 self.application_state = WebSocketState.RESPONSE
74 else:
75 self.application_state = WebSocketState.CONNECTED
76 await self._send(message)
77 elif self.application_state == WebSocketState.CONNECTED:
78 message_type = message["type"]
79 if message_type not in {"websocket.send", "websocket.close"}:
80 raise RuntimeError(
81 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
82 )
83 if message_type == "websocket.close":
84 self.application_state = WebSocketState.DISCONNECTED
85 try:
86 await self._send(message)
87 except OSError:
88 self.application_state = WebSocketState.DISCONNECTED
89 raise WebSocketDisconnect(code=1006)
90 elif self.application_state == WebSocketState.RESPONSE:
91 message_type = message["type"]
92 if message_type != "websocket.http.response.body":
93 raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
94 if not message.get("more_body", False):
95 self.application_state = WebSocketState.DISCONNECTED
96 await self._send(message)
97 else:
98 raise RuntimeError('Cannot call "send" once a close message has been sent.')
100 async def accept(
101 self,
102 subprotocol: str | None = None,
103 headers: Iterable[tuple[bytes, bytes]] | None = None,
104 ) -> None:
105 headers = headers or []
107 if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
108 # If we haven't yet seen the 'connect' message, then wait for it first.
109 await self.receive()
110 await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
112 def _raise_on_disconnect(self, message: Message) -> None:
113 if message["type"] == "websocket.disconnect":
114 raise WebSocketDisconnect(message["code"], message.get("reason"))
116 async def receive_text(self) -> str:
117 if self.application_state != WebSocketState.CONNECTED:
118 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
119 message = await self.receive()
120 self._raise_on_disconnect(message)
121 return cast(str, message["text"])
123 async def receive_bytes(self) -> bytes:
124 if self.application_state != WebSocketState.CONNECTED:
125 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
126 message = await self.receive()
127 self._raise_on_disconnect(message)
128 return cast(bytes, message["bytes"])
130 async def receive_json(self, mode: str = "text") -> Any:
131 if mode not in {"text", "binary"}:
132 raise RuntimeError('The "mode" argument should be "text" or "binary".')
133 if self.application_state != WebSocketState.CONNECTED:
134 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
135 message = await self.receive()
136 self._raise_on_disconnect(message)
138 if mode == "text":
139 text = message["text"]
140 else:
141 text = message["bytes"].decode("utf-8")
142 return json.loads(text)
144 async def iter_text(self) -> AsyncIterator[str]:
145 try:
146 while True:
147 yield await self.receive_text()
148 except WebSocketDisconnect:
149 pass
151 async def iter_bytes(self) -> AsyncIterator[bytes]:
152 try:
153 while True:
154 yield await self.receive_bytes()
155 except WebSocketDisconnect:
156 pass
158 async def iter_json(self) -> AsyncIterator[Any]:
159 try:
160 while True:
161 yield await self.receive_json()
162 except WebSocketDisconnect:
163 pass
165 async def send_text(self, data: str) -> None:
166 await self.send({"type": "websocket.send", "text": data})
168 async def send_bytes(self, data: bytes) -> None:
169 await self.send({"type": "websocket.send", "bytes": data})
171 async def send_json(self, data: Any, mode: str = "text") -> None:
172 if mode not in {"text", "binary"}:
173 raise RuntimeError('The "mode" argument should be "text" or "binary".')
174 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
175 if mode == "text":
176 await self.send({"type": "websocket.send", "text": text})
177 else:
178 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
180 async def close(self, code: int = 1000, reason: str | None = None) -> None:
181 await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
183 async def send_denial_response(self, response: Response) -> None:
184 if "websocket.http.response" in self.scope.get("extensions", {}):
185 await response(self.scope, self.receive, self.send)
186 else:
187 raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
190class WebSocketClose:
191 def __init__(self, code: int = 1000, reason: str | None = None) -> None:
192 self.code = code
193 self.reason = reason or ""
195 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
196 await send({"type": "websocket.close", "code": self.code, "reason": self.reason})