1from __future__ import annotations
2
3import functools
4import socket
5import ssl
6import sys
7import typing
8
9from .._exceptions import (
10 ConnectError,
11 ConnectTimeout,
12 ExceptionMapping,
13 ReadError,
14 ReadTimeout,
15 WriteError,
16 WriteTimeout,
17 map_exceptions,
18)
19from .._utils import is_socket_readable
20from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
21
22
23class TLSinTLSStream(NetworkStream): # pragma: no cover
24 """
25 Because the standard `SSLContext.wrap_socket` method does
26 not work for `SSLSocket` objects, we need this class
27 to implement TLS stream using an underlying `SSLObject`
28 instance in order to support TLS on top of TLS.
29 """
30
31 # Defined in RFC 8449
32 TLS_RECORD_SIZE = 16384
33
34 def __init__(
35 self,
36 sock: socket.socket,
37 ssl_context: ssl.SSLContext,
38 server_hostname: str | None = None,
39 timeout: float | None = None,
40 ):
41 self._sock = sock
42 self._incoming = ssl.MemoryBIO()
43 self._outgoing = ssl.MemoryBIO()
44
45 self.ssl_obj = ssl_context.wrap_bio(
46 incoming=self._incoming,
47 outgoing=self._outgoing,
48 server_hostname=server_hostname,
49 )
50
51 self._sock.settimeout(timeout)
52 self._perform_io(self.ssl_obj.do_handshake)
53
54 def _perform_io(
55 self,
56 func: typing.Callable[..., typing.Any],
57 ) -> typing.Any:
58 ret = None
59
60 while True:
61 errno = None
62 try:
63 ret = func()
64 except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
65 errno = e.errno
66
67 self._sock.sendall(self._outgoing.read())
68
69 if errno == ssl.SSL_ERROR_WANT_READ:
70 buf = self._sock.recv(self.TLS_RECORD_SIZE)
71
72 if buf:
73 self._incoming.write(buf)
74 else:
75 self._incoming.write_eof()
76 if errno is None:
77 return ret
78
79 def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
80 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
81 with map_exceptions(exc_map):
82 self._sock.settimeout(timeout)
83 return typing.cast(
84 bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes))
85 )
86
87 def write(self, buffer: bytes, timeout: float | None = None) -> None:
88 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
89 with map_exceptions(exc_map):
90 self._sock.settimeout(timeout)
91 while buffer:
92 nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer))
93 buffer = buffer[nsent:]
94
95 def close(self) -> None:
96 self._sock.close()
97
98 def start_tls(
99 self,
100 ssl_context: ssl.SSLContext,
101 server_hostname: str | None = None,
102 timeout: float | None = None,
103 ) -> NetworkStream:
104 raise NotImplementedError()
105
106 def get_extra_info(self, info: str) -> typing.Any:
107 if info == "ssl_object":
108 return self.ssl_obj
109 if info == "client_addr":
110 return self._sock.getsockname()
111 if info == "server_addr":
112 return self._sock.getpeername()
113 if info == "socket":
114 return self._sock
115 if info == "is_readable":
116 return is_socket_readable(self._sock)
117 return None
118
119
120class SyncStream(NetworkStream):
121 def __init__(self, sock: socket.socket) -> None:
122 self._sock = sock
123
124 def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
125 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
126 with map_exceptions(exc_map):
127 self._sock.settimeout(timeout)
128 return self._sock.recv(max_bytes)
129
130 def write(self, buffer: bytes, timeout: float | None = None) -> None:
131 if not buffer:
132 return
133
134 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
135 with map_exceptions(exc_map):
136 while buffer:
137 self._sock.settimeout(timeout)
138 n = self._sock.send(buffer)
139 buffer = buffer[n:]
140
141 def close(self) -> None:
142 self._sock.close()
143
144 def start_tls(
145 self,
146 ssl_context: ssl.SSLContext,
147 server_hostname: str | None = None,
148 timeout: float | None = None,
149 ) -> NetworkStream:
150 exc_map: ExceptionMapping = {
151 socket.timeout: ConnectTimeout,
152 OSError: ConnectError,
153 }
154 with map_exceptions(exc_map):
155 try:
156 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
157 # If the underlying socket has already been upgraded
158 # to the TLS layer (i.e. is an instance of SSLSocket),
159 # we need some additional smarts to support TLS-in-TLS.
160 return TLSinTLSStream(
161 self._sock, ssl_context, server_hostname, timeout
162 )
163 else:
164 self._sock.settimeout(timeout)
165 sock = ssl_context.wrap_socket(
166 self._sock, server_hostname=server_hostname
167 )
168 except Exception as exc: # pragma: nocover
169 self.close()
170 raise exc
171 return SyncStream(sock)
172
173 def get_extra_info(self, info: str) -> typing.Any:
174 if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
175 return self._sock._sslobj # type: ignore
176 if info == "client_addr":
177 return self._sock.getsockname()
178 if info == "server_addr":
179 return self._sock.getpeername()
180 if info == "socket":
181 return self._sock
182 if info == "is_readable":
183 return is_socket_readable(self._sock)
184 return None
185
186
187class SyncBackend(NetworkBackend):
188 def connect_tcp(
189 self,
190 host: str,
191 port: int,
192 timeout: float | None = None,
193 local_address: str | None = None,
194 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
195 ) -> NetworkStream:
196 # Note that we automatically include `TCP_NODELAY`
197 # in addition to any other custom socket options.
198 if socket_options is None:
199 socket_options = [] # pragma: no cover
200 address = (host, port)
201 source_address = None if local_address is None else (local_address, 0)
202 exc_map: ExceptionMapping = {
203 socket.timeout: ConnectTimeout,
204 OSError: ConnectError,
205 }
206
207 with map_exceptions(exc_map):
208 sock = socket.create_connection(
209 address,
210 timeout,
211 source_address=source_address,
212 )
213 for option in socket_options:
214 sock.setsockopt(*option) # pragma: no cover
215 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
216 return SyncStream(sock)
217
218 def connect_unix_socket(
219 self,
220 path: str,
221 timeout: float | None = None,
222 socket_options: typing.Iterable[SOCKET_OPTION] | None = None,
223 ) -> NetworkStream: # pragma: nocover
224 if sys.platform == "win32":
225 raise RuntimeError(
226 "Attempted to connect to a UNIX socket on a Windows system."
227 )
228 if socket_options is None:
229 socket_options = []
230
231 exc_map: ExceptionMapping = {
232 socket.timeout: ConnectTimeout,
233 OSError: ConnectError,
234 }
235 with map_exceptions(exc_map):
236 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
237 for option in socket_options:
238 sock.setsockopt(*option)
239 sock.settimeout(timeout)
240 sock.connect(path)
241 return SyncStream(sock)