Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/websocket/_handshake.py: 21%

116 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:34 +0000

1""" 

2_handshake.py 

3websocket - WebSocket client library for Python 

4 

5Copyright 2023 engn33r 

6 

7Licensed under the Apache License, Version 2.0 (the "License"); 

8you may not use this file except in compliance with the License. 

9You may obtain a copy of the License at 

10 

11 http://www.apache.org/licenses/LICENSE-2.0 

12 

13Unless required by applicable law or agreed to in writing, software 

14distributed under the License is distributed on an "AS IS" BASIS, 

15WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

16See the License for the specific language governing permissions and 

17limitations under the License. 

18""" 

19import hashlib 

20import hmac 

21import os 

22from base64 import encodebytes as base64encode 

23from http import client as HTTPStatus 

24from ._cookiejar import SimpleCookieJar 

25from ._exceptions import * 

26from ._http import * 

27from ._logging import * 

28from ._socket import * 

29 

30__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"] 

31 

32# websocket supported version. 

33VERSION = 13 

34 

35SUPPORTED_REDIRECT_STATUSES = (HTTPStatus.MOVED_PERMANENTLY, HTTPStatus.FOUND, HTTPStatus.SEE_OTHER, HTTPStatus.TEMPORARY_REDIRECT, HTTPStatus.PERMANENT_REDIRECT) 

36SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,) 

37 

38CookieJar = SimpleCookieJar() 

39 

40 

41class handshake_response: 

42 

43 def __init__(self, status: int, headers: dict, subprotocol): 

44 self.status = status 

45 self.headers = headers 

46 self.subprotocol = subprotocol 

47 CookieJar.add(headers.get("set-cookie")) 

48 

49 

50def handshake(sock, url: str, hostname: str, port: int, resource: str, **options): 

51 headers, key = _get_handshake_headers(resource, url, hostname, port, options) 

52 

53 header_str = "\r\n".join(headers) 

54 send(sock, header_str) 

55 dump("request header", header_str) 

56 

57 status, resp = _get_resp_headers(sock) 

58 if status in SUPPORTED_REDIRECT_STATUSES: 

59 return handshake_response(status, resp, None) 

60 success, subproto = _validate(resp, key, options.get("subprotocols")) 

61 if not success: 

62 raise WebSocketException("Invalid WebSocket Header") 

63 

64 return handshake_response(status, resp, subproto) 

65 

66 

67def _pack_hostname(hostname: str) -> str: 

68 # IPv6 address 

69 if ':' in hostname: 

70 return '[' + hostname + ']' 

71 

72 return hostname 

73 

74 

75def _get_handshake_headers(resource: str, url: str, host: str, port: int, options: dict): 

76 headers = [ 

77 "GET {resource} HTTP/1.1".format(resource=resource), 

78 "Upgrade: websocket" 

79 ] 

80 if port == 80 or port == 443: 

81 hostport = _pack_hostname(host) 

82 else: 

83 hostport = "{h}:{p}".format(h=_pack_hostname(host), p=port) 

84 if options.get("host"): 

85 headers.append("Host: {h}".format(h=options["host"])) 

86 else: 

87 headers.append("Host: {hp}".format(hp=hostport)) 

88 

89 # scheme indicates whether http or https is used in Origin 

90 # The same approach is used in parse_url of _url.py to set default port 

91 scheme, url = url.split(":", 1) 

92 if not options.get("suppress_origin"): 

93 if "origin" in options and options["origin"] is not None: 

94 headers.append("Origin: {origin}".format(origin=options["origin"])) 

95 elif scheme == "wss": 

96 headers.append("Origin: https://{hp}".format(hp=hostport)) 

97 else: 

98 headers.append("Origin: http://{hp}".format(hp=hostport)) 

99 

100 key = _create_sec_websocket_key() 

101 

102 # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified 

103 if not options.get('header') or 'Sec-WebSocket-Key' not in options['header']: 

104 headers.append("Sec-WebSocket-Key: {key}".format(key=key)) 

105 else: 

106 key = options['header']['Sec-WebSocket-Key'] 

107 

108 if not options.get('header') or 'Sec-WebSocket-Version' not in options['header']: 

109 headers.append("Sec-WebSocket-Version: {version}".format(version=VERSION)) 

110 

111 if not options.get('connection'): 

112 headers.append('Connection: Upgrade') 

113 else: 

114 headers.append(options['connection']) 

115 

116 subprotocols = options.get("subprotocols") 

117 if subprotocols: 

118 headers.append("Sec-WebSocket-Protocol: {protocols}".format(protocols=",".join(subprotocols))) 

119 

120 header = options.get("header") 

121 if header: 

122 if isinstance(header, dict): 

123 header = [ 

124 ": ".join([k, v]) 

125 for k, v in header.items() 

126 if v is not None 

127 ] 

128 headers.extend(header) 

129 

130 server_cookie = CookieJar.get(host) 

131 client_cookie = options.get("cookie", None) 

132 

133 cookie = "; ".join(filter(None, [server_cookie, client_cookie])) 

134 

135 if cookie: 

136 headers.append("Cookie: {cookie}".format(cookie=cookie)) 

137 

138 headers.extend(("", "")) 

139 return headers, key 

140 

141 

142def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple: 

143 status, resp_headers, status_message = read_headers(sock) 

144 if status not in success_statuses: 

145 content_len = resp_headers.get('content-length') 

146 if content_len: 

147 response_body = sock.recv(int(content_len)) # read the body of the HTTP error message response and include it in the exception 

148 else: 

149 response_body = None 

150 raise WebSocketBadStatusException("Handshake status {status} {message} -+-+- {headers} -+-+- {body}".format(status=status, message=status_message, headers=resp_headers, body=response_body), status, status_message, resp_headers, response_body) 

151 return status, resp_headers 

152 

153 

154_HEADERS_TO_CHECK = { 

155 "upgrade": "websocket", 

156 "connection": "upgrade", 

157} 

158 

159 

160def _validate(headers, key: str, subprotocols): 

161 subproto = None 

162 for k, v in _HEADERS_TO_CHECK.items(): 

163 r = headers.get(k, None) 

164 if not r: 

165 return False, None 

166 r = [x.strip().lower() for x in r.split(',')] 

167 if v not in r: 

168 return False, None 

169 

170 if subprotocols: 

171 subproto = headers.get("sec-websocket-protocol", None) 

172 if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]: 

173 error("Invalid subprotocol: " + str(subprotocols)) 

174 return False, None 

175 subproto = subproto.lower() 

176 

177 result = headers.get("sec-websocket-accept", None) 

178 if not result: 

179 return False, None 

180 result = result.lower() 

181 

182 if isinstance(result, str): 

183 result = result.encode('utf-8') 

184 

185 value = (key + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").encode('utf-8') 

186 hashed = base64encode(hashlib.sha1(value).digest()).strip().lower() 

187 success = hmac.compare_digest(hashed, result) 

188 

189 if success: 

190 return True, subproto 

191 else: 

192 return False, None 

193 

194 

195def _create_sec_websocket_key() -> str: 

196 randomness = os.urandom(16) 

197 return base64encode(randomness).decode('utf-8').strip()