1import socket
2import ssl
3import sys
4import typing
5from functools import partial
6
7from .._exceptions import (
8 ConnectError,
9 ConnectTimeout,
10 ExceptionMapping,
11 ReadError,
12 ReadTimeout,
13 WriteError,
14 WriteTimeout,
15 map_exceptions,
16)
17from .._utils import is_socket_readable
18from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
19
20
21class TLSinTLSStream(NetworkStream): # pragma: no cover
22 """
23 Because the standard `SSLContext.wrap_socket` method does
24 not work for `SSLSocket` objects, we need this class
25 to implement TLS stream using an underlying `SSLObject`
26 instance in order to support TLS on top of TLS.
27 """
28
29 # Defined in RFC 8449
30 TLS_RECORD_SIZE = 16384
31
32 def __init__(
33 self,
34 sock: socket.socket,
35 ssl_context: ssl.SSLContext,
36 server_hostname: typing.Optional[str] = None,
37 timeout: typing.Optional[float] = None,
38 ):
39 self._sock = sock
40 self._incoming = ssl.MemoryBIO()
41 self._outgoing = ssl.MemoryBIO()
42
43 self.ssl_obj = ssl_context.wrap_bio(
44 incoming=self._incoming,
45 outgoing=self._outgoing,
46 server_hostname=server_hostname,
47 )
48
49 self._sock.settimeout(timeout)
50 self._perform_io(self.ssl_obj.do_handshake)
51
52 def _perform_io(
53 self,
54 func: typing.Callable[..., typing.Any],
55 ) -> typing.Any:
56 ret = None
57
58 while True:
59 errno = None
60 try:
61 ret = func()
62 except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
63 errno = e.errno
64
65 self._sock.sendall(self._outgoing.read())
66
67 if errno == ssl.SSL_ERROR_WANT_READ:
68 buf = self._sock.recv(self.TLS_RECORD_SIZE)
69
70 if buf:
71 self._incoming.write(buf)
72 else:
73 self._incoming.write_eof()
74 if errno is None:
75 return ret
76
77 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
78 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
79 with map_exceptions(exc_map):
80 self._sock.settimeout(timeout)
81 return typing.cast(
82 bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes))
83 )
84
85 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
86 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
87 with map_exceptions(exc_map):
88 self._sock.settimeout(timeout)
89 while buffer:
90 nsent = self._perform_io(partial(self.ssl_obj.write, buffer))
91 buffer = buffer[nsent:]
92
93 def close(self) -> None:
94 self._sock.close()
95
96 def start_tls(
97 self,
98 ssl_context: ssl.SSLContext,
99 server_hostname: typing.Optional[str] = None,
100 timeout: typing.Optional[float] = None,
101 ) -> "NetworkStream":
102 raise NotImplementedError()
103
104 def get_extra_info(self, info: str) -> typing.Any:
105 if info == "ssl_object":
106 return self.ssl_obj
107 if info == "client_addr":
108 return self._sock.getsockname()
109 if info == "server_addr":
110 return self._sock.getpeername()
111 if info == "socket":
112 return self._sock
113 if info == "is_readable":
114 return is_socket_readable(self._sock)
115 return None
116
117
118class SyncStream(NetworkStream):
119 def __init__(self, sock: socket.socket) -> None:
120 self._sock = sock
121
122 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
123 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
124 with map_exceptions(exc_map):
125 self._sock.settimeout(timeout)
126 return self._sock.recv(max_bytes)
127
128 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
129 if not buffer:
130 return
131
132 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
133 with map_exceptions(exc_map):
134 while buffer:
135 self._sock.settimeout(timeout)
136 n = self._sock.send(buffer)
137 buffer = buffer[n:]
138
139 def close(self) -> None:
140 self._sock.close()
141
142 def start_tls(
143 self,
144 ssl_context: ssl.SSLContext,
145 server_hostname: typing.Optional[str] = None,
146 timeout: typing.Optional[float] = None,
147 ) -> NetworkStream:
148 exc_map: ExceptionMapping = {
149 socket.timeout: ConnectTimeout,
150 OSError: ConnectError,
151 }
152 with map_exceptions(exc_map):
153 try:
154 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
155 # If the underlying socket has already been upgraded
156 # to the TLS layer (i.e. is an instance of SSLSocket),
157 # we need some additional smarts to support TLS-in-TLS.
158 return TLSinTLSStream(
159 self._sock, ssl_context, server_hostname, timeout
160 )
161 else:
162 self._sock.settimeout(timeout)
163 sock = ssl_context.wrap_socket(
164 self._sock, server_hostname=server_hostname
165 )
166 except Exception as exc: # pragma: nocover
167 self.close()
168 raise exc
169 return SyncStream(sock)
170
171 def get_extra_info(self, info: str) -> typing.Any:
172 if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
173 return self._sock._sslobj # type: ignore
174 if info == "client_addr":
175 return self._sock.getsockname()
176 if info == "server_addr":
177 return self._sock.getpeername()
178 if info == "socket":
179 return self._sock
180 if info == "is_readable":
181 return is_socket_readable(self._sock)
182 return None
183
184
185class SyncBackend(NetworkBackend):
186 def connect_tcp(
187 self,
188 host: str,
189 port: int,
190 timeout: typing.Optional[float] = None,
191 local_address: typing.Optional[str] = None,
192 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
193 ) -> NetworkStream:
194 # Note that we automatically include `TCP_NODELAY`
195 # in addition to any other custom socket options.
196 if socket_options is None:
197 socket_options = [] # pragma: no cover
198 address = (host, port)
199 source_address = None if local_address is None else (local_address, 0)
200 exc_map: ExceptionMapping = {
201 socket.timeout: ConnectTimeout,
202 OSError: ConnectError,
203 }
204
205 with map_exceptions(exc_map):
206 sock = socket.create_connection(
207 address,
208 timeout,
209 source_address=source_address,
210 )
211 for option in socket_options:
212 sock.setsockopt(*option) # pragma: no cover
213 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
214 return SyncStream(sock)
215
216 def connect_unix_socket(
217 self,
218 path: str,
219 timeout: typing.Optional[float] = None,
220 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
221 ) -> NetworkStream: # pragma: nocover
222 if sys.platform == "win32":
223 raise RuntimeError(
224 "Attempted to connect to a UNIX socket on a Windows system."
225 )
226 if socket_options is None:
227 socket_options = []
228
229 exc_map: ExceptionMapping = {
230 socket.timeout: ConnectTimeout,
231 OSError: ConnectError,
232 }
233 with map_exceptions(exc_map):
234 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
235 for option in socket_options:
236 sock.setsockopt(*option)
237 sock.settimeout(timeout)
238 sock.connect(path)
239 return SyncStream(sock)