Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/starlette/websockets.py: 25%

120 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import enum 

2import json 

3import typing 

4 

5from starlette.requests import HTTPConnection 

6from starlette.types import Message, Receive, Scope, Send 

7 

8 

9class WebSocketState(enum.Enum): 

10 CONNECTING = 0 

11 CONNECTED = 1 

12 DISCONNECTED = 2 

13 

14 

15class WebSocketDisconnect(Exception): 

16 def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: 

17 self.code = code 

18 self.reason = reason or "" 

19 

20 

21class WebSocket(HTTPConnection): 

22 def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: 

23 super().__init__(scope) 

24 assert scope["type"] == "websocket" 

25 self._receive = receive 

26 self._send = send 

27 self.client_state = WebSocketState.CONNECTING 

28 self.application_state = WebSocketState.CONNECTING 

29 

30 async def receive(self) -> Message: 

31 """ 

32 Receive ASGI websocket messages, ensuring valid state transitions. 

33 """ 

34 if self.client_state == WebSocketState.CONNECTING: 

35 message = await self._receive() 

36 message_type = message["type"] 

37 if message_type != "websocket.connect": 

38 raise RuntimeError( 

39 'Expected ASGI message "websocket.connect", ' 

40 f"but got {message_type!r}" 

41 ) 

42 self.client_state = WebSocketState.CONNECTED 

43 return message 

44 elif self.client_state == WebSocketState.CONNECTED: 

45 message = await self._receive() 

46 message_type = message["type"] 

47 if message_type not in {"websocket.receive", "websocket.disconnect"}: 

48 raise RuntimeError( 

49 'Expected ASGI message "websocket.receive" or ' 

50 f'"websocket.disconnect", but got {message_type!r}' 

51 ) 

52 if message_type == "websocket.disconnect": 

53 self.client_state = WebSocketState.DISCONNECTED 

54 return message 

55 else: 

56 raise RuntimeError( 

57 'Cannot call "receive" once a disconnect message has been received.' 

58 ) 

59 

60 async def send(self, message: Message) -> None: 

61 """ 

62 Send ASGI websocket messages, ensuring valid state transitions. 

63 """ 

64 if self.application_state == WebSocketState.CONNECTING: 

65 message_type = message["type"] 

66 if message_type not in {"websocket.accept", "websocket.close"}: 

67 raise RuntimeError( 

68 'Expected ASGI message "websocket.connect", ' 

69 f"but got {message_type!r}" 

70 ) 

71 if message_type == "websocket.close": 

72 self.application_state = WebSocketState.DISCONNECTED 

73 else: 

74 self.application_state = WebSocketState.CONNECTED 

75 await self._send(message) 

76 elif self.application_state == WebSocketState.CONNECTED: 

77 message_type = message["type"] 

78 if message_type not in {"websocket.send", "websocket.close"}: 

79 raise RuntimeError( 

80 'Expected ASGI message "websocket.send" or "websocket.close", ' 

81 f"but got {message_type!r}" 

82 ) 

83 if message_type == "websocket.close": 

84 self.application_state = WebSocketState.DISCONNECTED 

85 await self._send(message) 

86 else: 

87 raise RuntimeError('Cannot call "send" once a close message has been sent.') 

88 

89 async def accept( 

90 self, 

91 subprotocol: typing.Optional[str] = None, 

92 headers: typing.Optional[typing.Iterable[typing.Tuple[bytes, bytes]]] = None, 

93 ) -> None: 

94 headers = headers or [] 

95 

96 if self.client_state == WebSocketState.CONNECTING: 

97 # If we haven't yet seen the 'connect' message, then wait for it first. 

98 await self.receive() 

99 await self.send( 

100 {"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers} 

101 ) 

102 

103 def _raise_on_disconnect(self, message: Message) -> None: 

104 if message["type"] == "websocket.disconnect": 

105 raise WebSocketDisconnect(message["code"]) 

106 

107 async def receive_text(self) -> str: 

108 if self.application_state != WebSocketState.CONNECTED: 

109 raise RuntimeError( 

110 'WebSocket is not connected. Need to call "accept" first.' 

111 ) 

112 message = await self.receive() 

113 self._raise_on_disconnect(message) 

114 return message["text"] 

115 

116 async def receive_bytes(self) -> bytes: 

117 if self.application_state != WebSocketState.CONNECTED: 

118 raise RuntimeError( 

119 'WebSocket is not connected. Need to call "accept" first.' 

120 ) 

121 message = await self.receive() 

122 self._raise_on_disconnect(message) 

123 return message["bytes"] 

124 

125 async def receive_json(self, mode: str = "text") -> typing.Any: 

126 if mode not in {"text", "binary"}: 

127 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

128 if self.application_state != WebSocketState.CONNECTED: 

129 raise RuntimeError( 

130 'WebSocket is not connected. Need to call "accept" first.' 

131 ) 

132 message = await self.receive() 

133 self._raise_on_disconnect(message) 

134 

135 if mode == "text": 

136 text = message["text"] 

137 else: 

138 text = message["bytes"].decode("utf-8") 

139 return json.loads(text) 

140 

141 async def iter_text(self) -> typing.AsyncIterator[str]: 

142 try: 

143 while True: 

144 yield await self.receive_text() 

145 except WebSocketDisconnect: 

146 pass 

147 

148 async def iter_bytes(self) -> typing.AsyncIterator[bytes]: 

149 try: 

150 while True: 

151 yield await self.receive_bytes() 

152 except WebSocketDisconnect: 

153 pass 

154 

155 async def iter_json(self) -> typing.AsyncIterator[typing.Any]: 

156 try: 

157 while True: 

158 yield await self.receive_json() 

159 except WebSocketDisconnect: 

160 pass 

161 

162 async def send_text(self, data: str) -> None: 

163 await self.send({"type": "websocket.send", "text": data}) 

164 

165 async def send_bytes(self, data: bytes) -> None: 

166 await self.send({"type": "websocket.send", "bytes": data}) 

167 

168 async def send_json(self, data: typing.Any, mode: str = "text") -> None: 

169 if mode not in {"text", "binary"}: 

170 raise RuntimeError('The "mode" argument should be "text" or "binary".') 

171 text = json.dumps(data) 

172 if mode == "text": 

173 await self.send({"type": "websocket.send", "text": text}) 

174 else: 

175 await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) 

176 

177 async def close( 

178 self, code: int = 1000, reason: typing.Optional[str] = None 

179 ) -> None: 

180 await self.send( 

181 {"type": "websocket.close", "code": code, "reason": reason or ""} 

182 ) 

183 

184 

185class WebSocketClose: 

186 def __init__(self, code: int = 1000, reason: typing.Optional[str] = None) -> None: 

187 self.code = code 

188 self.reason = reason or "" 

189 

190 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

191 await send( 

192 {"type": "websocket.close", "code": self.code, "reason": self.reason} 

193 )