1"""
2_handshake.py
3websocket - WebSocket client library for Python
4
5Copyright 2024 engn33r
6
7Licensed under the Apache License, Version 2.0 (the "License");
8you may not use this file except in compliance with the License.
9You may obtain a copy of the License at
10
11 http://www.apache.org/licenses/LICENSE-2.0
12
13Unless required by applicable law or agreed to in writing, software
14distributed under the License is distributed on an "AS IS" BASIS,
15WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16See the License for the specific language governing permissions and
17limitations under the License.
18"""
19
20import hashlib
21import hmac
22import os
23from base64 import encodebytes as base64encode
24from http import HTTPStatus
25
26from ._cookiejar import SimpleCookieJar
27from ._exceptions import WebSocketException, WebSocketBadStatusException
28from ._http import read_headers
29from ._logging import dump, error
30from ._socket import send
31
32__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
33
34# websocket supported version.
35VERSION = 13
36
37SUPPORTED_REDIRECT_STATUSES = (
38 HTTPStatus.MOVED_PERMANENTLY,
39 HTTPStatus.FOUND,
40 HTTPStatus.SEE_OTHER,
41 HTTPStatus.TEMPORARY_REDIRECT,
42 HTTPStatus.PERMANENT_REDIRECT,
43)
44SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
45
46CookieJar = SimpleCookieJar()
47
48
49class handshake_response:
50 def __init__(self, status: int, headers: dict, subprotocol):
51 self.status = status
52 self.headers = headers
53 self.subprotocol = subprotocol
54 CookieJar.add(headers.get("set-cookie"))
55
56
57def handshake(
58 sock, url: str, hostname: str, port: int, resource: str, **options
59) -> handshake_response:
60 headers, key = _get_handshake_headers(resource, url, hostname, port, options)
61
62 header_str = "\r\n".join(headers)
63 send(sock, header_str)
64 dump("request header", header_str)
65
66 status, resp = _get_resp_headers(sock)
67 if status in SUPPORTED_REDIRECT_STATUSES:
68 return handshake_response(status, resp, None)
69 success, subproto = _validate(resp, key, options.get("subprotocols"))
70 if not success:
71 raise WebSocketException("Invalid WebSocket Header")
72
73 return handshake_response(status, resp, subproto)
74
75
76def _pack_hostname(hostname: str) -> str:
77 # IPv6 address
78 if ":" in hostname:
79 return f"[{hostname}]"
80 return hostname
81
82
83def _get_handshake_headers(
84 resource: str, url: str, host: str, port: int, options: dict
85) -> tuple:
86 headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
87 if port in [80, 443]:
88 hostport = _pack_hostname(host)
89 else:
90 hostport = f"{_pack_hostname(host)}:{port}"
91 if options.get("host"):
92 headers.append(f'Host: {options["host"]}')
93 else:
94 headers.append(f"Host: {hostport}")
95
96 # scheme indicates whether http or https is used in Origin
97 # The same approach is used in parse_url of _url.py to set default port
98 scheme, url = url.split(":", 1)
99 if not options.get("suppress_origin"):
100 if "origin" in options and options["origin"] is not None:
101 headers.append(f'Origin: {options["origin"]}')
102 elif scheme == "wss":
103 headers.append(f"Origin: https://{hostport}")
104 else:
105 headers.append(f"Origin: http://{hostport}")
106
107 key = _create_sec_websocket_key()
108
109 # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
110 if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
111 headers.append(f"Sec-WebSocket-Key: {key}")
112 else:
113 key = options["header"]["Sec-WebSocket-Key"]
114
115 if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
116 headers.append(f"Sec-WebSocket-Version: {VERSION}")
117
118 if not options.get("connection"):
119 headers.append("Connection: Upgrade")
120 else:
121 headers.append(options["connection"])
122
123 if subprotocols := options.get("subprotocols"):
124 headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
125
126 if header := options.get("header"):
127 if isinstance(header, dict):
128 header = [": ".join([k, v]) for k, v in header.items() if v is not None]
129 headers.extend(header)
130
131 server_cookie = CookieJar.get(host)
132 client_cookie = options.get("cookie", None)
133
134 if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
135 headers.append(f"Cookie: {cookie}")
136
137 headers.extend(("", ""))
138 return headers, key
139
140
141def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
142 status, resp_headers, status_message = read_headers(sock)
143 if status not in success_statuses:
144 content_len = resp_headers.get("content-length")
145 if content_len:
146 response_body = sock.recv(
147 int(content_len)
148 ) # read the body of the HTTP error message response and include it in the exception
149 else:
150 response_body = None
151 raise WebSocketBadStatusException(
152 f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
153 status,
154 status_message,
155 resp_headers,
156 response_body,
157 )
158 return status, resp_headers
159
160
161_HEADERS_TO_CHECK = {
162 "upgrade": "websocket",
163 "connection": "upgrade",
164}
165
166
167def _validate(headers, key: str, subprotocols) -> tuple:
168 subproto = None
169 for k, v in _HEADERS_TO_CHECK.items():
170 r = headers.get(k, None)
171 if not r:
172 return False, None
173 r = [x.strip().lower() for x in r.split(",")]
174 if v not in r:
175 return False, None
176
177 if subprotocols:
178 subproto = headers.get("sec-websocket-protocol", None)
179 if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
180 error(f"Invalid subprotocol: {subprotocols}")
181 return False, None
182 subproto = subproto.lower()
183
184 result = headers.get("sec-websocket-accept", None)
185 if not result:
186 return False, None
187 result = result.lower()
188
189 if isinstance(result, str):
190 result = result.encode("utf-8")
191
192 value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
193 hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
194
195 if hmac.compare_digest(hashed, result):
196 return True, subproto
197 else:
198 return False, None
199
200
201def _create_sec_websocket_key() -> str:
202 randomness = os.urandom(16)
203 return base64encode(randomness).decode("utf-8").strip()