Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_backends/sync.py: 30%
57 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:38 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:38 +0000
1import socket
2import ssl
3import sys
4import typing
5from functools import partial
7from .._exceptions import (
8 ConnectError,
9 ConnectTimeout,
10 ExceptionMapping,
11 ReadError,
12 ReadTimeout,
13 WriteError,
14 WriteTimeout,
15 map_exceptions,
16)
17from .._utils import is_socket_readable
18from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
21class TLSinTLSStream(NetworkStream): # pragma: no cover
22 """
23 Because the standard `SSLContext.wrap_socket` method does
24 not work for `SSLSocket` objects, we need this class
25 to implement TLS stream using an underlying `SSLObject`
26 instance in order to support TLS on top of TLS.
27 """
29 # Defined in RFC 8449
30 TLS_RECORD_SIZE = 16384
32 def __init__(
33 self,
34 sock: socket.socket,
35 ssl_context: ssl.SSLContext,
36 server_hostname: typing.Optional[str] = None,
37 timeout: typing.Optional[float] = None,
38 ):
39 self._sock = sock
40 self._incoming = ssl.MemoryBIO()
41 self._outgoing = ssl.MemoryBIO()
43 self.ssl_obj = ssl_context.wrap_bio(
44 incoming=self._incoming,
45 outgoing=self._outgoing,
46 server_hostname=server_hostname,
47 )
49 self._sock.settimeout(timeout)
50 self._perform_io(self.ssl_obj.do_handshake)
52 def _perform_io(
53 self,
54 func: typing.Callable[..., typing.Any],
55 ) -> typing.Any:
56 ret = None
58 while True:
59 errno = None
60 try:
61 ret = func()
62 except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e:
63 errno = e.errno
65 self._sock.sendall(self._outgoing.read())
67 if errno == ssl.SSL_ERROR_WANT_READ:
68 buf = self._sock.recv(self.TLS_RECORD_SIZE)
70 if buf:
71 self._incoming.write(buf)
72 else:
73 self._incoming.write_eof()
74 if errno is None:
75 return ret
77 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
78 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
79 with map_exceptions(exc_map):
80 self._sock.settimeout(timeout)
81 return typing.cast(
82 bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes))
83 )
85 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
86 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
87 with map_exceptions(exc_map):
88 self._sock.settimeout(timeout)
89 while buffer:
90 nsent = self._perform_io(partial(self.ssl_obj.write, buffer))
91 buffer = buffer[nsent:]
93 def close(self) -> None:
94 self._sock.close()
96 def start_tls(
97 self,
98 ssl_context: ssl.SSLContext,
99 server_hostname: typing.Optional[str] = None,
100 timeout: typing.Optional[float] = None,
101 ) -> "NetworkStream":
102 raise NotImplementedError()
104 def get_extra_info(self, info: str) -> typing.Any:
105 if info == "ssl_object":
106 return self.ssl_obj
107 if info == "client_addr":
108 return self._sock.getsockname()
109 if info == "server_addr":
110 return self._sock.getpeername()
111 if info == "socket":
112 return self._sock
113 if info == "is_readable":
114 return is_socket_readable(self._sock)
115 return None
118class SyncStream(NetworkStream):
119 def __init__(self, sock: socket.socket) -> None:
120 self._sock = sock
122 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes:
123 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
124 with map_exceptions(exc_map):
125 self._sock.settimeout(timeout)
126 return self._sock.recv(max_bytes)
128 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None:
129 if not buffer:
130 return
132 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
133 with map_exceptions(exc_map):
134 while buffer:
135 self._sock.settimeout(timeout)
136 n = self._sock.send(buffer)
137 buffer = buffer[n:]
139 def close(self) -> None:
140 self._sock.close()
142 def start_tls(
143 self,
144 ssl_context: ssl.SSLContext,
145 server_hostname: typing.Optional[str] = None,
146 timeout: typing.Optional[float] = None,
147 ) -> NetworkStream:
148 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
149 raise RuntimeError(
150 "Attempted to add a TLS layer on top of the existing "
151 "TLS stream, which is not supported by httpcore package"
152 )
154 exc_map: ExceptionMapping = {
155 socket.timeout: ConnectTimeout,
156 OSError: ConnectError,
157 }
158 with map_exceptions(exc_map):
159 try:
160 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
161 # If the underlying socket has already been upgraded
162 # to the TLS layer (i.e. is an instance of SSLSocket),
163 # we need some additional smarts to support TLS-in-TLS.
164 return TLSinTLSStream(
165 self._sock, ssl_context, server_hostname, timeout
166 )
167 else:
168 self._sock.settimeout(timeout)
169 sock = ssl_context.wrap_socket(
170 self._sock, server_hostname=server_hostname
171 )
172 except Exception as exc: # pragma: nocover
173 self.close()
174 raise exc
175 return SyncStream(sock)
177 def get_extra_info(self, info: str) -> typing.Any:
178 if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket):
179 return self._sock._sslobj # type: ignore
180 if info == "client_addr":
181 return self._sock.getsockname()
182 if info == "server_addr":
183 return self._sock.getpeername()
184 if info == "socket":
185 return self._sock
186 if info == "is_readable":
187 return is_socket_readable(self._sock)
188 return None
191class SyncBackend(NetworkBackend):
192 def connect_tcp(
193 self,
194 host: str,
195 port: int,
196 timeout: typing.Optional[float] = None,
197 local_address: typing.Optional[str] = None,
198 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
199 ) -> NetworkStream:
200 # Note that we automatically include `TCP_NODELAY`
201 # in addition to any other custom socket options.
202 if socket_options is None:
203 socket_options = [] # pragma: no cover
204 address = (host, port)
205 source_address = None if local_address is None else (local_address, 0)
206 exc_map: ExceptionMapping = {
207 socket.timeout: ConnectTimeout,
208 OSError: ConnectError,
209 }
211 with map_exceptions(exc_map):
212 sock = socket.create_connection(
213 address,
214 timeout,
215 source_address=source_address,
216 )
217 for option in socket_options:
218 sock.setsockopt(*option) # pragma: no cover
219 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
220 return SyncStream(sock)
222 def connect_unix_socket(
223 self,
224 path: str,
225 timeout: typing.Optional[float] = None,
226 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None,
227 ) -> NetworkStream: # pragma: nocover
228 if sys.platform == "win32":
229 raise RuntimeError(
230 "Attempted to connect to a UNIX socket on a Windows system."
231 )
232 if socket_options is None:
233 socket_options = []
235 exc_map: ExceptionMapping = {
236 socket.timeout: ConnectTimeout,
237 OSError: ConnectError,
238 }
239 with map_exceptions(exc_map):
240 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
241 for option in socket_options:
242 sock.setsockopt(*option)
243 sock.settimeout(timeout)
244 sock.connect(path)
245 return SyncStream(sock)