1"""
2_handshake.py
3websocket - WebSocket client library for Python
4
5Copyright 2025 engn33r
6
7Licensed under the Apache License, Version 2.0 (the "License");
8you may not use this file except in compliance with the License.
9You may obtain a copy of the License at
10
11 http://www.apache.org/licenses/LICENSE-2.0
12
13Unless required by applicable law or agreed to in writing, software
14distributed under the License is distributed on an "AS IS" BASIS,
15WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16See the License for the specific language governing permissions and
17limitations under the License.
18"""
19
20import hashlib
21import hmac
22import os
23from base64 import encodebytes as base64encode
24from http import HTTPStatus
25
26from ._cookiejar import SimpleCookieJar
27from ._exceptions import WebSocketException, WebSocketBadStatusException
28from ._http import read_headers
29from ._logging import dump, error
30from ._socket import send
31
32__all__ = ["handshake_response", "handshake", "SUPPORTED_REDIRECT_STATUSES"]
33
34# websocket supported version.
35VERSION = 13
36
37SUPPORTED_REDIRECT_STATUSES = (
38 HTTPStatus.MOVED_PERMANENTLY,
39 HTTPStatus.FOUND,
40 HTTPStatus.SEE_OTHER,
41 HTTPStatus.TEMPORARY_REDIRECT,
42 HTTPStatus.PERMANENT_REDIRECT,
43)
44SUCCESS_STATUSES = SUPPORTED_REDIRECT_STATUSES + (HTTPStatus.SWITCHING_PROTOCOLS,)
45
46CookieJar = SimpleCookieJar()
47
48
49class handshake_response:
50 def __init__(self, status: int, headers: dict, subprotocol):
51 self.status = status
52 self.headers = headers
53 self.subprotocol = subprotocol
54 CookieJar.add(headers.get("set-cookie"))
55
56
57def handshake(
58 sock, url: str, hostname: str, port: int, resource: str, **options
59) -> handshake_response:
60 headers, key = _get_handshake_headers(resource, url, hostname, port, options)
61
62 header_str = "\r\n".join(headers)
63 send(sock, header_str)
64 dump("request header", header_str)
65
66 status, resp = _get_resp_headers(sock)
67 if status in SUPPORTED_REDIRECT_STATUSES:
68 return handshake_response(status, resp, None)
69 success, subproto = _validate(resp, key, options.get("subprotocols"))
70 if not success:
71 raise WebSocketException("Invalid WebSocket Header")
72
73 return handshake_response(status, resp, subproto)
74
75
76def _pack_hostname(hostname: str) -> str:
77 # IPv6 address
78 if ":" in hostname:
79 return f"[{hostname}]"
80 return hostname
81
82
83def _get_handshake_headers(
84 resource: str, url: str, host: str, port: int, options: dict
85) -> tuple:
86 headers = [f"GET {resource} HTTP/1.1", "Upgrade: websocket"]
87 if port in [80, 443]:
88 hostport = _pack_hostname(host)
89 else:
90 hostport = f"{_pack_hostname(host)}:{port}"
91 if options.get("host"):
92 headers.append(f'Host: {options["host"]}')
93 else:
94 headers.append(f"Host: {hostport}")
95
96 # scheme indicates whether http or https is used in Origin
97 # The same approach is used in parse_url of _url.py to set default port
98 scheme, url = url.split(":", 1)
99 if not options.get("suppress_origin"):
100 if "origin" in options and options["origin"] is not None:
101 headers.append(f'Origin: {options["origin"]}')
102 elif scheme == "wss":
103 headers.append(f"Origin: https://{hostport}")
104 else:
105 headers.append(f"Origin: http://{hostport}")
106
107 key = _create_sec_websocket_key()
108
109 # Append Sec-WebSocket-Key & Sec-WebSocket-Version if not manually specified
110 if not options.get("header") or "Sec-WebSocket-Key" not in options["header"]:
111 headers.append(f"Sec-WebSocket-Key: {key}")
112 else:
113 key = options["header"]["Sec-WebSocket-Key"]
114
115 if not options.get("header") or "Sec-WebSocket-Version" not in options["header"]:
116 headers.append(f"Sec-WebSocket-Version: {VERSION}")
117
118 if not options.get("connection"):
119 headers.append("Connection: Upgrade")
120 else:
121 headers.append(options["connection"])
122
123 if subprotocols := options.get("subprotocols"):
124 headers.append(f'Sec-WebSocket-Protocol: {",".join(subprotocols)}')
125
126 if header := options.get("header"):
127 if isinstance(header, dict):
128 header = [": ".join([k, v]) for k, v in header.items() if v is not None]
129 headers.extend(header)
130
131 server_cookie = CookieJar.get(host)
132 client_cookie = options.get("cookie", None)
133
134 if cookie := "; ".join(filter(None, [server_cookie, client_cookie])):
135 headers.append(f"Cookie: {cookie}")
136
137 headers.extend(("", ""))
138 return headers, key
139
140
141def _get_resp_headers(sock, success_statuses: tuple = SUCCESS_STATUSES) -> tuple:
142 status, resp_headers, status_message = read_headers(sock)
143 if status not in success_statuses:
144 content_len = resp_headers.get("content-length")
145 if content_len:
146 # Use chunked reading to avoid SSL BAD_LENGTH error on large responses
147 from ._socket import recv
148
149 response_body = b""
150 remaining = int(content_len)
151 while remaining > 0:
152 chunk_size = min(remaining, 16384) # Read in 16KB chunks
153 chunk = recv(sock, chunk_size)
154 response_body += chunk
155 remaining -= len(chunk)
156 else:
157 response_body = None
158 raise WebSocketBadStatusException(
159 f"Handshake status {status} {status_message} -+-+- {resp_headers} -+-+- {response_body}",
160 status,
161 status_message,
162 resp_headers,
163 response_body,
164 )
165 return status, resp_headers
166
167
168_HEADERS_TO_CHECK = {
169 "upgrade": "websocket",
170 "connection": "upgrade",
171}
172
173
174def _validate(headers, key: str, subprotocols) -> tuple:
175 subproto = None
176 for k, v in _HEADERS_TO_CHECK.items():
177 r = headers.get(k, None)
178 if not r:
179 return False, None
180 r = [x.strip().lower() for x in r.split(",")]
181 if v not in r:
182 return False, None
183
184 if subprotocols:
185 subproto = headers.get("sec-websocket-protocol", None)
186 if not subproto or subproto.lower() not in [s.lower() for s in subprotocols]:
187 error(f"Invalid subprotocol: {subprotocols}")
188 return False, None
189 subproto = subproto.lower()
190
191 result = headers.get("sec-websocket-accept", None)
192 if not result:
193 return False, None
194 result = result.lower()
195
196 if isinstance(result, str):
197 result = result.encode("utf-8")
198
199 value = f"{key}258EAFA5-E914-47DA-95CA-C5AB0DC85B11".encode("utf-8")
200 hashed = base64encode(hashlib.sha1(value).digest()).strip().lower()
201
202 if hmac.compare_digest(hashed, result):
203 return True, subproto
204 else:
205 return False, None
206
207
208def _create_sec_websocket_key() -> str:
209 randomness = os.urandom(16)
210 return base64encode(randomness).decode("utf-8").strip()