Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/anyio/_core/_sockets.py: 23%

189 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 07:19 +0000

1from __future__ import annotations 

2 

3import socket 

4import ssl 

5import sys 

6from ipaddress import IPv6Address, ip_address 

7from os import PathLike, chmod 

8from pathlib import Path 

9from socket import AddressFamily, SocketKind 

10from typing import Awaitable, List, Tuple, cast, overload 

11 

12from .. import to_thread 

13from ..abc import ( 

14 ConnectedUDPSocket, 

15 IPAddressType, 

16 IPSockAddrType, 

17 SocketListener, 

18 SocketStream, 

19 UDPSocket, 

20 UNIXSocketStream, 

21) 

22from ..streams.stapled import MultiListener 

23from ..streams.tls import TLSStream 

24from ._eventloop import get_asynclib 

25from ._resources import aclose_forcefully 

26from ._synchronization import Event 

27from ._tasks import create_task_group, move_on_after 

28 

29if sys.version_info >= (3, 8): 

30 from typing import Literal 

31else: 

32 from typing_extensions import Literal 

33 

34IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515 

35 

36GetAddrInfoReturnType = List[ 

37 Tuple[AddressFamily, SocketKind, int, str, Tuple[str, int]] 

38] 

39AnyIPAddressFamily = Literal[ 

40 AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6 

41] 

42IPAddressFamily = Literal[AddressFamily.AF_INET, AddressFamily.AF_INET6] 

43 

44 

45# tls_hostname given 

46@overload 

47async def connect_tcp( 

48 remote_host: IPAddressType, 

49 remote_port: int, 

50 *, 

51 local_host: IPAddressType | None = ..., 

52 ssl_context: ssl.SSLContext | None = ..., 

53 tls_standard_compatible: bool = ..., 

54 tls_hostname: str, 

55 happy_eyeballs_delay: float = ..., 

56) -> TLSStream: 

57 ... 

58 

59 

60# ssl_context given 

61@overload 

62async def connect_tcp( 

63 remote_host: IPAddressType, 

64 remote_port: int, 

65 *, 

66 local_host: IPAddressType | None = ..., 

67 ssl_context: ssl.SSLContext, 

68 tls_standard_compatible: bool = ..., 

69 tls_hostname: str | None = ..., 

70 happy_eyeballs_delay: float = ..., 

71) -> TLSStream: 

72 ... 

73 

74 

75# tls=True 

76@overload 

77async def connect_tcp( 

78 remote_host: IPAddressType, 

79 remote_port: int, 

80 *, 

81 local_host: IPAddressType | None = ..., 

82 tls: Literal[True], 

83 ssl_context: ssl.SSLContext | None = ..., 

84 tls_standard_compatible: bool = ..., 

85 tls_hostname: str | None = ..., 

86 happy_eyeballs_delay: float = ..., 

87) -> TLSStream: 

88 ... 

89 

90 

91# tls=False 

92@overload 

93async def connect_tcp( 

94 remote_host: IPAddressType, 

95 remote_port: int, 

96 *, 

97 local_host: IPAddressType | None = ..., 

98 tls: Literal[False], 

99 ssl_context: ssl.SSLContext | None = ..., 

100 tls_standard_compatible: bool = ..., 

101 tls_hostname: str | None = ..., 

102 happy_eyeballs_delay: float = ..., 

103) -> SocketStream: 

104 ... 

105 

106 

107# No TLS arguments 

108@overload 

109async def connect_tcp( 

110 remote_host: IPAddressType, 

111 remote_port: int, 

112 *, 

113 local_host: IPAddressType | None = ..., 

114 happy_eyeballs_delay: float = ..., 

115) -> SocketStream: 

116 ... 

117 

118 

119async def connect_tcp( 

120 remote_host: IPAddressType, 

121 remote_port: int, 

122 *, 

123 local_host: IPAddressType | None = None, 

124 tls: bool = False, 

125 ssl_context: ssl.SSLContext | None = None, 

126 tls_standard_compatible: bool = True, 

127 tls_hostname: str | None = None, 

128 happy_eyeballs_delay: float = 0.25, 

129) -> SocketStream | TLSStream: 

130 """ 

131 Connect to a host using the TCP protocol. 

132 

133 This function implements the stateless version of the Happy Eyeballs algorithm (RFC 

134 6555). If ``remote_host`` is a host name that resolves to multiple IP addresses, 

135 each one is tried until one connection attempt succeeds. If the first attempt does 

136 not connected within 250 milliseconds, a second attempt is started using the next 

137 address in the list, and so on. On IPv6 enabled systems, an IPv6 address (if 

138 available) is tried first. 

139 

140 When the connection has been established, a TLS handshake will be done if either 

141 ``ssl_context`` or ``tls_hostname`` is not ``None``, or if ``tls`` is ``True``. 

142 

143 :param remote_host: the IP address or host name to connect to 

144 :param remote_port: port on the target host to connect to 

145 :param local_host: the interface address or name to bind the socket to before connecting 

146 :param tls: ``True`` to do a TLS handshake with the connected stream and return a 

147 :class:`~anyio.streams.tls.TLSStream` instead 

148 :param ssl_context: the SSL context object to use (if omitted, a default context is created) 

149 :param tls_standard_compatible: If ``True``, performs the TLS shutdown handshake before closing 

150 the stream and requires that the server does this as well. Otherwise, 

151 :exc:`~ssl.SSLEOFError` may be raised during reads from the stream. 

152 Some protocols, such as HTTP, require this option to be ``False``. 

153 See :meth:`~ssl.SSLContext.wrap_socket` for details. 

154 :param tls_hostname: host name to check the server certificate against (defaults to the value 

155 of ``remote_host``) 

156 :param happy_eyeballs_delay: delay (in seconds) before starting the next connection attempt 

157 :return: a socket stream object if no TLS handshake was done, otherwise a TLS stream 

158 :raises OSError: if the connection attempt fails 

159 

160 """ 

161 # Placed here due to https://github.com/python/mypy/issues/7057 

162 connected_stream: SocketStream | None = None 

163 

164 async def try_connect(remote_host: str, event: Event) -> None: 

165 nonlocal connected_stream 

166 try: 

167 stream = await asynclib.connect_tcp(remote_host, remote_port, local_address) 

168 except OSError as exc: 

169 oserrors.append(exc) 

170 return 

171 else: 

172 if connected_stream is None: 

173 connected_stream = stream 

174 tg.cancel_scope.cancel() 

175 else: 

176 await stream.aclose() 

177 finally: 

178 event.set() 

179 

180 asynclib = get_asynclib() 

181 local_address: IPSockAddrType | None = None 

182 family = socket.AF_UNSPEC 

183 if local_host: 

184 gai_res = await getaddrinfo(str(local_host), None) 

185 family, *_, local_address = gai_res[0] 

186 

187 target_host = str(remote_host) 

188 try: 

189 addr_obj = ip_address(remote_host) 

190 except ValueError: 

191 # getaddrinfo() will raise an exception if name resolution fails 

192 gai_res = await getaddrinfo( 

193 target_host, remote_port, family=family, type=socket.SOCK_STREAM 

194 ) 

195 

196 # Organize the list so that the first address is an IPv6 address (if available) and the 

197 # second one is an IPv4 addresses. The rest can be in whatever order. 

198 v6_found = v4_found = False 

199 target_addrs: list[tuple[socket.AddressFamily, str]] = [] 

200 for af, *rest, sa in gai_res: 

201 if af == socket.AF_INET6 and not v6_found: 

202 v6_found = True 

203 target_addrs.insert(0, (af, sa[0])) 

204 elif af == socket.AF_INET and not v4_found and v6_found: 

205 v4_found = True 

206 target_addrs.insert(1, (af, sa[0])) 

207 else: 

208 target_addrs.append((af, sa[0])) 

209 else: 

210 if isinstance(addr_obj, IPv6Address): 

211 target_addrs = [(socket.AF_INET6, addr_obj.compressed)] 

212 else: 

213 target_addrs = [(socket.AF_INET, addr_obj.compressed)] 

214 

215 oserrors: list[OSError] = [] 

216 async with create_task_group() as tg: 

217 for i, (af, addr) in enumerate(target_addrs): 

218 event = Event() 

219 tg.start_soon(try_connect, addr, event) 

220 with move_on_after(happy_eyeballs_delay): 

221 await event.wait() 

222 

223 if connected_stream is None: 

224 cause = oserrors[0] if len(oserrors) == 1 else asynclib.ExceptionGroup(oserrors) 

225 raise OSError("All connection attempts failed") from cause 

226 

227 if tls or tls_hostname or ssl_context: 

228 try: 

229 return await TLSStream.wrap( 

230 connected_stream, 

231 server_side=False, 

232 hostname=tls_hostname or str(remote_host), 

233 ssl_context=ssl_context, 

234 standard_compatible=tls_standard_compatible, 

235 ) 

236 except BaseException: 

237 await aclose_forcefully(connected_stream) 

238 raise 

239 

240 return connected_stream 

241 

242 

243async def connect_unix(path: str | PathLike[str]) -> UNIXSocketStream: 

244 """ 

245 Connect to the given UNIX socket. 

246 

247 Not available on Windows. 

248 

249 :param path: path to the socket 

250 :return: a socket stream object 

251 

252 """ 

253 path = str(Path(path)) 

254 return await get_asynclib().connect_unix(path) 

255 

256 

257async def create_tcp_listener( 

258 *, 

259 local_host: IPAddressType | None = None, 

260 local_port: int = 0, 

261 family: AnyIPAddressFamily = socket.AddressFamily.AF_UNSPEC, 

262 backlog: int = 65536, 

263 reuse_port: bool = False, 

264) -> MultiListener[SocketStream]: 

265 """ 

266 Create a TCP socket listener. 

267 

268 :param local_port: port number to listen on 

269 :param local_host: IP address of the interface to listen on. If omitted, listen on 

270 all IPv4 and IPv6 interfaces. To listen on all interfaces on a specific address 

271 family, use ``0.0.0.0`` for IPv4 or ``::`` for IPv6. 

272 :param family: address family (used if ``local_host`` was omitted) 

273 :param backlog: maximum number of queued incoming connections (up to a maximum of 

274 2**16, or 65536) 

275 :param reuse_port: ``True`` to allow multiple sockets to bind to the same 

276 address/port (not supported on Windows) 

277 :return: a list of listener objects 

278 

279 """ 

280 asynclib = get_asynclib() 

281 backlog = min(backlog, 65536) 

282 local_host = str(local_host) if local_host is not None else None 

283 gai_res = await getaddrinfo( 

284 local_host, # type: ignore[arg-type] 

285 local_port, 

286 family=family, 

287 type=socket.SocketKind.SOCK_STREAM if sys.platform == "win32" else 0, 

288 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, 

289 ) 

290 listeners: list[SocketListener] = [] 

291 try: 

292 # The set() is here to work around a glibc bug: 

293 # https://sourceware.org/bugzilla/show_bug.cgi?id=14969 

294 sockaddr: tuple[str, int] | tuple[str, int, int, int] 

295 for fam, kind, *_, sockaddr in sorted(set(gai_res)): 

296 # Workaround for an uvloop bug where we don't get the correct scope ID for 

297 # IPv6 link-local addresses when passing type=socket.SOCK_STREAM to 

298 # getaddrinfo(): https://github.com/MagicStack/uvloop/issues/539 

299 if sys.platform != "win32" and kind is not SocketKind.SOCK_STREAM: 

300 continue 

301 

302 raw_socket = socket.socket(fam) 

303 raw_socket.setblocking(False) 

304 

305 # For Windows, enable exclusive address use. For others, enable address reuse. 

306 if sys.platform == "win32": 

307 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_EXCLUSIVEADDRUSE, 1) 

308 else: 

309 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 

310 

311 if reuse_port: 

312 raw_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) 

313 

314 # If only IPv6 was requested, disable dual stack operation 

315 if fam == socket.AF_INET6: 

316 raw_socket.setsockopt(IPPROTO_IPV6, socket.IPV6_V6ONLY, 1) 

317 

318 # Workaround for #554 

319 if "%" in sockaddr[0]: 

320 addr, scope_id = sockaddr[0].split("%", 1) 

321 sockaddr = (addr, sockaddr[1], 0, int(scope_id)) 

322 

323 raw_socket.bind(sockaddr) 

324 raw_socket.listen(backlog) 

325 listener = asynclib.TCPSocketListener(raw_socket) 

326 listeners.append(listener) 

327 except BaseException: 

328 for listener in listeners: 

329 await listener.aclose() 

330 

331 raise 

332 

333 return MultiListener(listeners) 

334 

335 

336async def create_unix_listener( 

337 path: str | PathLike[str], 

338 *, 

339 mode: int | None = None, 

340 backlog: int = 65536, 

341) -> SocketListener: 

342 """ 

343 Create a UNIX socket listener. 

344 

345 Not available on Windows. 

346 

347 :param path: path of the socket 

348 :param mode: permissions to set on the socket 

349 :param backlog: maximum number of queued incoming connections (up to a maximum of 2**16, or 

350 65536) 

351 :return: a listener object 

352 

353 .. versionchanged:: 3.0 

354 If a socket already exists on the file system in the given path, it will be removed first. 

355 

356 """ 

357 path_str = str(path) 

358 path = Path(path) 

359 if path.is_socket(): 

360 path.unlink() 

361 

362 backlog = min(backlog, 65536) 

363 raw_socket = socket.socket(socket.AF_UNIX) 

364 raw_socket.setblocking(False) 

365 try: 

366 await to_thread.run_sync(raw_socket.bind, path_str, cancellable=True) 

367 if mode is not None: 

368 await to_thread.run_sync(chmod, path_str, mode, cancellable=True) 

369 

370 raw_socket.listen(backlog) 

371 return get_asynclib().UNIXSocketListener(raw_socket) 

372 except BaseException: 

373 raw_socket.close() 

374 raise 

375 

376 

377async def create_udp_socket( 

378 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, 

379 *, 

380 local_host: IPAddressType | None = None, 

381 local_port: int = 0, 

382 reuse_port: bool = False, 

383) -> UDPSocket: 

384 """ 

385 Create a UDP socket. 

386 

387 If ``local_port`` has been given, the socket will be bound to this port on the local 

388 machine, making this socket suitable for providing UDP based services. 

389 

390 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from 

391 ``local_host`` if omitted 

392 :param local_host: IP address or host name of the local interface to bind to 

393 :param local_port: local port to bind to 

394 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port 

395 (not supported on Windows) 

396 :return: a UDP socket 

397 

398 """ 

399 if family is AddressFamily.AF_UNSPEC and not local_host: 

400 raise ValueError('Either "family" or "local_host" must be given') 

401 

402 if local_host: 

403 gai_res = await getaddrinfo( 

404 str(local_host), 

405 local_port, 

406 family=family, 

407 type=socket.SOCK_DGRAM, 

408 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, 

409 ) 

410 family = cast(AnyIPAddressFamily, gai_res[0][0]) 

411 local_address = gai_res[0][-1] 

412 elif family is AddressFamily.AF_INET6: 

413 local_address = ("::", 0) 

414 else: 

415 local_address = ("0.0.0.0", 0) 

416 

417 return await get_asynclib().create_udp_socket( 

418 family, local_address, None, reuse_port 

419 ) 

420 

421 

422async def create_connected_udp_socket( 

423 remote_host: IPAddressType, 

424 remote_port: int, 

425 *, 

426 family: AnyIPAddressFamily = AddressFamily.AF_UNSPEC, 

427 local_host: IPAddressType | None = None, 

428 local_port: int = 0, 

429 reuse_port: bool = False, 

430) -> ConnectedUDPSocket: 

431 """ 

432 Create a connected UDP socket. 

433 

434 Connected UDP sockets can only communicate with the specified remote host/port, and any packets 

435 sent from other sources are dropped. 

436 

437 :param remote_host: remote host to set as the default target 

438 :param remote_port: port on the remote host to set as the default target 

439 :param family: address family (``AF_INET`` or ``AF_INET6``) – automatically determined from 

440 ``local_host`` or ``remote_host`` if omitted 

441 :param local_host: IP address or host name of the local interface to bind to 

442 :param local_port: local port to bind to 

443 :param reuse_port: ``True`` to allow multiple sockets to bind to the same address/port 

444 (not supported on Windows) 

445 :return: a connected UDP socket 

446 

447 """ 

448 local_address = None 

449 if local_host: 

450 gai_res = await getaddrinfo( 

451 str(local_host), 

452 local_port, 

453 family=family, 

454 type=socket.SOCK_DGRAM, 

455 flags=socket.AI_PASSIVE | socket.AI_ADDRCONFIG, 

456 ) 

457 family = cast(AnyIPAddressFamily, gai_res[0][0]) 

458 local_address = gai_res[0][-1] 

459 

460 gai_res = await getaddrinfo( 

461 str(remote_host), remote_port, family=family, type=socket.SOCK_DGRAM 

462 ) 

463 family = cast(AnyIPAddressFamily, gai_res[0][0]) 

464 remote_address = gai_res[0][-1] 

465 

466 return await get_asynclib().create_udp_socket( 

467 family, local_address, remote_address, reuse_port 

468 ) 

469 

470 

471async def getaddrinfo( 

472 host: bytearray | bytes | str, 

473 port: str | int | None, 

474 *, 

475 family: int | AddressFamily = 0, 

476 type: int | SocketKind = 0, 

477 proto: int = 0, 

478 flags: int = 0, 

479) -> GetAddrInfoReturnType: 

480 """ 

481 Look up a numeric IP address given a host name. 

482 

483 Internationalized domain names are translated according to the (non-transitional) IDNA 2008 

484 standard. 

485 

486 .. note:: 4-tuple IPv6 socket addresses are automatically converted to 2-tuples of 

487 (host, port), unlike what :func:`socket.getaddrinfo` does. 

488 

489 :param host: host name 

490 :param port: port number 

491 :param family: socket family (`'AF_INET``, ...) 

492 :param type: socket type (``SOCK_STREAM``, ...) 

493 :param proto: protocol number 

494 :param flags: flags to pass to upstream ``getaddrinfo()`` 

495 :return: list of tuples containing (family, type, proto, canonname, sockaddr) 

496 

497 .. seealso:: :func:`socket.getaddrinfo` 

498 

499 """ 

500 # Handle unicode hostnames 

501 if isinstance(host, str): 

502 try: 

503 encoded_host = host.encode("ascii") 

504 except UnicodeEncodeError: 

505 import idna 

506 

507 encoded_host = idna.encode(host, uts46=True) 

508 else: 

509 encoded_host = host 

510 

511 gai_res = await get_asynclib().getaddrinfo( 

512 encoded_host, port, family=family, type=type, proto=proto, flags=flags 

513 ) 

514 return [ 

515 (family, type, proto, canonname, convert_ipv6_sockaddr(sockaddr)) 

516 for family, type, proto, canonname, sockaddr in gai_res 

517 ] 

518 

519 

520def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str, str]]: 

521 """ 

522 Look up the host name of an IP address. 

523 

524 :param sockaddr: socket address (e.g. (ipaddress, port) for IPv4) 

525 :param flags: flags to pass to upstream ``getnameinfo()`` 

526 :return: a tuple of (host name, service name) 

527 

528 .. seealso:: :func:`socket.getnameinfo` 

529 

530 """ 

531 return get_asynclib().getnameinfo(sockaddr, flags) 

532 

533 

534def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: 

535 """ 

536 Wait until the given socket has data to be read. 

537 

538 This does **NOT** work on Windows when using the asyncio backend with a proactor event loop 

539 (default on py3.8+). 

540 

541 .. warning:: Only use this on raw sockets that have not been wrapped by any higher level 

542 constructs like socket streams! 

543 

544 :param sock: a socket object 

545 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the 

546 socket to become readable 

547 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket 

548 to become readable 

549 

550 """ 

551 return get_asynclib().wait_socket_readable(sock) 

552 

553 

554def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: 

555 """ 

556 Wait until the given socket can be written to. 

557 

558 This does **NOT** work on Windows when using the asyncio backend with a proactor event loop 

559 (default on py3.8+). 

560 

561 .. warning:: Only use this on raw sockets that have not been wrapped by any higher level 

562 constructs like socket streams! 

563 

564 :param sock: a socket object 

565 :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the 

566 socket to become writable 

567 :raises ~anyio.BusyResourceError: if another task is already waiting for the socket 

568 to become writable 

569 

570 """ 

571 return get_asynclib().wait_socket_writable(sock) 

572 

573 

574# 

575# Private API 

576# 

577 

578 

579def convert_ipv6_sockaddr( 

580 sockaddr: tuple[str, int, int, int] | tuple[str, int] 

581) -> tuple[str, int]: 

582 """ 

583 Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. 

584 

585 If the scope ID is nonzero, it is added to the address, separated with ``%``. 

586 Otherwise the flow id and scope id are simply cut off from the tuple. 

587 Any other kinds of socket addresses are returned as-is. 

588 

589 :param sockaddr: the result of :meth:`~socket.socket.getsockname` 

590 :return: the converted socket address 

591 

592 """ 

593 # This is more complicated than it should be because of MyPy 

594 if isinstance(sockaddr, tuple) and len(sockaddr) == 4: 

595 host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) 

596 if scope_id: 

597 # PyPy (as of v7.3.11) leaves the interface name in the result, so 

598 # we discard it and only get the scope ID from the end 

599 # (https://foss.heptapod.net/pypy/pypy/-/issues/3938) 

600 host = host.split("%")[0] 

601 

602 # Add scope_id to the address 

603 return f"{host}%{scope_id}", port 

604 else: 

605 return host, port 

606 else: 

607 return cast(Tuple[str, int], sockaddr)