Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiodns/__init__.py: 29%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

250 statements  

1from __future__ import annotations 

2 

3import asyncio 

4import contextlib 

5import functools 

6import logging 

7import socket 

8import sys 

9import warnings 

10import weakref 

11from collections.abc import Callable, Iterable, Iterator, Sequence 

12from types import TracebackType 

13from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload 

14 

15import pycares 

16 

17from . import error 

18from .compat import ( 

19 AresHostResult, 

20 AresQueryAAAAResult, 

21 AresQueryAResult, 

22 AresQueryCAAResult, 

23 AresQueryCNAMEResult, 

24 AresQueryMXResult, 

25 AresQueryNAPTRResult, 

26 AresQueryNSResult, 

27 AresQueryPTRResult, 

28 AresQuerySOAResult, 

29 AresQuerySRVResult, 

30 AresQueryTXTResult, 

31 QueryResult, 

32 convert_result, 

33) 

34 

35__version__ = '4.0.4' 

36 

37__all__ = ( 

38 'DNSResolver', 

39 'error', 

40) 

41 

42_T = TypeVar('_T') 

43 

44WINDOWS_SELECTOR_ERR_MSG = ( 

45 'aiodns needs a SelectorEventLoop on Windows. See more: ' 

46 'https://github.com/aio-libs/aiodns#note-for-windows-users' 

47) 

48 

49_LOGGER = logging.getLogger(__name__) 

50 

51query_type_map = { 

52 'A': pycares.QUERY_TYPE_A, 

53 'AAAA': pycares.QUERY_TYPE_AAAA, 

54 'ANY': pycares.QUERY_TYPE_ANY, 

55 'CAA': pycares.QUERY_TYPE_CAA, 

56 'CNAME': pycares.QUERY_TYPE_CNAME, 

57 'MX': pycares.QUERY_TYPE_MX, 

58 'NAPTR': pycares.QUERY_TYPE_NAPTR, 

59 'NS': pycares.QUERY_TYPE_NS, 

60 'PTR': pycares.QUERY_TYPE_PTR, 

61 'SOA': pycares.QUERY_TYPE_SOA, 

62 'SRV': pycares.QUERY_TYPE_SRV, 

63 'TXT': pycares.QUERY_TYPE_TXT, 

64} 

65 

66query_class_map = { 

67 'IN': pycares.QUERY_CLASS_IN, 

68 'CHAOS': pycares.QUERY_CLASS_CHAOS, 

69 'HS': pycares.QUERY_CLASS_HS, 

70 'NONE': pycares.QUERY_CLASS_NONE, 

71 'ANY': pycares.QUERY_CLASS_ANY, 

72} 

73 

74 

75class DNSResolver: 

76 def __init__( 

77 self, 

78 nameservers: Sequence[str] | None = None, 

79 loop: asyncio.AbstractEventLoop | None = None, 

80 **kwargs: Any, 

81 ) -> None: # TODO(PY311): Use Unpack for kwargs. 

82 self._closed = True 

83 self.loop = loop or asyncio.get_event_loop() 

84 if TYPE_CHECKING: 

85 assert self.loop is not None 

86 kwargs.pop('sock_state_cb', None) 

87 timeout = kwargs.pop('timeout', None) 

88 self._timeout = timeout 

89 self._event_thread, self._channel = self._make_channel(**kwargs) 

90 if nameservers: 

91 self.nameservers = nameservers 

92 self._read_fds: set[int] = set() 

93 self._write_fds: set[int] = set() 

94 self._timer: asyncio.TimerHandle | None = None 

95 self._closed = False 

96 

97 def _make_channel(self, **kwargs: Any) -> tuple[bool, pycares.Channel]: 

98 # pycares 5+ uses event_thread by default when sock_state_cb 

99 # is not provided 

100 try: 

101 return True, pycares.Channel(timeout=self._timeout, **kwargs) 

102 except pycares.AresError as e: 

103 if sys.platform == 'linux': 

104 _LOGGER.warning( 

105 'Failed to create DNS resolver channel with automatic ' 

106 'monitoring of resolver configuration changes. This ' 

107 'usually means the system ran out of inotify watches. ' 

108 'Falling back to socket state callback. Consider ' 

109 'increasing the system inotify watch limit: %s', 

110 e, 

111 ) 

112 else: 

113 _LOGGER.warning( 

114 'Failed to create DNS resolver channel with automatic ' 

115 'monitoring of resolver configuration changes. ' 

116 'Falling back to socket state callback: %s', 

117 e, 

118 ) 

119 # Fall back to sock_state_cb (needs SelectorEventLoop on Windows) 

120 if sys.platform == 'win32' and not isinstance( 

121 self.loop, asyncio.SelectorEventLoop 

122 ): 

123 try: 

124 import winloop 

125 

126 if not isinstance(self.loop, winloop.Loop): 

127 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) 

128 except ModuleNotFoundError as ex: 

129 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex 

130 # Use weak reference for deterministic cleanup. Without it there's a 

131 # reference cycle (DNSResolver -> _channel -> callback -> DNSResolver). 

132 # Python 3.4+ can handle cycles with __del__, but weak ref ensures 

133 # cleanup happens immediately when last reference is dropped. 

134 weak_self = weakref.ref(self) 

135 

136 def sock_state_cb_wrapper( 

137 fd: int, readable: bool, writable: bool 

138 ) -> None: 

139 this = weak_self() 

140 if this is not None: 

141 this._sock_state_cb(fd, readable, writable) 

142 

143 return False, pycares.Channel( 

144 sock_state_cb=sock_state_cb_wrapper, 

145 timeout=self._timeout, 

146 **kwargs, 

147 ) 

148 

149 @property 

150 def nameservers(self) -> Sequence[str]: 

151 # pycares 5.x returns servers with port (e.g., '8.8.8.8:53') 

152 # Strip port for backward compatibility with pycares 4.x 

153 return [s.rsplit(':', 1)[0] for s in self._channel.servers] 

154 

155 @nameservers.setter 

156 def nameservers(self, value: Iterable[str | bytes]) -> None: 

157 self._channel.servers = value 

158 

159 def _callback( 

160 self, fut: asyncio.Future[_T], result: _T, errorno: int | None 

161 ) -> None: 

162 # The future can already be done if pycares raised synchronously 

163 # and _capture_ares_error set the exception before c-ares delivered 

164 # the same error through this callback. 

165 if fut.done(): 

166 return 

167 if errorno is not None: 

168 fut.set_exception( 

169 error.DNSError(errorno, pycares.errno.strerror(errorno)) 

170 ) 

171 else: 

172 fut.set_result(result) 

173 

174 def _get_future_callback( 

175 self, 

176 ) -> tuple[asyncio.Future[_T], Callable[[_T, int | None], None]]: 

177 """Return a future and a callback to set the result of the future.""" 

178 cb: Callable[[_T, int | None], None] 

179 future: asyncio.Future[_T] = self.loop.create_future() 

180 if self._event_thread: 

181 cb = functools.partial( # type: ignore[assignment] 

182 self.loop.call_soon_threadsafe, 

183 self._callback, # type: ignore[arg-type] 

184 future, 

185 ) 

186 else: 

187 cb = functools.partial(self._callback, future) 

188 return future, cb 

189 

190 def _query_callback( 

191 self, 

192 fut: asyncio.Future[QueryResult], 

193 qtype: int, 

194 result: pycares.DNSResult, 

195 errorno: int | None, 

196 ) -> None: 

197 """Callback for query that converts results to compatible format.""" 

198 # See _callback for why we guard on done() rather than cancelled(). 

199 if fut.done(): 

200 return 

201 if errorno is not None: 

202 fut.set_exception( 

203 error.DNSError(errorno, pycares.errno.strerror(errorno)) 

204 ) 

205 return 

206 try: 

207 converted = convert_result(result, qtype) 

208 except error.DNSError as exc: 

209 fut.set_exception(exc) 

210 else: 

211 fut.set_result(converted) 

212 

213 def _get_query_future_callback( 

214 self, qtype: int 

215 ) -> tuple[asyncio.Future[QueryResult], Callable[..., None]]: 

216 """Return a future and callback for query with result conversion.""" 

217 future: asyncio.Future[QueryResult] = self.loop.create_future() 

218 cb: Callable[..., None] 

219 if self._event_thread: 

220 cb = functools.partial( # type: ignore[assignment] 

221 self.loop.call_soon_threadsafe, 

222 self._query_callback, # type: ignore[arg-type] 

223 future, 

224 qtype, 

225 ) 

226 else: 

227 cb = functools.partial(self._query_callback, future, qtype) 

228 return future, cb 

229 

230 @contextlib.contextmanager 

231 def _capture_ares_error(self, fut: asyncio.Future[_T]) -> Iterator[None]: 

232 # When pycares raises synchronously (e.g. ARES_EBADNAME for a 

233 # malformed hostname), c-ares may also invoke the callback first, 

234 # leaving the future already done. Route the error through the 

235 # future so callers can rely on `await` to raise. 

236 try: 

237 yield 

238 except pycares.AresError as exc: 

239 if fut.done(): 

240 return 

241 # pycares always raises (errno, message), but be defensive: 

242 # an args-less AresError should still resolve the future to 

243 # avoid an indefinite hang on `await`. 

244 errno = exc.args[0] if exc.args else error.ARES_EFORMERR 

245 fut.set_exception( 

246 error.DNSError(errno, pycares.errno.strerror(errno)) 

247 ) 

248 

249 @overload 

250 def query( 

251 self, host: str, qtype: Literal['A'], qclass: str | None = ... 

252 ) -> asyncio.Future[list[AresQueryAResult]]: ... 

253 @overload 

254 def query( 

255 self, host: str, qtype: Literal['AAAA'], qclass: str | None = ... 

256 ) -> asyncio.Future[list[AresQueryAAAAResult]]: ... 

257 @overload 

258 def query( 

259 self, host: str, qtype: Literal['CAA'], qclass: str | None = ... 

260 ) -> asyncio.Future[list[AresQueryCAAResult]]: ... 

261 @overload 

262 def query( 

263 self, host: str, qtype: Literal['CNAME'], qclass: str | None = ... 

264 ) -> asyncio.Future[AresQueryCNAMEResult]: ... 

265 @overload 

266 def query( 

267 self, host: str, qtype: Literal['MX'], qclass: str | None = ... 

268 ) -> asyncio.Future[list[AresQueryMXResult]]: ... 

269 @overload 

270 def query( 

271 self, host: str, qtype: Literal['NAPTR'], qclass: str | None = ... 

272 ) -> asyncio.Future[list[AresQueryNAPTRResult]]: ... 

273 @overload 

274 def query( 

275 self, host: str, qtype: Literal['NS'], qclass: str | None = ... 

276 ) -> asyncio.Future[list[AresQueryNSResult]]: ... 

277 @overload 

278 def query( 

279 self, host: str, qtype: Literal['PTR'], qclass: str | None = ... 

280 ) -> asyncio.Future[AresQueryPTRResult]: ... 

281 @overload 

282 def query( 

283 self, host: str, qtype: Literal['SOA'], qclass: str | None = ... 

284 ) -> asyncio.Future[AresQuerySOAResult]: ... 

285 @overload 

286 def query( 

287 self, host: str, qtype: Literal['SRV'], qclass: str | None = ... 

288 ) -> asyncio.Future[list[AresQuerySRVResult]]: ... 

289 @overload 

290 def query( 

291 self, host: str, qtype: Literal['TXT'], qclass: str | None = ... 

292 ) -> asyncio.Future[list[AresQueryTXTResult]]: ... 

293 

294 def query( 

295 self, host: str, qtype: str, qclass: str | None = None 

296 ) -> asyncio.Future[list[Any]] | asyncio.Future[Any]: 

297 """Query DNS records (deprecated, use query_dns instead).""" 

298 warnings.warn( 

299 'query() is deprecated, use query_dns() instead', 

300 DeprecationWarning, 

301 stacklevel=2, 

302 ) 

303 try: 

304 qtype_int = query_type_map[qtype] 

305 except KeyError as e: 

306 raise ValueError(f'invalid query type: {qtype}') from e 

307 qclass_int: int | None = None 

308 if qclass is not None: 

309 try: 

310 qclass_int = query_class_map[qclass] 

311 except KeyError as e: 

312 raise ValueError(f'invalid query class: {qclass}') from e 

313 

314 fut, cb = self._get_query_future_callback(qtype_int) 

315 with self._capture_ares_error(fut): 

316 if qclass_int is not None: 

317 self._channel.query( 

318 host, qtype_int, query_class=qclass_int, callback=cb 

319 ) 

320 else: 

321 self._channel.query(host, qtype_int, callback=cb) 

322 return fut 

323 

324 def query_dns( 

325 self, host: str, qtype: str, qclass: str | None = None 

326 ) -> asyncio.Future[pycares.DNSResult]: 

327 """Query DNS records, returning native pycares 5.x DNSResult.""" 

328 try: 

329 qtype_int = query_type_map[qtype] 

330 except KeyError as e: 

331 raise ValueError(f'invalid query type: {qtype}') from e 

332 qclass_int: int | None = None 

333 if qclass is not None: 

334 try: 

335 qclass_int = query_class_map[qclass] 

336 except KeyError as e: 

337 raise ValueError(f'invalid query class: {qclass}') from e 

338 

339 fut: asyncio.Future[pycares.DNSResult] 

340 fut, cb = self._get_future_callback() 

341 with self._capture_ares_error(fut): 

342 if qclass_int is not None: 

343 self._channel.query( 

344 host, qtype_int, query_class=qclass_int, callback=cb 

345 ) 

346 else: 

347 self._channel.query(host, qtype_int, callback=cb) 

348 return fut 

349 

350 def _gethostbyname_callback( 

351 self, 

352 fut: asyncio.Future[AresHostResult], 

353 host: str, 

354 result: pycares.AddrInfoResult | None, 

355 errorno: int | None, 

356 ) -> None: 

357 """Callback for gethostbyname that converts AddrInfoResult.""" 

358 # See _callback for why we guard on done() rather than cancelled(). 

359 if fut.done(): 

360 return 

361 if errorno is not None: 

362 fut.set_exception( 

363 error.DNSError(errorno, pycares.errno.strerror(errorno)) 

364 ) 

365 else: 

366 assert result is not None # noqa: S101 

367 # node.addr is (address_bytes, port) - extract and decode 

368 addresses = [node.addr[0].decode() for node in result.nodes] 

369 # Get canonical name from cnames if available 

370 name = result.cnames[0].name if result.cnames else host 

371 fut.set_result( 

372 AresHostResult(name=name, aliases=[], addresses=addresses) 

373 ) 

374 

375 def gethostbyname( 

376 self, host: str, family: socket.AddressFamily 

377 ) -> asyncio.Future[AresHostResult]: 

378 """ 

379 Resolve hostname to addresses. 

380 

381 Deprecated: Use getaddrinfo() instead. This is implemented using 

382 getaddrinfo as pycares 5.x removed the gethostbyname method. 

383 """ 

384 warnings.warn( 

385 'gethostbyname() is deprecated, use getaddrinfo() instead', 

386 DeprecationWarning, 

387 stacklevel=2, 

388 ) 

389 fut: asyncio.Future[AresHostResult] = self.loop.create_future() 

390 cb: Callable[..., None] 

391 if self._event_thread: 

392 cb = functools.partial( # type: ignore[assignment] 

393 self.loop.call_soon_threadsafe, 

394 self._gethostbyname_callback, # type: ignore[arg-type] 

395 fut, 

396 host, 

397 ) 

398 else: 

399 cb = functools.partial(self._gethostbyname_callback, fut, host) 

400 with self._capture_ares_error(fut): 

401 self._channel.getaddrinfo(host, None, family=family, callback=cb) 

402 return fut 

403 

404 def getaddrinfo( 

405 self, 

406 host: str, 

407 family: socket.AddressFamily = socket.AF_UNSPEC, 

408 port: int | None = None, 

409 proto: int = 0, 

410 type: int = 0, 

411 flags: int = 0, 

412 ) -> asyncio.Future[pycares.AddrInfoResult]: 

413 fut: asyncio.Future[pycares.AddrInfoResult] 

414 fut, cb = self._get_future_callback() 

415 with self._capture_ares_error(fut): 

416 self._channel.getaddrinfo( 

417 host, 

418 port, 

419 family=family, 

420 type=type, 

421 proto=proto, 

422 flags=flags, 

423 callback=cb, 

424 ) 

425 return fut 

426 

427 def getnameinfo( 

428 self, 

429 sockaddr: tuple[str, int] | tuple[str, int, int, int], 

430 flags: int = 0, 

431 ) -> asyncio.Future[pycares.NameInfoResult]: 

432 fut: asyncio.Future[pycares.NameInfoResult] 

433 fut, cb = self._get_future_callback() 

434 with self._capture_ares_error(fut): 

435 self._channel.getnameinfo(sockaddr, flags, callback=cb) 

436 return fut 

437 

438 def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]: 

439 fut: asyncio.Future[pycares.HostResult] 

440 fut, cb = self._get_future_callback() 

441 with self._capture_ares_error(fut): 

442 self._channel.gethostbyaddr(name, callback=cb) 

443 return fut 

444 

445 def cancel(self) -> None: 

446 self._channel.cancel() 

447 

448 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: 

449 if readable or writable: 

450 if readable: 

451 self.loop.add_reader( 

452 fd, self._channel.process_fd, fd, pycares.ARES_SOCKET_BAD 

453 ) 

454 self._read_fds.add(fd) 

455 if writable: 

456 self.loop.add_writer( 

457 fd, self._channel.process_fd, pycares.ARES_SOCKET_BAD, fd 

458 ) 

459 self._write_fds.add(fd) 

460 if self._timer is None: 

461 self._start_timer() 

462 else: 

463 # socket is now closed 

464 if fd in self._read_fds: 

465 self._read_fds.discard(fd) 

466 self.loop.remove_reader(fd) 

467 

468 if fd in self._write_fds: 

469 self._write_fds.discard(fd) 

470 self.loop.remove_writer(fd) 

471 

472 if ( 

473 not self._read_fds 

474 and not self._write_fds 

475 and self._timer is not None 

476 ): 

477 self._timer.cancel() 

478 self._timer = None 

479 

480 def _timer_cb(self) -> None: 

481 if self._read_fds or self._write_fds: 

482 self._channel.process_fd( 

483 pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD 

484 ) 

485 self._start_timer() 

486 else: 

487 self._timer = None 

488 

489 def _start_timer(self) -> None: 

490 timeout = self._timeout 

491 if timeout is None or timeout < 0 or timeout > 1: 

492 timeout = 1 

493 elif timeout == 0: 

494 timeout = 0.1 

495 

496 self._timer = self.loop.call_later(timeout, self._timer_cb) 

497 

498 def _cleanup(self) -> None: 

499 """Cleanup timers and file descriptors when closing resolver.""" 

500 if self._closed: 

501 return 

502 # Mark as closed first to prevent double cleanup 

503 self._closed = True 

504 # Cancel timer if running 

505 if self._timer is not None: 

506 self._timer.cancel() 

507 self._timer = None 

508 

509 # Remove all file descriptors 

510 for fd in self._read_fds: 

511 self.loop.remove_reader(fd) 

512 for fd in self._write_fds: 

513 self.loop.remove_writer(fd) 

514 

515 self._read_fds.clear() 

516 self._write_fds.clear() 

517 self._channel.close() 

518 

519 async def close(self) -> None: 

520 """ 

521 Cleanly close the DNS resolver. 

522 

523 This should be called to ensure all resources are properly released. 

524 After calling close(), the resolver should not be used again. 

525 """ 

526 if not self._closed: 

527 self._channel.cancel() 

528 self._cleanup() 

529 

530 async def __aenter__(self) -> DNSResolver: 

531 """Enter the async context manager.""" 

532 return self 

533 

534 async def __aexit__( 

535 self, 

536 exc_type: type[BaseException] | None, 

537 exc_val: BaseException | None, 

538 exc_tb: TracebackType | None, 

539 ) -> None: 

540 """Exit the async context manager.""" 

541 await self.close() 

542 

543 def __del__(self) -> None: 

544 """Handle cleanup when the resolver is garbage collected.""" 

545 # Check if we have a channel to clean up 

546 # This can happen if an exception occurs during __init__ before 

547 # _channel is created (e.g., RuntimeError on Windows 

548 # without proper loop) 

549 if hasattr(self, '_channel'): 

550 self._cleanup()