Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/websocket/_handshake.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

123 statements  

1""" 

2_handshake.py 

3websocket - WebSocket client library for Python 

4 

5Copyright 2025 engn33r 

6 

7Licensed under the Apache License, Version 2.0 (the "License"); 

8you may not use this file except in compliance with the License. 

9You may obtain a copy of the License at 

10 

11 http://www.apache.org/licenses/LICENSE-2.0 

12 

13Unless required by applicable law or agreed to in writing, software 

14distributed under the License is distributed on an "AS IS" BASIS, 

15WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

16See the License for the specific language governing permissions and 

17limitations under the License. 

18""" 

19 

20import hashlib 

21import hmac 

22import os 

23import socket 

24from base64 import encodebytes as base64encode 

25from http import HTTPStatus 

26from typing import Any, List, Optional 

27 

28from ._cookiejar import SimpleCookieJar 

29from ._exceptions import WebSocketException, WebSocketBadStatusException 

30from ._http import read_headers 

31from ._logging import dump, error 

32from ._socket import send 

33 

34__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] 

35 

36# websocket supported version. 

37VERSION = 13 

38 

39SUPPORTED_REDIRECT_STATUSES = ( 

40 HTTPStatus.MOVED_PERMANENTLY, 

41 HTTPStatus.FOUND, 

42 HTTPStatus.SEE_OTHER, 

43 HTTPStatus.TEMPORARY_REDIRECT, 

44 HTTPStatus.PERMANENT_REDIRECT, 

45) 

46SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) 

47 

48CookieJar = SimpleCookieJar() 

49 

50 

51class handshake_response: 

52 def __init__(self, status: int, headers: dict, subprotocol: Optional[str]) -> None: 

53 self.status = status 

54 self.headers = headers 

55 self.subprotocol = subprotocol 

56 CookieJar.add(headers.get("set-cookie")) 

57 

58 

59def handshake( 

60 sock: socket.socket, 

61 url: str, 

62 hostname: str, 

63 port: int, 

64 resource: str, 

65 **options: Any, 

66) -> handshake_response: 

67 headers, key = _get_handshake_headers(resource, url, hostname, port, options) 

68 

69 header_str = "\r\n".join(headers) 

70 send(sock, header_str) 

71 dump("request header", header_str) 

72 

73 status, resp = _get_resp_headers(sock) 

74 if status in SUPPORTED_REDIRECT_STATUSES: 

75 return handshake_response(status, resp, None) 

76 success, subproto = _validate(resp, key, options.get("subprotocols")) 

77 if not success: 

78 raise WebSocketException("Invalid WebSocket Header") 

79 

80 return handshake_response(status, resp, subproto) 

81 

82 

83def _pack_hostname(hostname: str) -> str: 

84 # IPv6 address 

85 if ":" in hostname: 

86 return f"[{hostname}]" 

87 return hostname 

88 

89 

90def _get_handshake_headers( 

91 resource: str, url: str, host: str, port: int, options: dict 

92) -> tuple: 

93 headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"] 

94 if port in [80, 443]: 

95 hostport = _pack_hostname(host) 

96 else: 

97 hostport = f"{_pack_hostname(host)}:{port}" 

98 if not options.get("suppress_host"): 

99 if options.get("host"): 

100 headers.append(f'Host: {options["host"]}') 

101 else: 

102 headers.append(f"Host: {hostport}") 

103 

104 # scheme indicates whether http or https is used in Origin 

105 # The same approach is used in parse_url of _url.py to set default port 

106 scheme, url = url.split(":", 1) 

107 if not options.get("suppress_origin"): 

108 if "origin" in options and options["origin"] is not None: 

109 headers.append(f'Origin: {options["origin"]}') 

110 elif scheme == "wss": 

111 headers.append(f"Origin: https://{hostport}") 

112 else: 

113 headers.append(f"Origin: http://{hostport}") 

114 

115 key = _create_sec_websocket_key() 

116 

117 # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified 

118 if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]: 

119 headers.append(f"Sec-WebSocket-Key: {key}") 

120 else: 

121 key = options["header"]["Sec-WebSocket-Key"] 

122 

123 if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]: 

124 headers.append(f"Sec-WebSocket-Version: {VERSION}") 

125 

126 if not options.get("connection"): 

127 headers.append("Connection: Upgrade") 

128 else: 

129 headers.append(options["connection"]) 

130 

131 if subprotocols := options.get("subprotocols"): 

132 headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}') 

133 

134 if header := options.get("header"): 

135 if isinstance(header, dict): 

136 header = [": ".join([k, v]) for k, v in header.items() if v is not None] 

137 headers.extend(header) 

138 

139 server_cookie = CookieJar.get(host) 

140 client_cookie = options.get("cookie", None) 

141 

142 if cookie := "; ".join(filter(None, [server_cookie, client_cookie])): 

143 headers.append(f"Cookie: {cookie}") 

144 

145 headers.extend(("", "")) 

146 return headers, key 

147 

148 

149def _get_resp_headers( 

150 sock: socket.socket, success_statuses: tuple = SUCCESS_STATUSES 

151) -> tuple: 

152 status, resp_headers, status_message = read_headers(sock) 

153 if status not in success_statuses: 

154 content_len = resp_headers.get("content-length") 

155 if content_len: 

156 # Use chunked reading to avoid SSL BAD_LENGTH error on large responses 

157 from ._socket import recv 

158 

159 response_body = b"" 

160 remaining = int(content_len) 

161 while remaining > 0: 

162 chunk_size = min(remaining, 16384) # Read in 16KB chunks 

163 chunk = recv(sock, chunk_size) 

164 response_body += chunk 

165 remaining -= len(chunk) 

166 else: 

167 response_body = None 

168 raise WebSocketBadStatusException( 

169 f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body.decode('utf-8', errors='replace') if response_body else None}", 

170 status, 

171 status_message, 

172 resp_headers, 

173 response_body, 

174 ) 

175 return status, resp_headers 

176 

177 

178_HEADERS_TO_CHECK = { 

179 "upgrade": "websocket", 

180 "connection": "upgrade", 

181} 

182 

183 

184def _validate(headers: dict, key: str, subprotocols: Optional[List[str]]) -> tuple: 

185 subproto = None 

186 for k, v in _HEADERS_TO_CHECK.items(): 

187 r = headers.get(k, None) 

188 if not r: 

189 return False, None 

190 r = [x.strip().lower() for x in r.split(",")] 

191 if v not in r: 

192 return False, None 

193 

194 if subprotocols: 

195 subproto = headers.get("sec-websocket-protocol", None) 

196 if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: 

197 error(f"Invalid subprotocol: {subprotocols}") 

198 return False, None 

199 subproto = subproto.lower() 

200 

201 result = headers.get("sec-websocket-accept", None) 

202 if not result: 

203 return False, None 

204 result = result.lower() 

205 

206 if isinstance(result, str): 

207 result = result.encode("utf-8") 

208 

209 value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8") 

210 hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() 

211 

212 if hmac.compare_digest(hashed, result): 

213 return True, subproto 

214 else: 

215 return False, None 

216 

217 

218def _create_sec_websocket_key() -> str: 

219 randomness = os.urandom(16) 

220 return base64encode(randomness).decode("utf-8").strip()