1import errno
2import selectors
3import socket
4from typing import Optional, Union, Any
5
6from ._exceptions import (
7 WebSocketConnectionClosedException,
8 WebSocketTimeoutException,
9)
10from ._ssl_compat import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
11from ._utils import extract_error_code, extract_err_message
12
13"""
14_socket.py
15websocket - WebSocket client library for Python
16
17Copyright 2025 engn33r
18
19Licensed under the Apache License, Version 2.0 (the "License");
20you may not use this file except in compliance with the License.
21You may obtain a copy of the License at
22
23 http://www.apache.org/licenses/LICENSE-2.0
24
25Unless required by applicable law or agreed to in writing, software
26distributed under the License is distributed on an "AS IS" BASIS,
27WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28See the License for the specific language governing permissions and
29limitations under the License.
30"""
31
32DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
33if hasattr(socket, "SO_KEEPALIVE"):
34 DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
35if hasattr(socket, "TCP_KEEPIDLE"):
36 DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
37if hasattr(socket, "TCP_KEEPINTVL"):
38 DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
39if hasattr(socket, "TCP_KEEPCNT"):
40 DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
41
42_default_timeout = None
43
44__all__ = [
45 "DEFAULT_SOCKET_OPTION",
46 "sock_opt",
47 "setdefaulttimeout",
48 "getdefaulttimeout",
49 "recv",
50 "recv_line",
51 "send",
52]
53
54
55class sock_opt:
56 def __init__(
57 self, sockopt: Optional[list[tuple]], sslopt: Optional[dict[str, Any]]
58 ) -> None:
59 if sockopt is None:
60 sockopt = []
61 if sslopt is None:
62 sslopt = {}
63 self.sockopt = sockopt
64 self.sslopt = sslopt
65 self.timeout: Optional[Union[int, float]] = None
66
67
68def setdefaulttimeout(timeout: Optional[Union[int, float]]) -> None:
69 """
70 Set the global timeout setting to connect.
71
72 Parameters
73 ----------
74 timeout: int or float
75 default socket timeout time (in seconds)
76 """
77 global _default_timeout
78 _default_timeout = timeout
79
80
81def getdefaulttimeout() -> Optional[Union[int, float]]:
82 """
83 Get default timeout
84
85 Returns
86 ----------
87 _default_timeout: int or float
88 Return the global timeout setting (in seconds) to connect.
89 """
90 return _default_timeout
91
92
93def recv(sock: socket.socket, bufsize: int) -> bytes:
94 if not sock:
95 raise WebSocketConnectionClosedException("socket is already closed.")
96
97 def _recv():
98 try:
99 return sock.recv(bufsize)
100 except SSLWantReadError:
101 # Don't return None implicitly - fall through to retry logic
102 pass
103 except socket.error as exc:
104 error_code = extract_error_code(exc)
105 if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
106 raise
107 # Don't return None implicitly - fall through to retry logic
108
109 # Retry logic using selector for both SSLWantReadError and EAGAIN/EWOULDBLOCK
110 sel = selectors.DefaultSelector()
111 sel.register(sock, selectors.EVENT_READ)
112
113 r = sel.select(sock.gettimeout())
114 sel.close()
115
116 if r:
117 return sock.recv(bufsize)
118 else:
119 # Selector timeout should raise WebSocketTimeoutException
120 # not return None which gets misclassified as connection closed
121 raise WebSocketTimeoutException("Connection timed out waiting for data")
122
123 try:
124 if sock.gettimeout() == 0:
125 bytes_ = sock.recv(bufsize)
126 else:
127 bytes_ = _recv()
128 except TimeoutError:
129 raise WebSocketTimeoutException("Connection timed out")
130 except socket.timeout as e:
131 message = extract_err_message(e)
132 raise WebSocketTimeoutException(message)
133 except SSLError as e:
134 message = extract_err_message(e)
135 if isinstance(message, str) and "timed out" in message:
136 raise WebSocketTimeoutException(message)
137 else:
138 raise
139
140 if bytes_ is None:
141 raise WebSocketConnectionClosedException("Connection to remote host was lost.")
142 if not bytes_:
143 raise WebSocketConnectionClosedException("Connection to remote host was lost.")
144
145 return bytes_
146
147
148def recv_line(sock: socket.socket) -> bytes:
149 line = []
150 while True:
151 c = recv(sock, 1)
152 line.append(c)
153 if c == b"\n":
154 break
155 return b"".join(line)
156
157
158def send(sock: socket.socket, data: Union[bytes, str]) -> int:
159 if isinstance(data, str):
160 data = data.encode("utf-8")
161
162 if not sock:
163 raise WebSocketConnectionClosedException("socket is already closed.")
164
165 def _send() -> int:
166 try:
167 return sock.send(data)
168 except SSLEOFError:
169 raise WebSocketConnectionClosedException("socket is already closed.")
170 except SSLWantWriteError:
171 pass
172 except socket.error as exc:
173 error_code = extract_error_code(exc)
174 if error_code is None:
175 raise
176 if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
177 raise
178
179 sel = selectors.DefaultSelector()
180 sel.register(sock, selectors.EVENT_WRITE)
181
182 w = sel.select(sock.gettimeout())
183 sel.close()
184
185 if w:
186 return sock.send(data)
187 return 0
188
189 try:
190 if sock.gettimeout() == 0:
191 return sock.send(data)
192 else:
193 return _send()
194 except socket.timeout as e:
195 message = extract_err_message(e)
196 raise WebSocketTimeoutException(message)
197 except (OSError, SSLError) as e:
198 message = extract_err_message(e)
199 if isinstance(message, str) and "timed out" in message:
200 raise WebSocketTimeoutException(message)
201 else:
202 raise