Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/websockets.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

140 statements  

1from __future__ import annotations 

2 

3import enum 

4import json 

5import typing 

6 

7from starlette.requests import HTTPConnection 

8from starlette.responses import Response 

9from starlette.types import Message, Receive, Scope, Send 

10 

11 

12class WebSocketState(enum.Enum): 

13 CONNECTING = 0 

14 CONNECTED = 1 

15 DISCONNECTED = 2 

16 RESPONSE = 3 

17 

18 

19class WebSocketDisconnect(Exception): 

20 def __init__(self, code: int = 1000, reason: str | None = None) -> None: 

21 self.code = code 

22 self.reason = reason or "" 

23 

24 

25class WebSocket(HTTPConnection): 

26 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: 

27 super().__init__(scope) 

28 assert scope["type"] == "websocket" 

29 self._receive = receive 

30 self._send = send 

31 self.client_state = WebSocketState.CONNECTING 

32 self.application_state = WebSocketState.CONNECTING 

33 

34 async def receive(self) -> Message: 

35 """ 

36 Receive ASGI websocket messages, ensuring valid state transitions. 

37 """ 

38 if self.client_state == WebSocketState.CONNECTING: 

39 message = await self._receive() 

40 message_type = message["type"] 

41 if message_type != "websocket.connect": 

42 raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') 

43 self.client_state = WebSocketState.CONNECTED 

44 return message 

45 elif self.client_state == WebSocketState.CONNECTED: 

46 message = await self._receive() 

47 message_type = message["type"] 

48 if message_type not in {"websocket.receive", "websocket.disconnect"}: 

49 raise RuntimeError( 

50 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' 

51 ) 

52 if message_type == "websocket.disconnect": 

53 self.client_state = WebSocketState.DISCONNECTED 

54 return message 

55 else: 

56 raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') 

57 

58 async def send(self, message: Message) -> None: 

59 """ 

60 Send ASGI websocket messages, ensuring valid state transitions. 

61 """ 

62 if self.application_state == WebSocketState.CONNECTING: 

63 message_type = message["type"] 

64 if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: 

65 raise RuntimeError( 

66 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' 

67 f"but got {message_type!r}" 

68 ) 

69 if message_type == "websocket.close": 

70 self.application_state = WebSocketState.DISCONNECTED 

71 elif message_type == "websocket.http.response.start": 

72 self.application_state = WebSocketState.RESPONSE 

73 else: 

74 self.application_state = WebSocketState.CONNECTED 

75 await self._send(message) 

76 elif self.application_state == WebSocketState.CONNECTED: 

77 message_type = message["type"] 

78 if message_type not in {"websocket.send", "websocket.close"}: 

79 raise RuntimeError( 

80 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' 

81 ) 

82 if message_type == "websocket.close": 

83 self.application_state = WebSocketState.DISCONNECTED 

84 try: 

85 await self._send(message) 

86 except OSError: 

87 self.application_state = WebSocketState.DISCONNECTED 

88 raise WebSocketDisconnect(code=1006) 

89 elif self.application_state == WebSocketState.RESPONSE: 

90 message_type = message["type"] 

91 if message_type != "websocket.http.response.body": 

92 raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}') 

93 if not message.get("more_body", False): 

94 self.application_state = WebSocketState.DISCONNECTED 

95 await self._send(message) 

96 else: 

97 raise RuntimeError('Cannot call "send" once a close message has been sent.') 

98 

99 async def accept( 

100 self, 

101 subprotocol: str | None = None, 

102 headers: typing.Iterable[tuple[bytes, bytes]] | None = None, 

103 ) -> None: 

104 headers = headers or [] 

105 

106 if self.client_state == WebSocketState.CONNECTING: 

107 # If we haven't yet seen the 'connect' message, then wait for it first. 

108 await self.receive() 

109 await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}) 

110 

111 def _raise_on_disconnect(self, message: Message) -> None: 

112 if message["type"] == "websocket.disconnect": 

113 raise WebSocketDisconnect(message["code"], message.get("reason")) 

114 

115 async def receive_text(self) -> str: 

116 if self.application_state != WebSocketState.CONNECTED: 

117 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') 

118 message = await self.receive() 

119 self._raise_on_disconnect(message) 

120 return typing.cast(str, message["text"]) 

121 

122 async def receive_bytes(self) -> bytes: 

123 if self.application_state != WebSocketState.CONNECTED: 

124 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') 

125 message = await self.receive() 

126 self._raise_on_disconnect(message) 

127 return typing.cast(bytes, message["bytes"]) 

128 

129 async def receive_json(self, mode: str = "text") -> typing.Any: 

130 if mode not in {"text", "binary"}: 

131 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

132 if self.application_state != WebSocketState.CONNECTED: 

133 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') 

134 message = await self.receive() 

135 self._raise_on_disconnect(message) 

136 

137 if mode == "text": 

138 text = message["text"] 

139 else: 

140 text = message["bytes"].decode("utf-8") 

141 return json.loads(text) 

142 

143 async def iter_text(self) -> typing.AsyncIterator[str]: 

144 try: 

145 while True: 

146 yield await self.receive_text() 

147 except WebSocketDisconnect: 

148 pass 

149 

150 async def iter_bytes(self) -> typing.AsyncIterator[bytes]: 

151 try: 

152 while True: 

153 yield await self.receive_bytes() 

154 except WebSocketDisconnect: 

155 pass 

156 

157 async def iter_json(self) -> typing.AsyncIterator[typing.Any]: 

158 try: 

159 while True: 

160 yield await self.receive_json() 

161 except WebSocketDisconnect: 

162 pass 

163 

164 async def send_text(self, data: str) -> None: 

165 await self.send({"type": "websocket.send", "text": data}) 

166 

167 async def send_bytes(self, data: bytes) -> None: 

168 await self.send({"type": "websocket.send", "bytes": data}) 

169 

170 async def send_json(self, data: typing.Any, mode: str = "text") -> None: 

171 if mode not in {"text", "binary"}: 

172 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

173 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

174 if mode == "text": 

175 await self.send({"type": "websocket.send", "text": text}) 

176 else: 

177 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) 

178 

179 async def close(self, code: int = 1000, reason: str | None = None) -> None: 

180 await self.send({"type": "websocket.close", "code": code, "reason": reason or ""}) 

181 

182 async def send_denial_response(self, response: Response) -> None: 

183 if "websocket.http.response" in self.scope.get("extensions", {}): 

184 await response(self.scope, self.receive, self.send) 

185 else: 

186 raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") 

187 

188 

189class WebSocketClose: 

190 def __init__(self, code: int = 1000, reason: str | None = None) -> None: 

191 self.code = code 

192 self.reason = reason or "" 

193 

194 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

195 await send({"type": "websocket.close", "code": self.code, "reason": self.reason})