Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/streams/tls.py: 38%

138 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 07:19 +0000

1from __future__ import annotations 

2 

3import logging 

4import re 

5import ssl 

6from dataclasses import dataclass 

7from functools import wraps 

8from typing import Any, Callable, Mapping, Tuple, TypeVar 

9 

10from .. import ( 

11 BrokenResourceError, 

12 EndOfStream, 

13 aclose_forcefully, 

14 get_cancelled_exc_class, 

15) 

16from .._core._typedattr import TypedAttributeSet, typed_attribute 

17from ..abc import AnyByteStream, ByteStream, Listener, TaskGroup 

18 

19T_Retval = TypeVar("T_Retval") 

20_PCTRTT = Tuple[Tuple[str, str], ...] 

21_PCTRTTT = Tuple[_PCTRTT, ...] 

22 

23 

24class TLSAttribute(TypedAttributeSet): 

25 """Contains Transport Layer Security related attributes.""" 

26 

27 #: the selected ALPN protocol 

28 alpn_protocol: str | None = typed_attribute() 

29 #: the channel binding for type ``tls-unique`` 

30 channel_binding_tls_unique: bytes = typed_attribute() 

31 #: the selected cipher 

32 cipher: tuple[str, str, int] = typed_attribute() 

33 #: the peer certificate in dictionary form (see :meth:`ssl.SSLSocket.getpeercert` for more 

34 #: information) 

35 peer_certificate: dict[str, str | _PCTRTTT | _PCTRTT] | None = typed_attribute() 

36 #: the peer certificate in binary form 

37 peer_certificate_binary: bytes | None = typed_attribute() 

38 #: ``True`` if this is the server side of the connection 

39 server_side: bool = typed_attribute() 

40 #: ciphers shared by the client during the TLS handshake (``None`` if this is the 

41 #: client side) 

42 shared_ciphers: list[tuple[str, str, int]] | None = typed_attribute() 

43 #: the :class:`~ssl.SSLObject` used for encryption 

44 ssl_object: ssl.SSLObject = typed_attribute() 

45 #: ``True`` if this stream does (and expects) a closing TLS handshake when the stream is being 

46 #: closed 

47 standard_compatible: bool = typed_attribute() 

48 #: the TLS protocol version (e.g. ``TLSv1.2``) 

49 tls_version: str = typed_attribute() 

50 

51 

52@dataclass(eq=False) 

53class TLSStream(ByteStream): 

54 """ 

55 A stream wrapper that encrypts all sent data and decrypts received data. 

56 

57 This class has no public initializer; use :meth:`wrap` instead. 

58 All extra attributes from :class:`~TLSAttribute` are supported. 

59 

60 :var AnyByteStream transport_stream: the wrapped stream 

61 

62 """ 

63 

64 transport_stream: AnyByteStream 

65 standard_compatible: bool 

66 _ssl_object: ssl.SSLObject 

67 _read_bio: ssl.MemoryBIO 

68 _write_bio: ssl.MemoryBIO 

69 

70 @classmethod 

71 async def wrap( 

72 cls, 

73 transport_stream: AnyByteStream, 

74 *, 

75 server_side: bool | None = None, 

76 hostname: str | None = None, 

77 ssl_context: ssl.SSLContext | None = None, 

78 standard_compatible: bool = True, 

79 ) -> TLSStream: 

80 """ 

81 Wrap an existing stream with Transport Layer Security. 

82 

83 This performs a TLS handshake with the peer. 

84 

85 :param transport_stream: a bytes-transporting stream to wrap 

86 :param server_side: ``True`` if this is the server side of the connection, ``False`` if 

87 this is the client side (if omitted, will be set to ``False`` if ``hostname`` has been 

88 provided, ``False`` otherwise). Used only to create a default context when an explicit 

89 context has not been provided. 

90 :param hostname: host name of the peer (if host name checking is desired) 

91 :param ssl_context: the SSLContext object to use (if not provided, a secure default will be 

92 created) 

93 :param standard_compatible: if ``False``, skip the closing handshake when closing the 

94 connection, and don't raise an exception if the peer does the same 

95 :raises ~ssl.SSLError: if the TLS handshake fails 

96 

97 """ 

98 if server_side is None: 

99 server_side = not hostname 

100 

101 if not ssl_context: 

102 purpose = ( 

103 ssl.Purpose.CLIENT_AUTH if server_side else ssl.Purpose.SERVER_AUTH 

104 ) 

105 ssl_context = ssl.create_default_context(purpose) 

106 

107 # Re-enable detection of unexpected EOFs if it was disabled by Python 

108 if hasattr(ssl, "OP_IGNORE_UNEXPECTED_EOF"): 

109 ssl_context.options &= ~ssl.OP_IGNORE_UNEXPECTED_EOF 

110 

111 bio_in = ssl.MemoryBIO() 

112 bio_out = ssl.MemoryBIO() 

113 ssl_object = ssl_context.wrap_bio( 

114 bio_in, bio_out, server_side=server_side, server_hostname=hostname 

115 ) 

116 wrapper = cls( 

117 transport_stream=transport_stream, 

118 standard_compatible=standard_compatible, 

119 _ssl_object=ssl_object, 

120 _read_bio=bio_in, 

121 _write_bio=bio_out, 

122 ) 

123 await wrapper._call_sslobject_method(ssl_object.do_handshake) 

124 return wrapper 

125 

126 async def _call_sslobject_method( 

127 self, func: Callable[..., T_Retval], *args: object 

128 ) -> T_Retval: 

129 while True: 

130 try: 

131 result = func(*args) 

132 except ssl.SSLWantReadError: 

133 try: 

134 # Flush any pending writes first 

135 if self._write_bio.pending: 

136 await self.transport_stream.send(self._write_bio.read()) 

137 

138 data = await self.transport_stream.receive() 

139 except EndOfStream: 

140 self._read_bio.write_eof() 

141 except OSError as exc: 

142 self._read_bio.write_eof() 

143 self._write_bio.write_eof() 

144 raise BrokenResourceError from exc 

145 else: 

146 self._read_bio.write(data) 

147 except ssl.SSLWantWriteError: 

148 await self.transport_stream.send(self._write_bio.read()) 

149 except ssl.SSLSyscallError as exc: 

150 self._read_bio.write_eof() 

151 self._write_bio.write_eof() 

152 raise BrokenResourceError from exc 

153 except ssl.SSLError as exc: 

154 self._read_bio.write_eof() 

155 self._write_bio.write_eof() 

156 if ( 

157 isinstance(exc, ssl.SSLEOFError) 

158 or "UNEXPECTED_EOF_WHILE_READING" in exc.strerror 

159 ): 

160 if self.standard_compatible: 

161 raise BrokenResourceError from exc 

162 else: 

163 raise EndOfStream from None 

164 

165 raise 

166 else: 

167 # Flush any pending writes first 

168 if self._write_bio.pending: 

169 await self.transport_stream.send(self._write_bio.read()) 

170 

171 return result 

172 

173 async def unwrap(self) -> tuple[AnyByteStream, bytes]: 

174 """ 

175 Does the TLS closing handshake. 

176 

177 :return: a tuple of (wrapped byte stream, bytes left in the read buffer) 

178 

179 """ 

180 await self._call_sslobject_method(self._ssl_object.unwrap) 

181 self._read_bio.write_eof() 

182 self._write_bio.write_eof() 

183 return self.transport_stream, self._read_bio.read() 

184 

185 async def aclose(self) -> None: 

186 if self.standard_compatible: 

187 try: 

188 await self.unwrap() 

189 except BaseException: 

190 await aclose_forcefully(self.transport_stream) 

191 raise 

192 

193 await self.transport_stream.aclose() 

194 

195 async def receive(self, max_bytes: int = 65536) -> bytes: 

196 data = await self._call_sslobject_method(self._ssl_object.read, max_bytes) 

197 if not data: 

198 raise EndOfStream 

199 

200 return data 

201 

202 async def send(self, item: bytes) -> None: 

203 await self._call_sslobject_method(self._ssl_object.write, item) 

204 

205 async def send_eof(self) -> None: 

206 tls_version = self.extra(TLSAttribute.tls_version) 

207 match = re.match(r"TLSv(\d+)(?:\.(\d+))?", tls_version) 

208 if match: 

209 major, minor = int(match.group(1)), int(match.group(2) or 0) 

210 if (major, minor) < (1, 3): 

211 raise NotImplementedError( 

212 f"send_eof() requires at least TLSv1.3; current " 

213 f"session uses {tls_version}" 

214 ) 

215 

216 raise NotImplementedError( 

217 "send_eof() has not yet been implemented for TLS streams" 

218 ) 

219 

220 @property 

221 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

222 return { 

223 **self.transport_stream.extra_attributes, 

224 TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, 

225 TLSAttribute.channel_binding_tls_unique: self._ssl_object.get_channel_binding, 

226 TLSAttribute.cipher: self._ssl_object.cipher, 

227 TLSAttribute.peer_certificate: lambda: self._ssl_object.getpeercert(False), 

228 TLSAttribute.peer_certificate_binary: lambda: self._ssl_object.getpeercert( 

229 True 

230 ), 

231 TLSAttribute.server_side: lambda: self._ssl_object.server_side, 

232 TLSAttribute.shared_ciphers: lambda: self._ssl_object.shared_ciphers() 

233 if self._ssl_object.server_side 

234 else None, 

235 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

236 TLSAttribute.ssl_object: lambda: self._ssl_object, 

237 TLSAttribute.tls_version: self._ssl_object.version, 

238 } 

239 

240 

241@dataclass(eq=False) 

242class TLSListener(Listener[TLSStream]): 

243 """ 

244 A convenience listener that wraps another listener and auto-negotiates a TLS session on every 

245 accepted connection. 

246 

247 If the TLS handshake times out or raises an exception, :meth:`handle_handshake_error` is 

248 called to do whatever post-mortem processing is deemed necessary. 

249 

250 Supports only the :attr:`~TLSAttribute.standard_compatible` extra attribute. 

251 

252 :param Listener listener: the listener to wrap 

253 :param ssl_context: the SSL context object 

254 :param standard_compatible: a flag passed through to :meth:`TLSStream.wrap` 

255 :param handshake_timeout: time limit for the TLS handshake 

256 (passed to :func:`~anyio.fail_after`) 

257 """ 

258 

259 listener: Listener[Any] 

260 ssl_context: ssl.SSLContext 

261 standard_compatible: bool = True 

262 handshake_timeout: float = 30 

263 

264 @staticmethod 

265 async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> None: 

266 f""" 

267 Handle an exception raised during the TLS handshake. 

268 

269 This method does 3 things: 

270 

271 #. Forcefully closes the original stream 

272 #. Logs the exception (unless it was a cancellation exception) using the ``{__name__}`` 

273 logger 

274 #. Reraises the exception if it was a base exception or a cancellation exception 

275 

276 :param exc: the exception 

277 :param stream: the original stream 

278 

279 """ 

280 await aclose_forcefully(stream) 

281 

282 # Log all except cancellation exceptions 

283 if not isinstance(exc, get_cancelled_exc_class()): 

284 logging.getLogger(__name__).exception("Error during TLS handshake") 

285 

286 # Only reraise base exceptions and cancellation exceptions 

287 if not isinstance(exc, Exception) or isinstance(exc, get_cancelled_exc_class()): 

288 raise 

289 

290 async def serve( 

291 self, 

292 handler: Callable[[TLSStream], Any], 

293 task_group: TaskGroup | None = None, 

294 ) -> None: 

295 @wraps(handler) 

296 async def handler_wrapper(stream: AnyByteStream) -> None: 

297 from .. import fail_after 

298 

299 try: 

300 with fail_after(self.handshake_timeout): 

301 wrapped_stream = await TLSStream.wrap( 

302 stream, 

303 ssl_context=self.ssl_context, 

304 standard_compatible=self.standard_compatible, 

305 ) 

306 except BaseException as exc: 

307 await self.handle_handshake_error(exc, stream) 

308 else: 

309 await handler(wrapped_stream) 

310 

311 await self.listener.serve(handler_wrapper, task_group) 

312 

313 async def aclose(self) -> None: 

314 await self.listener.aclose() 

315 

316 @property 

317 def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: 

318 return { 

319 TLSAttribute.standard_compatible: lambda: self.standard_compatible, 

320 }