Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_backends/sync.py: 30%

57 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:38 +0000

1import socket 

2import ssl 

3import sys 

4import typing 

5from functools import partial 

6 

7from .._exceptions import ( 

8 ConnectError, 

9 ConnectTimeout, 

10 ExceptionMapping, 

11 ReadError, 

12 ReadTimeout, 

13 WriteError, 

14 WriteTimeout, 

15 map_exceptions, 

16) 

17from .._utils import is_socket_readable 

18from .base import SOCKET_OPTION, NetworkBackend, NetworkStream 

19 

20 

21class TLSinTLSStream(NetworkStream): # pragma: no cover 

22 """ 

23 Because the standard `SSLContext.wrap_socket` method does 

24 not work for `SSLSocket` objects, we need this class 

25 to implement TLS stream using an underlying `SSLObject` 

26 instance in order to support TLS on top of TLS. 

27 """ 

28 

29 # Defined in RFC 8449 

30 TLS_RECORD_SIZE = 16384 

31 

32 def __init__( 

33 self, 

34 sock: socket.socket, 

35 ssl_context: ssl.SSLContext, 

36 server_hostname: typing.Optional[str] = None, 

37 timeout: typing.Optional[float] = None, 

38 ): 

39 self._sock = sock 

40 self._incoming = ssl.MemoryBIO() 

41 self._outgoing = ssl.MemoryBIO() 

42 

43 self.ssl_obj = ssl_context.wrap_bio( 

44 incoming=self._incoming, 

45 outgoing=self._outgoing, 

46 server_hostname=server_hostname, 

47 ) 

48 

49 self._sock.settimeout(timeout) 

50 self._perform_io(self.ssl_obj.do_handshake) 

51 

52 def _perform_io( 

53 self, 

54 func: typing.Callable[..., typing.Any], 

55 ) -> typing.Any: 

56 ret = None 

57 

58 while True: 

59 errno = None 

60 try: 

61 ret = func() 

62 except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: 

63 errno = e.errno 

64 

65 self._sock.sendall(self._outgoing.read()) 

66 

67 if errno == ssl.SSL_ERROR_WANT_READ: 

68 buf = self._sock.recv(self.TLS_RECORD_SIZE) 

69 

70 if buf: 

71 self._incoming.write(buf) 

72 else: 

73 self._incoming.write_eof() 

74 if errno is None: 

75 return ret 

76 

77 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: 

78 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 

79 with map_exceptions(exc_map): 

80 self._sock.settimeout(timeout) 

81 return typing.cast( 

82 bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes)) 

83 ) 

84 

85 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: 

86 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 

87 with map_exceptions(exc_map): 

88 self._sock.settimeout(timeout) 

89 while buffer: 

90 nsent = self._perform_io(partial(self.ssl_obj.write, buffer)) 

91 buffer = buffer[nsent:] 

92 

93 def close(self) -> None: 

94 self._sock.close() 

95 

96 def start_tls( 

97 self, 

98 ssl_context: ssl.SSLContext, 

99 server_hostname: typing.Optional[str] = None, 

100 timeout: typing.Optional[float] = None, 

101 ) -> "NetworkStream": 

102 raise NotImplementedError() 

103 

104 def get_extra_info(self, info: str) -> typing.Any: 

105 if info == "ssl_object": 

106 return self.ssl_obj 

107 if info == "client_addr": 

108 return self._sock.getsockname() 

109 if info == "server_addr": 

110 return self._sock.getpeername() 

111 if info == "socket": 

112 return self._sock 

113 if info == "is_readable": 

114 return is_socket_readable(self._sock) 

115 return None 

116 

117 

118class SyncStream(NetworkStream): 

119 def __init__(self, sock: socket.socket) -> None: 

120 self._sock = sock 

121 

122 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: 

123 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 

124 with map_exceptions(exc_map): 

125 self._sock.settimeout(timeout) 

126 return self._sock.recv(max_bytes) 

127 

128 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: 

129 if not buffer: 

130 return 

131 

132 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 

133 with map_exceptions(exc_map): 

134 while buffer: 

135 self._sock.settimeout(timeout) 

136 n = self._sock.send(buffer) 

137 buffer = buffer[n:] 

138 

139 def close(self) -> None: 

140 self._sock.close() 

141 

142 def start_tls( 

143 self, 

144 ssl_context: ssl.SSLContext, 

145 server_hostname: typing.Optional[str] = None, 

146 timeout: typing.Optional[float] = None, 

147 ) -> NetworkStream: 

148 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover 

149 raise RuntimeError( 

150 "Attempted to add a TLS layer on top of the existing " 

151 "TLS stream, which is not supported by httpcore package" 

152 ) 

153 

154 exc_map: ExceptionMapping = { 

155 socket.timeout: ConnectTimeout, 

156 OSError: ConnectError, 

157 } 

158 with map_exceptions(exc_map): 

159 try: 

160 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover 

161 # If the underlying socket has already been upgraded 

162 # to the TLS layer (i.e. is an instance of SSLSocket), 

163 # we need some additional smarts to support TLS-in-TLS. 

164 return TLSinTLSStream( 

165 self._sock, ssl_context, server_hostname, timeout 

166 ) 

167 else: 

168 self._sock.settimeout(timeout) 

169 sock = ssl_context.wrap_socket( 

170 self._sock, server_hostname=server_hostname 

171 ) 

172 except Exception as exc: # pragma: nocover 

173 self.close() 

174 raise exc 

175 return SyncStream(sock) 

176 

177 def get_extra_info(self, info: str) -> typing.Any: 

178 if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): 

179 return self._sock._sslobj # type: ignore 

180 if info == "client_addr": 

181 return self._sock.getsockname() 

182 if info == "server_addr": 

183 return self._sock.getpeername() 

184 if info == "socket": 

185 return self._sock 

186 if info == "is_readable": 

187 return is_socket_readable(self._sock) 

188 return None 

189 

190 

191class SyncBackend(NetworkBackend): 

192 def connect_tcp( 

193 self, 

194 host: str, 

195 port: int, 

196 timeout: typing.Optional[float] = None, 

197 local_address: typing.Optional[str] = None, 

198 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, 

199 ) -> NetworkStream: 

200 # Note that we automatically include `TCP_NODELAY` 

201 # in addition to any other custom socket options. 

202 if socket_options is None: 

203 socket_options = [] # pragma: no cover 

204 address = (host, port) 

205 source_address = None if local_address is None else (local_address, 0) 

206 exc_map: ExceptionMapping = { 

207 socket.timeout: ConnectTimeout, 

208 OSError: ConnectError, 

209 } 

210 

211 with map_exceptions(exc_map): 

212 sock = socket.create_connection( 

213 address, 

214 timeout, 

215 source_address=source_address, 

216 ) 

217 for option in socket_options: 

218 sock.setsockopt(*option) # pragma: no cover 

219 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

220 return SyncStream(sock) 

221 

222 def connect_unix_socket( 

223 self, 

224 path: str, 

225 timeout: typing.Optional[float] = None, 

226 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, 

227 ) -> NetworkStream: # pragma: nocover 

228 if sys.platform == "win32": 

229 raise RuntimeError( 

230 "Attempted to connect to a UNIX socket on a Windows system." 

231 ) 

232 if socket_options is None: 

233 socket_options = [] 

234 

235 exc_map: ExceptionMapping = { 

236 socket.timeout: ConnectTimeout, 

237 OSError: ConnectError, 

238 } 

239 with map_exceptions(exc_map): 

240 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

241 for option in socket_options: 

242 sock.setsockopt(*option) 

243 sock.settimeout(timeout) 

244 sock.connect(path) 

245 return SyncStream(sock)