Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/httpcore/_backends/sync.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

58 statements  

1from __future__ import annotations 

2 

3import functools 

4import socket 

5import ssl 

6import sys 

7import typing 

8 

9from .._exceptions import ( 

10 ConnectError, 

11 ConnectTimeout, 

12 ExceptionMapping, 

13 ReadError, 

14 ReadTimeout, 

15 WriteError, 

16 WriteTimeout, 

17 map_exceptions, 

18) 

19from .._utils import is_socket_readable 

20from .base import SOCKET_OPTION, NetworkBackend, NetworkStream 

21 

22 

23class TLSinTLSStream(NetworkStream): # pragma: no cover 

24 """ 

25 Because the standard `SSLContext.wrap_socket` method does 

26 not work for `SSLSocket` objects, we need this class 

27 to implement TLS stream using an underlying `SSLObject` 

28 instance in order to support TLS on top of TLS. 

29 """ 

30 

31 # Defined in RFC 8449 

32 TLS_RECORD_SIZE = 16384 

33 

34 def __init__( 

35 self, 

36 sock: socket.socket, 

37 ssl_context: ssl.SSLContext, 

38 server_hostname: str | None = None, 

39 timeout: float | None = None, 

40 ): 

41 self._sock = sock 

42 self._incoming = ssl.MemoryBIO() 

43 self._outgoing = ssl.MemoryBIO() 

44 

45 self.ssl_obj = ssl_context.wrap_bio( 

46 incoming=self._incoming, 

47 outgoing=self._outgoing, 

48 server_hostname=server_hostname, 

49 ) 

50 

51 self._sock.settimeout(timeout) 

52 self._perform_io(self.ssl_obj.do_handshake) 

53 

54 def _perform_io( 

55 self, 

56 func: typing.Callable[..., typing.Any], 

57 ) -> typing.Any: 

58 ret = None 

59 

60 while True: 

61 errno = None 

62 try: 

63 ret = func() 

64 except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: 

65 errno = e.errno 

66 

67 self._sock.sendall(self._outgoing.read()) 

68 

69 if errno == ssl.SSL_ERROR_WANT_READ: 

70 buf = self._sock.recv(self.TLS_RECORD_SIZE) 

71 

72 if buf: 

73 self._incoming.write(buf) 

74 else: 

75 self._incoming.write_eof() 

76 if errno is None: 

77 return ret 

78 

79 def read(self, max_bytes: int, timeout: float | None = None) -> bytes: 

80 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 

81 with map_exceptions(exc_map): 

82 self._sock.settimeout(timeout) 

83 return typing.cast( 

84 bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes)) 

85 ) 

86 

87 def write(self, buffer: bytes, timeout: float | None = None) -> None: 

88 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 

89 with map_exceptions(exc_map): 

90 self._sock.settimeout(timeout) 

91 while buffer: 

92 nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer)) 

93 buffer = buffer[nsent:] 

94 

95 def close(self) -> None: 

96 self._sock.close() 

97 

98 def start_tls( 

99 self, 

100 ssl_context: ssl.SSLContext, 

101 server_hostname: str | None = None, 

102 timeout: float | None = None, 

103 ) -> NetworkStream: 

104 raise NotImplementedError() 

105 

106 def get_extra_info(self, info: str) -> typing.Any: 

107 if info == "ssl_object": 

108 return self.ssl_obj 

109 if info == "client_addr": 

110 return self._sock.getsockname() 

111 if info == "server_addr": 

112 return self._sock.getpeername() 

113 if info == "socket": 

114 return self._sock 

115 if info == "is_readable": 

116 return is_socket_readable(self._sock) 

117 return None 

118 

119 

120class SyncStream(NetworkStream): 

121 def __init__(self, sock: socket.socket) -> None: 

122 self._sock = sock 

123 

124 def read(self, max_bytes: int, timeout: float | None = None) -> bytes: 

125 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 

126 with map_exceptions(exc_map): 

127 self._sock.settimeout(timeout) 

128 return self._sock.recv(max_bytes) 

129 

130 def write(self, buffer: bytes, timeout: float | None = None) -> None: 

131 if not buffer: 

132 return 

133 

134 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 

135 with map_exceptions(exc_map): 

136 while buffer: 

137 self._sock.settimeout(timeout) 

138 n = self._sock.send(buffer) 

139 buffer = buffer[n:] 

140 

141 def close(self) -> None: 

142 self._sock.close() 

143 

144 def start_tls( 

145 self, 

146 ssl_context: ssl.SSLContext, 

147 server_hostname: str | None = None, 

148 timeout: float | None = None, 

149 ) -> NetworkStream: 

150 exc_map: ExceptionMapping = { 

151 socket.timeout: ConnectTimeout, 

152 OSError: ConnectError, 

153 } 

154 with map_exceptions(exc_map): 

155 try: 

156 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover 

157 # If the underlying socket has already been upgraded 

158 # to the TLS layer (i.e. is an instance of SSLSocket), 

159 # we need some additional smarts to support TLS-in-TLS. 

160 return TLSinTLSStream( 

161 self._sock, ssl_context, server_hostname, timeout 

162 ) 

163 else: 

164 self._sock.settimeout(timeout) 

165 sock = ssl_context.wrap_socket( 

166 self._sock, server_hostname=server_hostname 

167 ) 

168 except Exception as exc: # pragma: nocover 

169 self.close() 

170 raise exc 

171 return SyncStream(sock) 

172 

173 def get_extra_info(self, info: str) -> typing.Any: 

174 if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): 

175 return self._sock._sslobj # type: ignore 

176 if info == "client_addr": 

177 return self._sock.getsockname() 

178 if info == "server_addr": 

179 return self._sock.getpeername() 

180 if info == "socket": 

181 return self._sock 

182 if info == "is_readable": 

183 return is_socket_readable(self._sock) 

184 return None 

185 

186 

187class SyncBackend(NetworkBackend): 

188 def connect_tcp( 

189 self, 

190 host: str, 

191 port: int, 

192 timeout: float | None = None, 

193 local_address: str | None = None, 

194 socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 

195 ) -> NetworkStream: 

196 # Note that we automatically include `TCP_NODELAY` 

197 # in addition to any other custom socket options. 

198 if socket_options is None: 

199 socket_options = [] # pragma: no cover 

200 address = (host, port) 

201 source_address = None if local_address is None else (local_address, 0) 

202 exc_map: ExceptionMapping = { 

203 socket.timeout: ConnectTimeout, 

204 OSError: ConnectError, 

205 } 

206 

207 with map_exceptions(exc_map): 

208 sock = socket.create_connection( 

209 address, 

210 timeout, 

211 source_address=source_address, 

212 ) 

213 for option in socket_options: 

214 sock.setsockopt(*option) # pragma: no cover 

215 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

216 return SyncStream(sock) 

217 

218 def connect_unix_socket( 

219 self, 

220 path: str, 

221 timeout: float | None = None, 

222 socket_options: typing.Iterable[SOCKET_OPTION] | None = None, 

223 ) -> NetworkStream: # pragma: nocover 

224 if sys.platform == "win32": 

225 raise RuntimeError( 

226 "Attempted to connect to a UNIX socket on a Windows system." 

227 ) 

228 if socket_options is None: 

229 socket_options = [] 

230 

231 exc_map: ExceptionMapping = { 

232 socket.timeout: ConnectTimeout, 

233 OSError: ConnectError, 

234 } 

235 with map_exceptions(exc_map): 

236 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

237 for option in socket_options: 

238 sock.setsockopt(*option) 

239 sock.settimeout(timeout) 

240 sock.connect(path) 

241 return SyncStream(sock)