Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 38%

137 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import logging 

2import re 

3import ssl 

4from dataclasses import dataclass 

5from functools import wraps 

6from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union 

7 

8from .. import ( 

9 BrokenResourceError, 

10 EndOfStream, 

11 aclose_forcefully, 

12 get_cancelled_exc_class, 

13) 

14from .._core._typedattr import TypedAttributeSet, typed_attribute 

15from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup 

16 

17T_Retval = TypeVar("T_Retval") 

18_PCTRTT = Tuple[Tuple[str, str], ...] 

19_PCTRTTT = Tuple[_PCTRTT, ...] 

20 

21 

22class TLSAttribute(TypedAttributeSet): 

23 """Contains Transport Layer Security related attributes.""" 

24 

25 #: the selected ALPN protocol 

26 alpn_protocol: Optional[str] = typed_attribute() 

27 #: the channel binding for type ``tls-unique`` 

28 channel_binding_tls_unique: bytes = typed_attribute() 

29 #: the selected cipher 

30 cipher: Tuple[str, str, int] = typed_attribute() 

31 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more 

32 #: information) 

33 peer_certificate: Optional[ 

34 Dict[str, Union[str, _PCTRTTT, _PCTRTT]] 

35 ] = typed_attribute() 

36 #: the peer certificate in binary form 

37 peer_certificate_binary: Optional[bytes] = typed_attribute() 

38 #: ``True`` if this is the server side of the connection 

39 server_side: bool = typed_attribute() 

40 #: ciphers shared between both ends of the TLS connection 

41 shared_ciphers: List[Tuple[str, str, int]] = typed_attribute() 

42 #: the :class:`~ssl.SSLObject` used for encryption 

43 ssl_object: ssl.SSLObject = typed_attribute() 

44 #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being 

45 #: closed 

46 standard_compatible: bool = typed_attribute() 

47 #: the TLS protocol version (e.g. ``TLSv1.2``) 

48 tls_version: str = typed_attribute() 

49 

50 

51@dataclass(eq=False) 

52class TLSStream(ByteStream): 

53 """ 

54 A stream wrapper that encrypts all sent data and decrypts received data. 

55 

56 This class has no public initializer; use :meth:`wrap` instead. 

57 All extra attributes from :class:`~TLSAttribute` are supported. 

58 

59 :var AnyByteStream transport_stream: the wrapped stream 

60 

61 """ 

62 

63 transport_stream: AnyByteStream 

64 standard_compatible: bool 

65 _ssl_object: ssl.SSLObject 

66 _read_bio: ssl.MemoryBIO 

67 _write_bio: ssl.MemoryBIO 

68 

69 @classmethod 

70 async def wrap( 

71 cls, 

72 transport_stream: AnyByteStream, 

73 *, 

74 server_side: Optional[bool] = None, 

75 hostname: Optional[str] = None, 

76 ssl_context: Optional[ssl.SSLContext] = None, 

77 standard_compatible: bool = True, 

78 ) -> "TLSStream": 

79 """ 

80 Wrap an existing stream with Transport Layer Security. 

81 

82 This performs a TLS handshake with the peer. 

83 

84 :param transport_stream: a bytes-transporting stream to wrap 

85 :param server_side: ``True`` if this is the server side of the connection, ``False`` if 

86 this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been 

87 provided, ``False`` otherwise). Used only to create a default context when an explicit 

88 context has not been provided. 

89 :param hostname: host name of the peer (if host name checking is desired) 

90 :param ssl_context: the SSLContext object to use (if not provided, a secure default will be 

91 created) 

92 :param standard_compatible: if ``False``, skip the closing handshake when closing the 

93 connection, and don't raise an exception if the peer does the same 

94 :raises ~ssl.SSLError: if the TLS handshake fails 

95 

96 """ 

97 if server_side is None: 

98 server_side = not hostname 

99 

100 if not ssl_context: 

101 purpose = ( 

102 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

103 ) 

104 ssl_context = ssl.create_default_context(purpose) 

105 

106 # Re-enable detection of unexpected EOFs if it was disabled by Python 

107 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

108 ssl_context.options ^= ssl.OP_IGNORE_UNEXPECTED_EOF # type: ignore[attr-defined] 

109 

110 bio_in = ssl.MemoryBIO() 

111 bio_out = ssl.MemoryBIO() 

112 ssl_object = ssl_context.wrap_bio( 

113 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

114 ) 

115 wrapper = cls( 

116 transport_stream=transport_stream, 

117 standard_compatible=standard_compatible, 

118 _ssl_object=ssl_object, 

119 _read_bio=bio_in, 

120 _write_bio=bio_out, 

121 ) 

122 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

123 return wrapper 

124 

125 async def _call_sslobject_method( 

126 self, func: Callable[..., T_Retval], *args: object 

127 ) -> T_Retval: 

128 while True: 

129 try: 

130 result = func(*args) 

131 except ssl.SSLWantReadError: 

132 try: 

133 # Flush any pending writes first 

134 if self._write_bio.pending: 

135 await self.transport_stream.send(self._write_bio.read()) 

136 

137 data = await self.transport_stream.receive() 

138 except EndOfStream: 

139 self._read_bio.write_eof() 

140 except OSError as exc: 

141 self._read_bio.write_eof() 

142 self._write_bio.write_eof() 

143 raise BrokenResourceError from exc 

144 else: 

145 self._read_bio.write(data) 

146 except ssl.SSLWantWriteError: 

147 await self.transport_stream.send(self._write_bio.read()) 

148 except ssl.SSLSyscallError as exc: 

149 self._read_bio.write_eof() 

150 self._write_bio.write_eof() 

151 raise BrokenResourceError from exc 

152 except ssl.SSLError as exc: 

153 self._read_bio.write_eof() 

154 self._write_bio.write_eof() 

155 if ( 

156 isinstance(exc, ssl.SSLEOFError) 

157 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

158 ): 

159 if self.standard_compatible: 

160 raise BrokenResourceError from exc 

161 else: 

162 raise EndOfStream from None 

163 

164 raise 

165 else: 

166 # Flush any pending writes first 

167 if self._write_bio.pending: 

168 await self.transport_stream.send(self._write_bio.read()) 

169 

170 return result 

171 

172 async def unwrap(self) -> Tuple[AnyByteStream, bytes]: 

173 """ 

174 Does the TLS closing handshake. 

175 

176 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

177 

178 """ 

179 await self._call_sslobject_method(self._ssl_object.unwrap) 

180 self._read_bio.write_eof() 

181 self._write_bio.write_eof() 

182 return self.transport_stream, self._read_bio.read() 

183 

184 async def aclose(self) -> None: 

185 if self.standard_compatible: 

186 try: 

187 await self.unwrap() 

188 except BaseException: 

189 await aclose_forcefully(self.transport_stream) 

190 raise 

191 

192 await self.transport_stream.aclose() 

193 

194 async def receive(self, max_bytes: int = 65536) -> bytes: 

195 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

196 if not data: 

197 raise EndOfStream 

198 

199 return data 

200 

201 async def send(self, item: bytes) -> None: 

202 await self._call_sslobject_method(self._ssl_object.write, item) 

203 

204 async def send_eof(self) -> None: 

205 tls_version = self.extra(TLSAttribute.tls_version) 

206 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

207 if match: 

208 major, minor = int(match.group(1)), int(match.group(2) or 0) 

209 if (major, minor) < (1, 3): 

210 raise NotImplementedError( 

211 f"send_eof() requires at least TLSv1.3; current " 

212 f"session uses {tls_version}" 

213 ) 

214 

215 raise NotImplementedError( 

216 "send_eof() has not yet been implemented for TLS streams" 

217 ) 

218 

219 @property 

220 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

221 return { 

222 **self.transport_stream.extra_attributes, 

223 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

224 TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding, 

225 TLSAttribute.cipher: self._ssl_object.cipher, 

226 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

227 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

228 True 

229 ), 

230 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

231 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers(), 

232 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

233 TLSAttribute.ssl_object: lambda: self._ssl_object, 

234 TLSAttribute.tls_version: self._ssl_object.version, 

235 } 

236 

237 

238@dataclass(eq=False) 

239class TLSListener(Listener[TLSStream]): 

240 """ 

241 A convenience listener that wraps another listener and auto-negotiates a TLS session on every 

242 accepted connection. 

243 

244 If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is 

245 called to do whatever post-mortem processing is deemed necessary. 

246 

247 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

248 

249 :param Listener listener: the listener to wrap 

250 :param ssl_context: the SSL context object 

251 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

252 :param handshake_timeout: time limit for the TLS handshake 

253 (passed to :func:`~anyio.fail_after`) 

254 """ 

255 

256 listener: Listener[Any] 

257 ssl_context: ssl.SSLContext 

258 standard_compatible: bool = True 

259 handshake_timeout: float = 30 

260 

261 @staticmethod 

262 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

263 f""" 

264 Handle an exception raised during the TLS handshake. 

265 

266 This method does 3 things: 

267 

268 #. Forcefully closes the original stream 

269 #. Logs the exception (unless it was a cancellation exception) using the ``{__name__}`` 

270 logger 

271 #. Reraises the exception if it was a base exception or a cancellation exception 

272 

273 :param exc: the exception 

274 :param stream: the original stream 

275 

276 """ 

277 await aclose_forcefully(stream) 

278 

279 # Log all except cancellation exceptions 

280 if not isinstance(exc, get_cancelled_exc_class()): 

281 logging.getLogger(__name__).exception("Error during TLS handshake") 

282 

283 # Only reraise base exceptions and cancellation exceptions 

284 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

285 raise 

286 

287 async def serve( 

288 self, 

289 handler: Callable[[TLSStream], Any], 

290 task_group: Optional[TaskGroup] = None, 

291 ) -> None: 

292 @wraps(handler) 

293 async def handler_wrapper(stream: AnyByteStream) -> None: 

294 from .. import fail_after 

295 

296 try: 

297 with fail_after(self.handshake_timeout): 

298 wrapped_stream = await TLSStream.wrap( 

299 stream, 

300 ssl_context=self.ssl_context, 

301 standard_compatible=self.standard_compatible, 

302 ) 

303 except BaseException as exc: 

304 await self.handle_handshake_error(exc, stream) 

305 else: 

306 await handler(wrapped_stream) 

307 

308 await self.listener.serve(handler_wrapper, task_group) 

309 

310 async def aclose(self) -> None: 

311 await self.listener.aclose() 

312 

313 @property 

314 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

315 return { 

316 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

317 }