Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/starlette/websockets.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

144 statements  

1from __future__ import annotations 

2 

3import enum 

4import json 

5from collections.abc import AsyncIterator, Iterable 

6from typing import Any, cast 

7 

8from starlette.requests import HTTPConnection 

9from starlette.responses import Response 

10from starlette.types import Message, Receive, Scope, Send 

11 

12 

13class WebSocketState(enum.Enum): 

14 CONNECTING = 0 

15 CONNECTED = 1 

16 DISCONNECTED = 2 

17 RESPONSE = 3 

18 

19 

20class WebSocketDisconnect(Exception): 

21 def __init__(self, code: int = 1000, reason: str | None = None) -> None: 

22 self.code = code 

23 self.reason = reason or "" 

24 

25 

26class WebSocket(HTTPConnection): 

27 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: 

28 super().__init__(scope) 

29 assert scope["type"] == "websocket" 

30 self._receive = receive 

31 self._send = send 

32 self.client_state = WebSocketState.CONNECTING 

33 self.application_state = WebSocketState.CONNECTING 

34 

35 async def receive(self) -> Message: 

36 """ 

37 Receive ASGI websocket messages, ensuring valid state transitions. 

38 """ 

39 if self.client_state == WebSocketState.CONNECTING: 

40 message = await self._receive() 

41 message_type = message["type"] 

42 if message_type != "websocket.connect": 

43 raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}') 

44 self.client_state = WebSocketState.CONNECTED 

45 return message 

46 elif self.client_state == WebSocketState.CONNECTED: 

47 message = await self._receive() 

48 message_type = message["type"] 

49 if message_type not in {"websocket.receive", "websocket.disconnect"}: 

50 raise RuntimeError( 

51 f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}' 

52 ) 

53 if message_type == "websocket.disconnect": 

54 self.client_state = WebSocketState.DISCONNECTED 

55 return message 

56 else: 

57 raise RuntimeError('Cannot call "receive" once a disconnect message has been received.') 

58 

59 async def send(self, message: Message) -> None: 

60 """ 

61 Send ASGI websocket messages, ensuring valid state transitions. 

62 """ 

63 if self.application_state == WebSocketState.CONNECTING: 

64 message_type = message["type"] 

65 if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}: 

66 raise RuntimeError( 

67 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", ' 

68 f"but got {message_type!r}" 

69 ) 

70 if message_type == "websocket.close": 

71 self.application_state = WebSocketState.DISCONNECTED 

72 elif message_type == "websocket.http.response.start": 

73 self.application_state = WebSocketState.RESPONSE 

74 else: 

75 self.application_state = WebSocketState.CONNECTED 

76 await self._send(message) 

77 elif self.application_state == WebSocketState.CONNECTED: 

78 message_type = message["type"] 

79 if message_type not in {"websocket.send", "websocket.close"}: 

80 raise RuntimeError( 

81 f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}' 

82 ) 

83 if message_type == "websocket.close": 

84 self.application_state = WebSocketState.DISCONNECTED 

85 try: 

86 await self._send(message) 

87 except OSError: 

88 self.application_state = WebSocketState.DISCONNECTED 

89 raise WebSocketDisconnect(code=1006) 

90 elif self.application_state == WebSocketState.RESPONSE: 

91 message_type = message["type"] 

92 if message_type != "websocket.http.response.body": 

93 raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}') 

94 if not message.get("more_body", False): 

95 self.application_state = WebSocketState.DISCONNECTED 

96 await self._send(message) 

97 else: 

98 raise RuntimeError('Cannot call "send" once a close message has been sent.') 

99 

100 async def accept( 

101 self, 

102 subprotocol: str | None = None, 

103 headers: Iterable[tuple[bytes, bytes]] | None = None, 

104 ) -> None: 

105 headers = headers or [] 

106 

107 if self.client_state == WebSocketState.CONNECTING: # pragma: no branch 

108 # If we haven't yet seen the 'connect' message, then wait for it first. 

109 await self.receive() 

110 await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers}) 

111 

112 def _raise_on_disconnect(self, message: Message) -> None: 

113 if message["type"] == "websocket.disconnect": 

114 raise WebSocketDisconnect(message["code"], message.get("reason")) 

115 

116 async def receive_text(self) -> str: 

117 if self.application_state != WebSocketState.CONNECTED: 

118 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') 

119 message = await self.receive() 

120 self._raise_on_disconnect(message) 

121 return cast(str, message["text"]) 

122 

123 async def receive_bytes(self) -> bytes: 

124 if self.application_state != WebSocketState.CONNECTED: 

125 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') 

126 message = await self.receive() 

127 self._raise_on_disconnect(message) 

128 return cast(bytes, message["bytes"]) 

129 

130 async def receive_json(self, mode: str = "text") -> Any: 

131 if mode not in {"text", "binary"}: 

132 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

133 if self.application_state != WebSocketState.CONNECTED: 

134 raise RuntimeError('WebSocket is not connected. Need to call "accept" first.') 

135 message = await self.receive() 

136 self._raise_on_disconnect(message) 

137 

138 if mode == "text": 

139 text = message["text"] 

140 else: 

141 text = message["bytes"].decode("utf-8") 

142 return json.loads(text) 

143 

144 async def iter_text(self) -> AsyncIterator[str]: 

145 try: 

146 while True: 

147 yield await self.receive_text() 

148 except WebSocketDisconnect: 

149 pass 

150 

151 async def iter_bytes(self) -> AsyncIterator[bytes]: 

152 try: 

153 while True: 

154 yield await self.receive_bytes() 

155 except WebSocketDisconnect: 

156 pass 

157 

158 async def iter_json(self) -> AsyncIterator[Any]: 

159 try: 

160 while True: 

161 yield await self.receive_json() 

162 except WebSocketDisconnect: 

163 pass 

164 

165 async def send_text(self, data: str) -> None: 

166 await self.send({"type": "websocket.send", "text": data}) 

167 

168 async def send_bytes(self, data: bytes) -> None: 

169 await self.send({"type": "websocket.send", "bytes": data}) 

170 

171 async def send_json(self, data: Any, mode: str = "text") -> None: 

172 if mode not in {"text", "binary"}: 

173 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

174 text = json.dumps(data, separators=(",", ":"), ensure_ascii=False) 

175 if mode == "text": 

176 await self.send({"type": "websocket.send", "text": text}) 

177 else: 

178 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) 

179 

180 async def close(self, code: int = 1000, reason: str | None = None) -> None: 

181 await self.send({"type": "websocket.close", "code": code, "reason": reason or ""}) 

182 

183 async def send_denial_response(self, response: Response) -> None: 

184 if "websocket.http.response" in self.scope.get("extensions", {}): 

185 await response(self.scope, self.receive, self.send) 

186 else: 

187 raise RuntimeError("The server doesn't support the Websocket Denial Response extension.") 

188 

189 

190class WebSocketClose: 

191 def __init__(self, code: int = 1000, reason: str | None = None) -> None: 

192 self.code = code 

193 self.reason = reason or "" 

194 

195 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

196 await send({"type": "websocket.close", "code": self.code, "reason": self.reason})