1import errno
2import selectors
3import socket
4from typing import Optional, Union
5
6from ._exceptions import (
7 WebSocketConnectionClosedException,
8 WebSocketTimeoutException,
9)
10from ._ssl_compat import SSLError, SSLEOFError, SSLWantReadError, SSLWantWriteError
11from ._utils import extract_error_code, extract_err_message
12
13"""
14_socket.py
15websocket - WebSocket client library for Python
16
17Copyright 2024 engn33r
18
19Licensed under the Apache License, Version 2.0 (the "License");
20you may not use this file except in compliance with the License.
21You may obtain a copy of the License at
22
23 http://www.apache.org/licenses/LICENSE-2.0
24
25Unless required by applicable law or agreed to in writing, software
26distributed under the License is distributed on an "AS IS" BASIS,
27WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28See the License for the specific language governing permissions and
29limitations under the License.
30"""
31
32DEFAULT_SOCKET_OPTION = [(socket.SOL_TCP, socket.TCP_NODELAY, 1)]
33if hasattr(socket, "SO_KEEPALIVE"):
34 DEFAULT_SOCKET_OPTION.append((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1))
35if hasattr(socket, "TCP_KEEPIDLE"):
36 DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPIDLE, 30))
37if hasattr(socket, "TCP_KEEPINTVL"):
38 DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPINTVL, 10))
39if hasattr(socket, "TCP_KEEPCNT"):
40 DEFAULT_SOCKET_OPTION.append((socket.SOL_TCP, socket.TCP_KEEPCNT, 3))
41
42_default_timeout = None
43
44__all__ = [
45 "DEFAULT_SOCKET_OPTION",
46 "sock_opt",
47 "setdefaulttimeout",
48 "getdefaulttimeout",
49 "recv",
50 "recv_line",
51 "send",
52]
53
54
55class sock_opt:
56 def __init__(self, sockopt: Optional[list], sslopt: Optional[dict]) -> None:
57 if sockopt is None:
58 sockopt = []
59 if sslopt is None:
60 sslopt = {}
61 self.sockopt = sockopt
62 self.sslopt = sslopt
63 self.timeout: Union[int, float, None] = None
64
65
66def setdefaulttimeout(timeout: Union[int, float, None]) -> None:
67 """
68 Set the global timeout setting to connect.
69
70 Parameters
71 ----------
72 timeout: int or float
73 default socket timeout time (in seconds)
74 """
75 global _default_timeout
76 _default_timeout = timeout
77
78
79def getdefaulttimeout() -> Union[int, float, None]:
80 """
81 Get default timeout
82
83 Returns
84 ----------
85 _default_timeout: int or float
86 Return the global timeout setting (in seconds) to connect.
87 """
88 return _default_timeout
89
90
91def recv(sock: socket.socket, bufsize: int) -> bytes:
92 if not sock:
93 raise WebSocketConnectionClosedException("socket is already closed.")
94
95 def _recv():
96 try:
97 return sock.recv(bufsize)
98 except SSLWantReadError:
99 pass
100 except socket.error as exc:
101 error_code = extract_error_code(exc)
102 if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
103 raise
104
105 sel = selectors.DefaultSelector()
106 sel.register(sock, selectors.EVENT_READ)
107
108 r = sel.select(sock.gettimeout())
109 sel.close()
110
111 if r:
112 return sock.recv(bufsize)
113
114 try:
115 if sock.gettimeout() == 0:
116 bytes_ = sock.recv(bufsize)
117 else:
118 bytes_ = _recv()
119 except TimeoutError:
120 raise WebSocketTimeoutException("Connection timed out")
121 except socket.timeout as e:
122 message = extract_err_message(e)
123 raise WebSocketTimeoutException(message)
124 except SSLError as e:
125 message = extract_err_message(e)
126 if isinstance(message, str) and "timed out" in message:
127 raise WebSocketTimeoutException(message)
128 else:
129 raise
130
131 if not bytes_:
132 raise WebSocketConnectionClosedException("Connection to remote host was lost.")
133
134 return bytes_
135
136
137def recv_line(sock: socket.socket) -> bytes:
138 line = []
139 while True:
140 c = recv(sock, 1)
141 line.append(c)
142 if c == b"\n":
143 break
144 return b"".join(line)
145
146
147def send(sock: socket.socket, data: Union[bytes, str]) -> int:
148 if isinstance(data, str):
149 data = data.encode("utf-8")
150
151 if not sock:
152 raise WebSocketConnectionClosedException("socket is already closed.")
153
154 def _send():
155 try:
156 return sock.send(data)
157 except SSLEOFError:
158 raise WebSocketConnectionClosedException("socket is already closed.")
159 except SSLWantWriteError:
160 pass
161 except socket.error as exc:
162 error_code = extract_error_code(exc)
163 if error_code is None:
164 raise
165 if error_code not in [errno.EAGAIN, errno.EWOULDBLOCK]:
166 raise
167
168 sel = selectors.DefaultSelector()
169 sel.register(sock, selectors.EVENT_WRITE)
170
171 w = sel.select(sock.gettimeout())
172 sel.close()
173
174 if w:
175 return sock.send(data)
176
177 try:
178 if sock.gettimeout() == 0:
179 return sock.send(data)
180 else:
181 return _send()
182 except socket.timeout as e:
183 message = extract_err_message(e)
184 raise WebSocketTimeoutException(message)
185 except Exception as e:
186 message = extract_err_message(e)
187 if isinstance(message, str) and "timed out" in message:
188 raise WebSocketTimeoutException(message)
189 else:
190 raise