Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/websocket/_handshake.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

113 statements  

1""" 

2_handshake.py 

3websocket - WebSocket client library for Python 

4 

5Copyright 2024 engn33r 

6 

7Licensed under the Apache License, Version 2.0 (the "License"); 

8you may not use this file except in compliance with the License. 

9You may obtain a copy of the License at 

10 

11 http://www.apache.org/licenses/LICENSE-2.0 

12 

13Unless required by applicable law or agreed to in writing, software 

14distributed under the License is distributed on an "AS IS" BASIS, 

15WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

16See the License for the specific language governing permissions and 

17limitations under the License. 

18""" 

19 

20import hashlib 

21import hmac 

22import os 

23from base64 import encodebytes as base64encode 

24from http import HTTPStatus 

25 

26from ._cookiejar import SimpleCookieJar 

27from ._exceptions import WebSocketException, WebSocketBadStatusException 

28from ._http import read_headers 

29from ._logging import dump, error 

30from ._socket import send 

31 

32__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] 

33 

34# websocket supported version. 

35VERSION = 13 

36 

37SUPPORTED_REDIRECT_STATUSES = ( 

38 HTTPStatus.MOVED_PERMANENTLY, 

39 HTTPStatus.FOUND, 

40 HTTPStatus.SEE_OTHER, 

41 HTTPStatus.TEMPORARY_REDIRECT, 

42 HTTPStatus.PERMANENT_REDIRECT, 

43) 

44SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) 

45 

46CookieJar = SimpleCookieJar() 

47 

48 

49class handshake_response: 

50 def __init__(self, status: int, headers: dict, subprotocol): 

51 self.status = status 

52 self.headers = headers 

53 self.subprotocol = subprotocol 

54 CookieJar.add(headers.get("set-cookie")) 

55 

56 

57def handshake( 

58 sock, url: str, hostname: str, port: int, resource: str, **options 

59) -> handshake_response: 

60 headers, key = _get_handshake_headers(resource, url, hostname, port, options) 

61 

62 header_str = "\r\n".join(headers) 

63 send(sock, header_str) 

64 dump("request header", header_str) 

65 

66 status, resp = _get_resp_headers(sock) 

67 if status in SUPPORTED_REDIRECT_STATUSES: 

68 return handshake_response(status, resp, None) 

69 success, subproto = _validate(resp, key, options.get("subprotocols")) 

70 if not success: 

71 raise WebSocketException("Invalid WebSocket Header") 

72 

73 return handshake_response(status, resp, subproto) 

74 

75 

76def _pack_hostname(hostname: str) -> str: 

77 # IPv6 address 

78 if ":" in hostname: 

79 return f"[{hostname}]" 

80 return hostname 

81 

82 

83def _get_handshake_headers( 

84 resource: str, url: str, host: str, port: int, options: dict 

85) -> tuple: 

86 headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"] 

87 if port in [80, 443]: 

88 hostport = _pack_hostname(host) 

89 else: 

90 hostport = f"{_pack_hostname(host)}:{port}" 

91 if options.get("host"): 

92 headers.append(f'Host: {options["host"]}') 

93 else: 

94 headers.append(f"Host: {hostport}") 

95 

96 # scheme indicates whether http or https is used in Origin 

97 # The same approach is used in parse_url of _url.py to set default port 

98 scheme, url = url.split(":", 1) 

99 if not options.get("suppress_origin"): 

100 if "origin" in options and options["origin"] is not None: 

101 headers.append(f'Origin: {options["origin"]}') 

102 elif scheme == "wss": 

103 headers.append(f"Origin: https://{hostport}") 

104 else: 

105 headers.append(f"Origin: http://{hostport}") 

106 

107 key = _create_sec_websocket_key() 

108 

109 # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified 

110 if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]: 

111 headers.append(f"Sec-WebSocket-Key: {key}") 

112 else: 

113 key = options["header"]["Sec-WebSocket-Key"] 

114 

115 if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]: 

116 headers.append(f"Sec-WebSocket-Version: {VERSION}") 

117 

118 if not options.get("connection"): 

119 headers.append("Connection: Upgrade") 

120 else: 

121 headers.append(options["connection"]) 

122 

123 if subprotocols := options.get("subprotocols"): 

124 headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}') 

125 

126 if header := options.get("header"): 

127 if isinstance(header, dict): 

128 header = [": ".join([k, v]) for k, v in header.items() if v is not None] 

129 headers.extend(header) 

130 

131 server_cookie = CookieJar.get(host) 

132 client_cookie = options.get("cookie", None) 

133 

134 if cookie := "; ".join(filter(None, [server_cookie, client_cookie])): 

135 headers.append(f"Cookie: {cookie}") 

136 

137 headers.extend(("", "")) 

138 return headers, key 

139 

140 

141def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple: 

142 status, resp_headers, status_message = read_headers(sock) 

143 if status not in success_statuses: 

144 content_len = resp_headers.get("content-length") 

145 if content_len: 

146 response_body = sock.recv( 

147 int(content_len) 

148 ) # read the body of the HTTP error message response and include it in the exception 

149 else: 

150 response_body = None 

151 raise WebSocketBadStatusException( 

152 f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}", 

153 status, 

154 status_message, 

155 resp_headers, 

156 response_body, 

157 ) 

158 return status, resp_headers 

159 

160 

161_HEADERS_TO_CHECK = { 

162 "upgrade": "websocket", 

163 "connection": "upgrade", 

164} 

165 

166 

167def _validate(headers, key: str, subprotocols) -> tuple: 

168 subproto = None 

169 for k, v in _HEADERS_TO_CHECK.items(): 

170 r = headers.get(k, None) 

171 if not r: 

172 return False, None 

173 r = [x.strip().lower() for x in r.split(",")] 

174 if v not in r: 

175 return False, None 

176 

177 if subprotocols: 

178 subproto = headers.get("sec-websocket-protocol", None) 

179 if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: 

180 error(f"Invalid subprotocol: {subprotocols}") 

181 return False, None 

182 subproto = subproto.lower() 

183 

184 result = headers.get("sec-websocket-accept", None) 

185 if not result: 

186 return False, None 

187 result = result.lower() 

188 

189 if isinstance(result, str): 

190 result = result.encode("utf-8") 

191 

192 value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8") 

193 hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() 

194 

195 if hmac.compare_digest(hashed, result): 

196 return True, subproto 

197 else: 

198 return False, None 

199 

200 

201def _create_sec_websocket_key() -> str: 

202 randomness = os.urandom(16) 

203 return base64encode(randomness).decode("utf-8").strip()