Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_backends/sync.py: 30%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

57 statements  

1import socket 

2import ssl 

3import sys 

4import typing 

5from functools import partial 

6 

7from .._exceptions import ( 

8 ConnectError, 

9 ConnectTimeout, 

10 ExceptionMapping, 

11 ReadError, 

12 ReadTimeout, 

13 WriteError, 

14 WriteTimeout, 

15 map_exceptions, 

16) 

17from .._utils import is_socket_readable 

18from .base import SOCKET_OPTION, NetworkBackend, NetworkStream 

19 

20 

21class TLSinTLSStream(NetworkStream): # pragma: no cover 

22 """ 

23 Because the standard `SSLContext.wrap_socket` method does 

24 not work for `SSLSocket` objects, we need this class 

25 to implement TLS stream using an underlying `SSLObject` 

26 instance in order to support TLS on top of TLS. 

27 """ 

28 

29 # Defined in RFC 8449 

30 TLS_RECORD_SIZE = 16384 

31 

32 def __init__( 

33 self, 

34 sock: socket.socket, 

35 ssl_context: ssl.SSLContext, 

36 server_hostname: typing.Optional[str] = None, 

37 timeout: typing.Optional[float] = None, 

38 ): 

39 self._sock = sock 

40 self._incoming = ssl.MemoryBIO() 

41 self._outgoing = ssl.MemoryBIO() 

42 

43 self.ssl_obj = ssl_context.wrap_bio( 

44 incoming=self._incoming, 

45 outgoing=self._outgoing, 

46 server_hostname=server_hostname, 

47 ) 

48 

49 self._sock.settimeout(timeout) 

50 self._perform_io(self.ssl_obj.do_handshake) 

51 

52 def _perform_io( 

53 self, 

54 func: typing.Callable[..., typing.Any], 

55 ) -> typing.Any: 

56 ret = None 

57 

58 while True: 

59 errno = None 

60 try: 

61 ret = func() 

62 except (ssl.SSLWantReadError, ssl.SSLWantWriteError) as e: 

63 errno = e.errno 

64 

65 self._sock.sendall(self._outgoing.read()) 

66 

67 if errno == ssl.SSL_ERROR_WANT_READ: 

68 buf = self._sock.recv(self.TLS_RECORD_SIZE) 

69 

70 if buf: 

71 self._incoming.write(buf) 

72 else: 

73 self._incoming.write_eof() 

74 if errno is None: 

75 return ret 

76 

77 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: 

78 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 

79 with map_exceptions(exc_map): 

80 self._sock.settimeout(timeout) 

81 return typing.cast( 

82 bytes, self._perform_io(partial(self.ssl_obj.read, max_bytes)) 

83 ) 

84 

85 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: 

86 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 

87 with map_exceptions(exc_map): 

88 self._sock.settimeout(timeout) 

89 while buffer: 

90 nsent = self._perform_io(partial(self.ssl_obj.write, buffer)) 

91 buffer = buffer[nsent:] 

92 

93 def close(self) -> None: 

94 self._sock.close() 

95 

96 def start_tls( 

97 self, 

98 ssl_context: ssl.SSLContext, 

99 server_hostname: typing.Optional[str] = None, 

100 timeout: typing.Optional[float] = None, 

101 ) -> "NetworkStream": 

102 raise NotImplementedError() 

103 

104 def get_extra_info(self, info: str) -> typing.Any: 

105 if info == "ssl_object": 

106 return self.ssl_obj 

107 if info == "client_addr": 

108 return self._sock.getsockname() 

109 if info == "server_addr": 

110 return self._sock.getpeername() 

111 if info == "socket": 

112 return self._sock 

113 if info == "is_readable": 

114 return is_socket_readable(self._sock) 

115 return None 

116 

117 

118class SyncStream(NetworkStream): 

119 def __init__(self, sock: socket.socket) -> None: 

120 self._sock = sock 

121 

122 def read(self, max_bytes: int, timeout: typing.Optional[float] = None) -> bytes: 

123 exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} 

124 with map_exceptions(exc_map): 

125 self._sock.settimeout(timeout) 

126 return self._sock.recv(max_bytes) 

127 

128 def write(self, buffer: bytes, timeout: typing.Optional[float] = None) -> None: 

129 if not buffer: 

130 return 

131 

132 exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} 

133 with map_exceptions(exc_map): 

134 while buffer: 

135 self._sock.settimeout(timeout) 

136 n = self._sock.send(buffer) 

137 buffer = buffer[n:] 

138 

139 def close(self) -> None: 

140 self._sock.close() 

141 

142 def start_tls( 

143 self, 

144 ssl_context: ssl.SSLContext, 

145 server_hostname: typing.Optional[str] = None, 

146 timeout: typing.Optional[float] = None, 

147 ) -> NetworkStream: 

148 exc_map: ExceptionMapping = { 

149 socket.timeout: ConnectTimeout, 

150 OSError: ConnectError, 

151 } 

152 with map_exceptions(exc_map): 

153 try: 

154 if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover 

155 # If the underlying socket has already been upgraded 

156 # to the TLS layer (i.e. is an instance of SSLSocket), 

157 # we need some additional smarts to support TLS-in-TLS. 

158 return TLSinTLSStream( 

159 self._sock, ssl_context, server_hostname, timeout 

160 ) 

161 else: 

162 self._sock.settimeout(timeout) 

163 sock = ssl_context.wrap_socket( 

164 self._sock, server_hostname=server_hostname 

165 ) 

166 except Exception as exc: # pragma: nocover 

167 self.close() 

168 raise exc 

169 return SyncStream(sock) 

170 

171 def get_extra_info(self, info: str) -> typing.Any: 

172 if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): 

173 return self._sock._sslobj # type: ignore 

174 if info == "client_addr": 

175 return self._sock.getsockname() 

176 if info == "server_addr": 

177 return self._sock.getpeername() 

178 if info == "socket": 

179 return self._sock 

180 if info == "is_readable": 

181 return is_socket_readable(self._sock) 

182 return None 

183 

184 

185class SyncBackend(NetworkBackend): 

186 def connect_tcp( 

187 self, 

188 host: str, 

189 port: int, 

190 timeout: typing.Optional[float] = None, 

191 local_address: typing.Optional[str] = None, 

192 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, 

193 ) -> NetworkStream: 

194 # Note that we automatically include `TCP_NODELAY` 

195 # in addition to any other custom socket options. 

196 if socket_options is None: 

197 socket_options = [] # pragma: no cover 

198 address = (host, port) 

199 source_address = None if local_address is None else (local_address, 0) 

200 exc_map: ExceptionMapping = { 

201 socket.timeout: ConnectTimeout, 

202 OSError: ConnectError, 

203 } 

204 

205 with map_exceptions(exc_map): 

206 sock = socket.create_connection( 

207 address, 

208 timeout, 

209 source_address=source_address, 

210 ) 

211 for option in socket_options: 

212 sock.setsockopt(*option) # pragma: no cover 

213 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

214 return SyncStream(sock) 

215 

216 def connect_unix_socket( 

217 self, 

218 path: str, 

219 timeout: typing.Optional[float] = None, 

220 socket_options: typing.Optional[typing.Iterable[SOCKET_OPTION]] = None, 

221 ) -> NetworkStream: # pragma: nocover 

222 if sys.platform == "win32": 

223 raise RuntimeError( 

224 "Attempted to connect to a UNIX socket on a Windows system." 

225 ) 

226 if socket_options is None: 

227 socket_options = [] 

228 

229 exc_map: ExceptionMapping = { 

230 socket.timeout: ConnectTimeout, 

231 OSError: ConnectError, 

232 } 

233 with map_exceptions(exc_map): 

234 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

235 for option in socket_options: 

236 sock.setsockopt(*option) 

237 sock.settimeout(timeout) 

238 sock.connect(path) 

239 return SyncStream(sock)