Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiodns/__init__.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

229 statements  

1from __future__ import annotations 

2 

3import asyncio 

4import functools 

5import logging 

6import socket 

7import sys 

8import warnings 

9import weakref 

10from collections.abc import Callable, Iterable, Sequence 

11from types import TracebackType 

12from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload 

13 

14import pycares 

15 

16from . import error 

17from .compat import ( 

18 AresHostResult, 

19 AresQueryAAAAResult, 

20 AresQueryAResult, 

21 AresQueryCAAResult, 

22 AresQueryCNAMEResult, 

23 AresQueryMXResult, 

24 AresQueryNAPTRResult, 

25 AresQueryNSResult, 

26 AresQueryPTRResult, 

27 AresQuerySOAResult, 

28 AresQuerySRVResult, 

29 AresQueryTXTResult, 

30 QueryResult, 

31 convert_result, 

32) 

33 

34__version__ = '4.0.0' 

35 

36__all__ = ( 

37 'DNSResolver', 

38 'error', 

39) 

40 

41_T = TypeVar('_T') 

42 

43WINDOWS_SELECTOR_ERR_MSG = ( 

44 'aiodns needs a SelectorEventLoop on Windows. See more: ' 

45 'https://github.com/aio-libs/aiodns#note-for-windows-users' 

46) 

47 

48_LOGGER = logging.getLogger(__name__) 

49 

50query_type_map = { 

51 'A': pycares.QUERY_TYPE_A, 

52 'AAAA': pycares.QUERY_TYPE_AAAA, 

53 'ANY': pycares.QUERY_TYPE_ANY, 

54 'CAA': pycares.QUERY_TYPE_CAA, 

55 'CNAME': pycares.QUERY_TYPE_CNAME, 

56 'MX': pycares.QUERY_TYPE_MX, 

57 'NAPTR': pycares.QUERY_TYPE_NAPTR, 

58 'NS': pycares.QUERY_TYPE_NS, 

59 'PTR': pycares.QUERY_TYPE_PTR, 

60 'SOA': pycares.QUERY_TYPE_SOA, 

61 'SRV': pycares.QUERY_TYPE_SRV, 

62 'TXT': pycares.QUERY_TYPE_TXT, 

63} 

64 

65query_class_map = { 

66 'IN': pycares.QUERY_CLASS_IN, 

67 'CHAOS': pycares.QUERY_CLASS_CHAOS, 

68 'HS': pycares.QUERY_CLASS_HS, 

69 'NONE': pycares.QUERY_CLASS_NONE, 

70 'ANY': pycares.QUERY_CLASS_ANY, 

71} 

72 

73 

74class DNSResolver: 

75 def __init__( 

76 self, 

77 nameservers: Sequence[str] | None = None, 

78 loop: asyncio.AbstractEventLoop | None = None, 

79 **kwargs: Any, 

80 ) -> None: # TODO(PY311): Use Unpack for kwargs. 

81 self._closed = True 

82 self.loop = loop or asyncio.get_event_loop() 

83 if TYPE_CHECKING: 

84 assert self.loop is not None 

85 kwargs.pop('sock_state_cb', None) 

86 timeout = kwargs.pop('timeout', None) 

87 self._timeout = timeout 

88 self._event_thread, self._channel = self._make_channel(**kwargs) 

89 if nameservers: 

90 self.nameservers = nameservers 

91 self._read_fds: set[int] = set() 

92 self._write_fds: set[int] = set() 

93 self._timer: asyncio.TimerHandle | None = None 

94 self._closed = False 

95 

96 def _make_channel(self, **kwargs: Any) -> tuple[bool, pycares.Channel]: 

97 # pycares 5+ uses event_thread by default when sock_state_cb 

98 # is not provided 

99 try: 

100 return True, pycares.Channel(timeout=self._timeout, **kwargs) 

101 except pycares.AresError as e: 

102 if sys.platform == 'linux': 

103 _LOGGER.warning( 

104 'Failed to create DNS resolver channel with automatic ' 

105 'monitoring of resolver configuration changes. This ' 

106 'usually means the system ran out of inotify watches. ' 

107 'Falling back to socket state callback. Consider ' 

108 'increasing the system inotify watch limit: %s', 

109 e, 

110 ) 

111 else: 

112 _LOGGER.warning( 

113 'Failed to create DNS resolver channel with automatic ' 

114 'monitoring of resolver configuration changes. ' 

115 'Falling back to socket state callback: %s', 

116 e, 

117 ) 

118 # Fall back to sock_state_cb (needs SelectorEventLoop on Windows) 

119 if sys.platform == 'win32' and not isinstance( 

120 self.loop, asyncio.SelectorEventLoop 

121 ): 

122 try: 

123 import winloop 

124 

125 if not isinstance(self.loop, winloop.Loop): 

126 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) 

127 except ModuleNotFoundError as ex: 

128 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex 

129 # Use weak reference for deterministic cleanup. Without it there's a 

130 # reference cycle (DNSResolver -> _channel -> callback -> DNSResolver). 

131 # Python 3.4+ can handle cycles with __del__, but weak ref ensures 

132 # cleanup happens immediately when last reference is dropped. 

133 weak_self = weakref.ref(self) 

134 

135 def sock_state_cb_wrapper( 

136 fd: int, readable: bool, writable: bool 

137 ) -> None: 

138 this = weak_self() 

139 if this is not None: 

140 this._sock_state_cb(fd, readable, writable) 

141 

142 return False, pycares.Channel( 

143 sock_state_cb=sock_state_cb_wrapper, 

144 timeout=self._timeout, 

145 **kwargs, 

146 ) 

147 

148 @property 

149 def nameservers(self) -> Sequence[str]: 

150 # pycares 5.x returns servers with port (e.g., '8.8.8.8:53') 

151 # Strip port for backward compatibility with pycares 4.x 

152 return [s.rsplit(':', 1)[0] for s in self._channel.servers] 

153 

154 @nameservers.setter 

155 def nameservers(self, value: Iterable[str | bytes]) -> None: 

156 self._channel.servers = value 

157 

158 def _callback( 

159 self, fut: asyncio.Future[_T], result: _T, errorno: int | None 

160 ) -> None: 

161 if fut.cancelled(): 

162 return 

163 if errorno is not None: 

164 fut.set_exception( 

165 error.DNSError(errorno, pycares.errno.strerror(errorno)) 

166 ) 

167 else: 

168 fut.set_result(result) 

169 

170 def _get_future_callback( 

171 self, 

172 ) -> tuple[asyncio.Future[_T], Callable[[_T, int | None], None]]: 

173 """Return a future and a callback to set the result of the future.""" 

174 cb: Callable[[_T, int | None], None] 

175 future: asyncio.Future[_T] = self.loop.create_future() 

176 if self._event_thread: 

177 cb = functools.partial( # type: ignore[assignment] 

178 self.loop.call_soon_threadsafe, 

179 self._callback, # type: ignore[arg-type] 

180 future, 

181 ) 

182 else: 

183 cb = functools.partial(self._callback, future) 

184 return future, cb 

185 

186 def _query_callback( 

187 self, 

188 fut: asyncio.Future[QueryResult], 

189 qtype: int, 

190 result: pycares.DNSResult, 

191 errorno: int | None, 

192 ) -> None: 

193 """Callback for query that converts results to compatible format.""" 

194 if fut.cancelled(): 

195 return 

196 if errorno is not None: 

197 fut.set_exception( 

198 error.DNSError(errorno, pycares.errno.strerror(errorno)) 

199 ) 

200 else: 

201 fut.set_result(convert_result(result, qtype)) 

202 

203 def _get_query_future_callback( 

204 self, qtype: int 

205 ) -> tuple[asyncio.Future[QueryResult], Callable[..., None]]: 

206 """Return a future and callback for query with result conversion.""" 

207 future: asyncio.Future[QueryResult] = self.loop.create_future() 

208 cb: Callable[..., None] 

209 if self._event_thread: 

210 cb = functools.partial( # type: ignore[assignment] 

211 self.loop.call_soon_threadsafe, 

212 self._query_callback, # type: ignore[arg-type] 

213 future, 

214 qtype, 

215 ) 

216 else: 

217 cb = functools.partial(self._query_callback, future, qtype) 

218 return future, cb 

219 

220 @overload 

221 def query( 

222 self, host: str, qtype: Literal['A'], qclass: str | None = ... 

223 ) -> asyncio.Future[list[AresQueryAResult]]: ... 

224 @overload 

225 def query( 

226 self, host: str, qtype: Literal['AAAA'], qclass: str | None = ... 

227 ) -> asyncio.Future[list[AresQueryAAAAResult]]: ... 

228 @overload 

229 def query( 

230 self, host: str, qtype: Literal['CAA'], qclass: str | None = ... 

231 ) -> asyncio.Future[list[AresQueryCAAResult]]: ... 

232 @overload 

233 def query( 

234 self, host: str, qtype: Literal['CNAME'], qclass: str | None = ... 

235 ) -> asyncio.Future[AresQueryCNAMEResult]: ... 

236 @overload 

237 def query( 

238 self, host: str, qtype: Literal['MX'], qclass: str | None = ... 

239 ) -> asyncio.Future[list[AresQueryMXResult]]: ... 

240 @overload 

241 def query( 

242 self, host: str, qtype: Literal['NAPTR'], qclass: str | None = ... 

243 ) -> asyncio.Future[list[AresQueryNAPTRResult]]: ... 

244 @overload 

245 def query( 

246 self, host: str, qtype: Literal['NS'], qclass: str | None = ... 

247 ) -> asyncio.Future[list[AresQueryNSResult]]: ... 

248 @overload 

249 def query( 

250 self, host: str, qtype: Literal['PTR'], qclass: str | None = ... 

251 ) -> asyncio.Future[AresQueryPTRResult]: ... 

252 @overload 

253 def query( 

254 self, host: str, qtype: Literal['SOA'], qclass: str | None = ... 

255 ) -> asyncio.Future[AresQuerySOAResult]: ... 

256 @overload 

257 def query( 

258 self, host: str, qtype: Literal['SRV'], qclass: str | None = ... 

259 ) -> asyncio.Future[list[AresQuerySRVResult]]: ... 

260 @overload 

261 def query( 

262 self, host: str, qtype: Literal['TXT'], qclass: str | None = ... 

263 ) -> asyncio.Future[list[AresQueryTXTResult]]: ... 

264 

265 def query( 

266 self, host: str, qtype: str, qclass: str | None = None 

267 ) -> asyncio.Future[list[Any]] | asyncio.Future[Any]: 

268 """Query DNS records (deprecated, use query_dns instead).""" 

269 warnings.warn( 

270 'query() is deprecated, use query_dns() instead', 

271 DeprecationWarning, 

272 stacklevel=2, 

273 ) 

274 try: 

275 qtype_int = query_type_map[qtype] 

276 except KeyError as e: 

277 raise ValueError(f'invalid query type: {qtype}') from e 

278 qclass_int: int | None = None 

279 if qclass is not None: 

280 try: 

281 qclass_int = query_class_map[qclass] 

282 except KeyError as e: 

283 raise ValueError(f'invalid query class: {qclass}') from e 

284 

285 fut, cb = self._get_query_future_callback(qtype_int) 

286 if qclass_int is not None: 

287 self._channel.query( 

288 host, qtype_int, query_class=qclass_int, callback=cb 

289 ) 

290 else: 

291 self._channel.query(host, qtype_int, callback=cb) 

292 return fut 

293 

294 def query_dns( 

295 self, host: str, qtype: str, qclass: str | None = None 

296 ) -> asyncio.Future[pycares.DNSResult]: 

297 """Query DNS records, returning native pycares 5.x DNSResult.""" 

298 try: 

299 qtype_int = query_type_map[qtype] 

300 except KeyError as e: 

301 raise ValueError(f'invalid query type: {qtype}') from e 

302 qclass_int: int | None = None 

303 if qclass is not None: 

304 try: 

305 qclass_int = query_class_map[qclass] 

306 except KeyError as e: 

307 raise ValueError(f'invalid query class: {qclass}') from e 

308 

309 fut: asyncio.Future[pycares.DNSResult] 

310 fut, cb = self._get_future_callback() 

311 if qclass_int is not None: 

312 self._channel.query( 

313 host, qtype_int, query_class=qclass_int, callback=cb 

314 ) 

315 else: 

316 self._channel.query(host, qtype_int, callback=cb) 

317 return fut 

318 

319 def _gethostbyname_callback( 

320 self, 

321 fut: asyncio.Future[AresHostResult], 

322 host: str, 

323 result: pycares.AddrInfoResult | None, 

324 errorno: int | None, 

325 ) -> None: 

326 """Callback for gethostbyname that converts AddrInfoResult.""" 

327 if fut.cancelled(): 

328 return 

329 if errorno is not None: 

330 fut.set_exception( 

331 error.DNSError(errorno, pycares.errno.strerror(errorno)) 

332 ) 

333 else: 

334 assert result is not None # noqa: S101 

335 # node.addr is (address_bytes, port) - extract and decode 

336 addresses = [node.addr[0].decode() for node in result.nodes] 

337 # Get canonical name from cnames if available 

338 name = result.cnames[0].name if result.cnames else host 

339 fut.set_result( 

340 AresHostResult(name=name, aliases=[], addresses=addresses) 

341 ) 

342 

343 def gethostbyname( 

344 self, host: str, family: socket.AddressFamily 

345 ) -> asyncio.Future[AresHostResult]: 

346 """ 

347 Resolve hostname to addresses. 

348 

349 Deprecated: Use getaddrinfo() instead. This is implemented using 

350 getaddrinfo as pycares 5.x removed the gethostbyname method. 

351 """ 

352 warnings.warn( 

353 'gethostbyname() is deprecated, use getaddrinfo() instead', 

354 DeprecationWarning, 

355 stacklevel=2, 

356 ) 

357 fut: asyncio.Future[AresHostResult] = self.loop.create_future() 

358 cb: Callable[..., None] 

359 if self._event_thread: 

360 cb = functools.partial( # type: ignore[assignment] 

361 self.loop.call_soon_threadsafe, 

362 self._gethostbyname_callback, # type: ignore[arg-type] 

363 fut, 

364 host, 

365 ) 

366 else: 

367 cb = functools.partial(self._gethostbyname_callback, fut, host) 

368 self._channel.getaddrinfo(host, None, family=family, callback=cb) 

369 return fut 

370 

371 def getaddrinfo( 

372 self, 

373 host: str, 

374 family: socket.AddressFamily = socket.AF_UNSPEC, 

375 port: int | None = None, 

376 proto: int = 0, 

377 type: int = 0, 

378 flags: int = 0, 

379 ) -> asyncio.Future[pycares.AddrInfoResult]: 

380 fut: asyncio.Future[pycares.AddrInfoResult] 

381 fut, cb = self._get_future_callback() 

382 self._channel.getaddrinfo( 

383 host, 

384 port, 

385 family=family, 

386 type=type, 

387 proto=proto, 

388 flags=flags, 

389 callback=cb, 

390 ) 

391 return fut 

392 

393 def getnameinfo( 

394 self, 

395 sockaddr: tuple[str, int] | tuple[str, int, int, int], 

396 flags: int = 0, 

397 ) -> asyncio.Future[pycares.NameInfoResult]: 

398 fut: asyncio.Future[pycares.NameInfoResult] 

399 fut, cb = self._get_future_callback() 

400 self._channel.getnameinfo(sockaddr, flags, callback=cb) 

401 return fut 

402 

403 def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]: 

404 fut: asyncio.Future[pycares.HostResult] 

405 fut, cb = self._get_future_callback() 

406 self._channel.gethostbyaddr(name, callback=cb) 

407 return fut 

408 

409 def cancel(self) -> None: 

410 self._channel.cancel() 

411 

412 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: 

413 if readable or writable: 

414 if readable: 

415 self.loop.add_reader( 

416 fd, self._channel.process_fd, fd, pycares.ARES_SOCKET_BAD 

417 ) 

418 self._read_fds.add(fd) 

419 if writable: 

420 self.loop.add_writer( 

421 fd, self._channel.process_fd, pycares.ARES_SOCKET_BAD, fd 

422 ) 

423 self._write_fds.add(fd) 

424 if self._timer is None: 

425 self._start_timer() 

426 else: 

427 # socket is now closed 

428 if fd in self._read_fds: 

429 self._read_fds.discard(fd) 

430 self.loop.remove_reader(fd) 

431 

432 if fd in self._write_fds: 

433 self._write_fds.discard(fd) 

434 self.loop.remove_writer(fd) 

435 

436 if ( 

437 not self._read_fds 

438 and not self._write_fds 

439 and self._timer is not None 

440 ): 

441 self._timer.cancel() 

442 self._timer = None 

443 

444 def _timer_cb(self) -> None: 

445 if self._read_fds or self._write_fds: 

446 self._channel.process_fd( 

447 pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD 

448 ) 

449 self._start_timer() 

450 else: 

451 self._timer = None 

452 

453 def _start_timer(self) -> None: 

454 timeout = self._timeout 

455 if timeout is None or timeout < 0 or timeout > 1: 

456 timeout = 1 

457 elif timeout == 0: 

458 timeout = 0.1 

459 

460 self._timer = self.loop.call_later(timeout, self._timer_cb) 

461 

462 def _cleanup(self) -> None: 

463 """Cleanup timers and file descriptors when closing resolver.""" 

464 if self._closed: 

465 return 

466 # Mark as closed first to prevent double cleanup 

467 self._closed = True 

468 # Cancel timer if running 

469 if self._timer is not None: 

470 self._timer.cancel() 

471 self._timer = None 

472 

473 # Remove all file descriptors 

474 for fd in self._read_fds: 

475 self.loop.remove_reader(fd) 

476 for fd in self._write_fds: 

477 self.loop.remove_writer(fd) 

478 

479 self._read_fds.clear() 

480 self._write_fds.clear() 

481 self._channel.close() 

482 

483 async def close(self) -> None: 

484 """ 

485 Cleanly close the DNS resolver. 

486 

487 This should be called to ensure all resources are properly released. 

488 After calling close(), the resolver should not be used again. 

489 """ 

490 if not self._closed: 

491 self._channel.cancel() 

492 self._cleanup() 

493 

494 async def __aenter__(self) -> DNSResolver: 

495 """Enter the async context manager.""" 

496 return self 

497 

498 async def __aexit__( 

499 self, 

500 exc_type: type[BaseException] | None, 

501 exc_val: BaseException | None, 

502 exc_tb: TracebackType | None, 

503 ) -> None: 

504 """Exit the async context manager.""" 

505 await self.close() 

506 

507 def __del__(self) -> None: 

508 """Handle cleanup when the resolver is garbage collected.""" 

509 # Check if we have a channel to clean up 

510 # This can happen if an exception occurs during __init__ before 

511 # _channel is created (e.g., RuntimeError on Windows 

512 # without proper loop) 

513 if hasattr(self, '_channel'): 

514 self._cleanup()