Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/websockets.py: 25%
120 statements
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
« prev ^ index » next coverage.py v7.2.2, created at 2023-03-26 06:12 +0000
1import enum
2import json
3import typing
5from starlette.requests import HTTPConnection
6from starlette.types import Message, Receive, Scope, Send
9class WebSocketState(enum.Enum):
10 CONNECTING = 0
11 CONNECTED = 1
12 DISCONNECTED = 2
15class WebSocketDisconnect(Exception):
16 def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
17 self.code = code
18 self.reason = reason or ""
21class WebSocket(HTTPConnection):
22 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
23 super().__init__(scope)
24 assert scope["type"] == "websocket"
25 self._receive = receive
26 self._send = send
27 self.client_state = WebSocketState.CONNECTING
28 self.application_state = WebSocketState.CONNECTING
30 async def receive(self) -> Message:
31 """
32 Receive ASGI websocket messages, ensuring valid state transitions.
33 """
34 if self.client_state == WebSocketState.CONNECTING:
35 message = await self._receive()
36 message_type = message["type"]
37 if message_type != "websocket.connect":
38 raise RuntimeError(
39 'Expected ASGI message "websocket.connect", '
40 f"but got {message_type!r}"
41 )
42 self.client_state = WebSocketState.CONNECTED
43 return message
44 elif self.client_state == WebSocketState.CONNECTED:
45 message = await self._receive()
46 message_type = message["type"]
47 if message_type not in {"websocket.receive", "websocket.disconnect"}:
48 raise RuntimeError(
49 'Expected ASGI message "websocket.receive" or '
50 f'"websocket.disconnect", but got {message_type!r}'
51 )
52 if message_type == "websocket.disconnect":
53 self.client_state = WebSocketState.DISCONNECTED
54 return message
55 else:
56 raise RuntimeError(
57 'Cannot call "receive" once a disconnect message has been received.'
58 )
60 async def send(self, message: Message) -> None:
61 """
62 Send ASGI websocket messages, ensuring valid state transitions.
63 """
64 if self.application_state == WebSocketState.CONNECTING:
65 message_type = message["type"]
66 if message_type not in {"websocket.accept", "websocket.close"}:
67 raise RuntimeError(
68 'Expected ASGI message "websocket.connect", '
69 f"but got {message_type!r}"
70 )
71 if message_type == "websocket.close":
72 self.application_state = WebSocketState.DISCONNECTED
73 else:
74 self.application_state = WebSocketState.CONNECTED
75 await self._send(message)
76 elif self.application_state == WebSocketState.CONNECTED:
77 message_type = message["type"]
78 if message_type not in {"websocket.send", "websocket.close"}:
79 raise RuntimeError(
80 'Expected ASGI message "websocket.send" or "websocket.close", '
81 f"but got {message_type!r}"
82 )
83 if message_type == "websocket.close":
84 self.application_state = WebSocketState.DISCONNECTED
85 await self._send(message)
86 else:
87 raise RuntimeError('Cannot call "send" once a close message has been sent.')
89 async def accept(
90 self,
91 subprotocol: typing.Optional[str] = None,
92 headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None,
93 ) -> None:
94 headers = headers or []
96 if self.client_state == WebSocketState.CONNECTING:
97 # If we haven't yet seen the 'connect' message, then wait for it first.
98 await self.receive()
99 await self.send(
100 {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}
101 )
103 def _raise_on_disconnect(self, message: Message) -> None:
104 if message["type"] == "websocket.disconnect":
105 raise WebSocketDisconnect(message["code"])
107 async def receive_text(self) -> str:
108 if self.application_state != WebSocketState.CONNECTED:
109 raise RuntimeError(
110 'WebSocket is not connected. Need to call "accept" first.'
111 )
112 message = await self.receive()
113 self._raise_on_disconnect(message)
114 return message["text"]
116 async def receive_bytes(self) -> bytes:
117 if self.application_state != WebSocketState.CONNECTED:
118 raise RuntimeError(
119 'WebSocket is not connected. Need to call "accept" first.'
120 )
121 message = await self.receive()
122 self._raise_on_disconnect(message)
123 return message["bytes"]
125 async def receive_json(self, mode: str = "text") -> typing.Any:
126 if mode not in {"text", "binary"}:
127 raise RuntimeError('The "mode" argument should be "text" or "binary".')
128 if self.application_state != WebSocketState.CONNECTED:
129 raise RuntimeError(
130 'WebSocket is not connected. Need to call "accept" first.'
131 )
132 message = await self.receive()
133 self._raise_on_disconnect(message)
135 if mode == "text":
136 text = message["text"]
137 else:
138 text = message["bytes"].decode("utf-8")
139 return json.loads(text)
141 async def iter_text(self) -> typing.AsyncIterator[str]:
142 try:
143 while True:
144 yield await self.receive_text()
145 except WebSocketDisconnect:
146 pass
148 async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
149 try:
150 while True:
151 yield await self.receive_bytes()
152 except WebSocketDisconnect:
153 pass
155 async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
156 try:
157 while True:
158 yield await self.receive_json()
159 except WebSocketDisconnect:
160 pass
162 async def send_text(self, data: str) -> None:
163 await self.send({"type": "websocket.send", "text": data})
165 async def send_bytes(self, data: bytes) -> None:
166 await self.send({"type": "websocket.send", "bytes": data})
168 async def send_json(self, data: typing.Any, mode: str = "text") -> None:
169 if mode not in {"text", "binary"}:
170 raise RuntimeError('The "mode" argument should be "text" or "binary".')
171 text = json.dumps(data)
172 if mode == "text":
173 await self.send({"type": "websocket.send", "text": text})
174 else:
175 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
177 async def close(
178 self, code: int = 1000, reason: typing.Optional[str] = None
179 ) -> None:
180 await self.send(
181 {"type": "websocket.close", "code": code, "reason": reason or ""}
182 )
185class WebSocketClose:
186 def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None:
187 self.code = code
188 self.reason = reason or ""
190 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
191 await send(
192 {"type": "websocket.close", "code": self.code, "reason": self.reason}
193 )