Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/websockets.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

140 statements  

1from __future__ import annotations 

2 

3import enum 

4import json 

5import typing 

6 

7from starlette.requests import HTTPConnection 

8from starlette.responses import Response 

9from starlette.types import Message, Receive, Scope, Send 

10 

11 

12class WebSocketState(enum.Enum): 

13 CONNECTING = 0 

14 CONNECTED = 1 

15 DISCONNECTED = 2 

16 RESPONSE = 3 

17 

18 

19class WebSocketDisconnect(Exception): 

20 def __init__(self, code: int = 1000, reason: str | None = None) -> None: 

21 self.code = code 

22 self.reason = reason or "" 

23 

24 

25class WebSocket(HTTPConnection): 

26 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: 

27 super().__init__(scope) 

28 assert scope["type"] == "websocket" 

29 self._receive = receive 

30 self._send = send 

31 self.client_state = WebSocketState.CONNECTING 

32 self.application_state = WebSocketState.CONNECTING 

33 

34 async def receive(self) -> Message: 

35 """ 

36 Receive ASGI websocket messages, ensuring valid state transitions. 

37 """ 

38 if self.client_state == WebSocketState.CONNECTING: 

39 message = await self._receive() 

40 message_type = message["type"] 

41 if message_type != "websocket.connect": 

42 raise RuntimeError( 

43 'Expected ASGI message "websocket.connect", ' 

44 f"but got {message_type!r}" 

45 ) 

46 self.client_state = WebSocketState.CONNECTED 

47 return message 

48 elif self.client_state == WebSocketState.CONNECTED: 

49 message = await self._receive() 

50 message_type = message["type"] 

51 if message_type not in {"websocket.receive", "websocket.disconnect"}: 

52 raise RuntimeError( 

53 'Expected ASGI message "websocket.receive" or ' 

54 f'"websocket.disconnect", but got {message_type!r}' 

55 ) 

56 if message_type == "websocket.disconnect": 

57 self.client_state = WebSocketState.DISCONNECTED 

58 return message 

59 else: 

60 raise RuntimeError( 

61 'Cannot call "receive" once a disconnect message has been received.' 

62 ) 

63 

64 async def send(self, message: Message) -> None: 

65 """ 

66 Send ASGI websocket messages, ensuring valid state transitions. 

67 """ 

68 if self.application_state == WebSocketState.CONNECTING: 

69 message_type = message["type"] 

70 if message_type not in { 

71 "websocket.accept", 

72 "websocket.close", 

73 "websocket.http.response.start", 

74 }: 

75 raise RuntimeError( 

76 'Expected ASGI message "websocket.accept",' 

77 '"websocket.close" or "websocket.http.response.start",' 

78 f"but got {message_type!r}" 

79 ) 

80 if message_type == "websocket.close": 

81 self.application_state = WebSocketState.DISCONNECTED 

82 elif message_type == "websocket.http.response.start": 

83 self.application_state = WebSocketState.RESPONSE 

84 else: 

85 self.application_state = WebSocketState.CONNECTED 

86 await self._send(message) 

87 elif self.application_state == WebSocketState.CONNECTED: 

88 message_type = message["type"] 

89 if message_type not in {"websocket.send", "websocket.close"}: 

90 raise RuntimeError( 

91 'Expected ASGI message "websocket.send" or "websocket.close", ' 

92 f"but got {message_type!r}" 

93 ) 

94 if message_type == "websocket.close": 

95 self.application_state = WebSocketState.DISCONNECTED 

96 try: 

97 await self._send(message) 

98 except OSError: 

99 self.application_state = WebSocketState.DISCONNECTED 

100 raise WebSocketDisconnect(code=1006) 

101 elif self.application_state == WebSocketState.RESPONSE: 

102 message_type = message["type"] 

103 if message_type != "websocket.http.response.body": 

104 raise RuntimeError( 

105 'Expected ASGI message "websocket.http.response.body", ' 

106 f"but got {message_type!r}" 

107 ) 

108 if not message.get("more_body", False): 

109 self.application_state = WebSocketState.DISCONNECTED 

110 await self._send(message) 

111 else: 

112 raise RuntimeError('Cannot call "send" once a close message has been sent.') 

113 

114 async def accept( 

115 self, 

116 subprotocol: str | None = None, 

117 headers: typing.Iterable[tuple[bytes, bytes]] | None = None, 

118 ) -> None: 

119 headers = headers or [] 

120 

121 if self.client_state == WebSocketState.CONNECTING: 

122 # If we haven't yet seen the 'connect' message, then wait for it first. 

123 await self.receive() 

124 await self.send( 

125 {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} 

126 ) 

127 

128 def _raise_on_disconnect(self, message: Message) -> None: 

129 if message["type"] == "websocket.disconnect": 

130 raise WebSocketDisconnect(message["code"], message.get("reason")) 

131 

132 async def receive_text(self) -> str: 

133 if self.application_state != WebSocketState.CONNECTED: 

134 raise RuntimeError( 

135 'WebSocket is not connected. Need to call "accept" first.' 

136 ) 

137 message = await self.receive() 

138 self._raise_on_disconnect(message) 

139 return typing.cast(str, message["text"]) 

140 

141 async def receive_bytes(self) -> bytes: 

142 if self.application_state != WebSocketState.CONNECTED: 

143 raise RuntimeError( 

144 'WebSocket is not connected. Need to call "accept" first.' 

145 ) 

146 message = await self.receive() 

147 self._raise_on_disconnect(message) 

148 return typing.cast(bytes, message["bytes"]) 

149 

150 async def receive_json(self, mode: str = "text") -> typing.Any: 

151 if mode not in {"text", "binary"}: 

152 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

153 if self.application_state != WebSocketState.CONNECTED: 

154 raise RuntimeError( 

155 'WebSocket is not connected. Need to call "accept" first.' 

156 ) 

157 message = await self.receive() 

158 self._raise_on_disconnect(message) 

159 

160 if mode == "text": 

161 text = message["text"] 

162 else: 

163 text = message["bytes"].decode("utf-8") 

164 return json.loads(text) 

165 

166 async def iter_text(self) -> typing.AsyncIterator[str]: 

167 try: 

168 while True: 

169 yield await self.receive_text() 

170 except WebSocketDisconnect: 

171 pass 

172 

173 async def iter_bytes(self) -> typing.AsyncIterator[bytes]: 

174 try: 

175 while True: 

176 yield await self.receive_bytes() 

177 except WebSocketDisconnect: 

178 pass 

179 

180 async def iter_json(self) -> typing.AsyncIterator[typing.Any]: 

181 try: 

182 while True: 

183 yield await self.receive_json() 

184 except WebSocketDisconnect: 

185 pass 

186 

187 async def send_text(self, data: str) -> None: 

188 await self.send({"type": "websocket.send", "text": data}) 

189 

190 async def send_bytes(self, data: bytes) -> None: 

191 await self.send({"type": "websocket.send", "bytes": data}) 

192 

193 async def send_json(self, data: typing.Any, mode: str = "text") -> None: 

194 if mode not in {"text", "binary"}: 

195 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

196 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

197 if mode == "text": 

198 await self.send({"type": "websocket.send", "text": text}) 

199 else: 

200 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) 

201 

202 async def close(self, code: int = 1000, reason: str | None = None) -> None: 

203 await self.send( 

204 {"type": "websocket.close", "code": code, "reason": reason or ""} 

205 ) 

206 

207 async def send_denial_response(self, response: Response) -> None: 

208 if "websocket.http.response" in self.scope.get("extensions", {}): 

209 await response(self.scope, self.receive, self.send) 

210 else: 

211 raise RuntimeError( 

212 "The server doesn't support the Websocket Denial Response extension." 

213 ) 

214 

215 

216class WebSocketClose: 

217 def __init__(self, code: int = 1000, reason: str | None = None) -> None: 

218 self.code = code 

219 self.reason = reason or "" 

220 

221 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

222 await send( 

223 {"type": "websocket.close", "code": self.code, "reason": self.reason} 

224 )