Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiodns/__init__.py: 29%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import asyncio
4import contextlib
5import functools
6import logging
7import socket
8import sys
9import warnings
10import weakref
11from collections.abc import Callable, Iterable, Iterator, Sequence
12from types import TracebackType
13from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
15import pycares
17from . import error
18from .compat import (
19 AresHostResult,
20 AresQueryAAAAResult,
21 AresQueryAResult,
22 AresQueryCAAResult,
23 AresQueryCNAMEResult,
24 AresQueryMXResult,
25 AresQueryNAPTRResult,
26 AresQueryNSResult,
27 AresQueryPTRResult,
28 AresQuerySOAResult,
29 AresQuerySRVResult,
30 AresQueryTXTResult,
31 QueryResult,
32 convert_result,
33)
35__version__ = '4.0.4'
37__all__ = (
38 'DNSResolver',
39 'error',
40)
42_T = TypeVar('_T')
44WINDOWS_SELECTOR_ERR_MSG = (
45 'aiodns needs a SelectorEventLoop on Windows. See more: '
46 'https://github.com/aio-libs/aiodns#note-for-windows-users'
47)
49_LOGGER = logging.getLogger(__name__)
51query_type_map = {
52 'A': pycares.QUERY_TYPE_A,
53 'AAAA': pycares.QUERY_TYPE_AAAA,
54 'ANY': pycares.QUERY_TYPE_ANY,
55 'CAA': pycares.QUERY_TYPE_CAA,
56 'CNAME': pycares.QUERY_TYPE_CNAME,
57 'MX': pycares.QUERY_TYPE_MX,
58 'NAPTR': pycares.QUERY_TYPE_NAPTR,
59 'NS': pycares.QUERY_TYPE_NS,
60 'PTR': pycares.QUERY_TYPE_PTR,
61 'SOA': pycares.QUERY_TYPE_SOA,
62 'SRV': pycares.QUERY_TYPE_SRV,
63 'TXT': pycares.QUERY_TYPE_TXT,
64}
66query_class_map = {
67 'IN': pycares.QUERY_CLASS_IN,
68 'CHAOS': pycares.QUERY_CLASS_CHAOS,
69 'HS': pycares.QUERY_CLASS_HS,
70 'NONE': pycares.QUERY_CLASS_NONE,
71 'ANY': pycares.QUERY_CLASS_ANY,
72}
75class DNSResolver:
76 def __init__(
77 self,
78 nameservers: Sequence[str] | None = None,
79 loop: asyncio.AbstractEventLoop | None = None,
80 **kwargs: Any,
81 ) -> None: # TODO(PY311): Use Unpack for kwargs.
82 self._closed = True
83 self.loop = loop or asyncio.get_event_loop()
84 if TYPE_CHECKING:
85 assert self.loop is not None
86 kwargs.pop('sock_state_cb', None)
87 timeout = kwargs.pop('timeout', None)
88 self._timeout = timeout
89 self._event_thread, self._channel = self._make_channel(**kwargs)
90 if nameservers:
91 self.nameservers = nameservers
92 self._read_fds: set[int] = set()
93 self._write_fds: set[int] = set()
94 self._timer: asyncio.TimerHandle | None = None
95 self._closed = False
97 def _make_channel(self, **kwargs: Any) -> tuple[bool, pycares.Channel]:
98 # pycares 5+ uses event_thread by default when sock_state_cb
99 # is not provided
100 try:
101 return True, pycares.Channel(timeout=self._timeout, **kwargs)
102 except pycares.AresError as e:
103 if sys.platform == 'linux':
104 _LOGGER.warning(
105 'Failed to create DNS resolver channel with automatic '
106 'monitoring of resolver configuration changes. This '
107 'usually means the system ran out of inotify watches. '
108 'Falling back to socket state callback. Consider '
109 'increasing the system inotify watch limit: %s',
110 e,
111 )
112 else:
113 _LOGGER.warning(
114 'Failed to create DNS resolver channel with automatic '
115 'monitoring of resolver configuration changes. '
116 'Falling back to socket state callback: %s',
117 e,
118 )
119 # Fall back to sock_state_cb (needs SelectorEventLoop on Windows)
120 if sys.platform == 'win32' and not isinstance(
121 self.loop, asyncio.SelectorEventLoop
122 ):
123 try:
124 import winloop
126 if not isinstance(self.loop, winloop.Loop):
127 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
128 except ModuleNotFoundError as ex:
129 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex
130 # Use weak reference for deterministic cleanup. Without it there's a
131 # reference cycle (DNSResolver -> _channel -> callback -> DNSResolver).
132 # Python 3.4+ can handle cycles with __del__, but weak ref ensures
133 # cleanup happens immediately when last reference is dropped.
134 weak_self = weakref.ref(self)
136 def sock_state_cb_wrapper(
137 fd: int, readable: bool, writable: bool
138 ) -> None:
139 this = weak_self()
140 if this is not None:
141 this._sock_state_cb(fd, readable, writable)
143 return False, pycares.Channel(
144 sock_state_cb=sock_state_cb_wrapper,
145 timeout=self._timeout,
146 **kwargs,
147 )
149 @property
150 def nameservers(self) -> Sequence[str]:
151 # pycares 5.x returns servers with port (e.g., '8.8.8.8:53')
152 # Strip port for backward compatibility with pycares 4.x
153 return [s.rsplit(':', 1)[0] for s in self._channel.servers]
155 @nameservers.setter
156 def nameservers(self, value: Iterable[str | bytes]) -> None:
157 self._channel.servers = value
159 def _callback(
160 self, fut: asyncio.Future[_T], result: _T, errorno: int | None
161 ) -> None:
162 # The future can already be done if pycares raised synchronously
163 # and _capture_ares_error set the exception before c-ares delivered
164 # the same error through this callback.
165 if fut.done():
166 return
167 if errorno is not None:
168 fut.set_exception(
169 error.DNSError(errorno, pycares.errno.strerror(errorno))
170 )
171 else:
172 fut.set_result(result)
174 def _get_future_callback(
175 self,
176 ) -> tuple[asyncio.Future[_T], Callable[[_T, int | None], None]]:
177 """Return a future and a callback to set the result of the future."""
178 cb: Callable[[_T, int | None], None]
179 future: asyncio.Future[_T] = self.loop.create_future()
180 if self._event_thread:
181 cb = functools.partial( # type: ignore[assignment]
182 self.loop.call_soon_threadsafe,
183 self._callback, # type: ignore[arg-type]
184 future,
185 )
186 else:
187 cb = functools.partial(self._callback, future)
188 return future, cb
190 def _query_callback(
191 self,
192 fut: asyncio.Future[QueryResult],
193 qtype: int,
194 result: pycares.DNSResult,
195 errorno: int | None,
196 ) -> None:
197 """Callback for query that converts results to compatible format."""
198 # See _callback for why we guard on done() rather than cancelled().
199 if fut.done():
200 return
201 if errorno is not None:
202 fut.set_exception(
203 error.DNSError(errorno, pycares.errno.strerror(errorno))
204 )
205 return
206 try:
207 converted = convert_result(result, qtype)
208 except error.DNSError as exc:
209 fut.set_exception(exc)
210 else:
211 fut.set_result(converted)
213 def _get_query_future_callback(
214 self, qtype: int
215 ) -> tuple[asyncio.Future[QueryResult], Callable[..., None]]:
216 """Return a future and callback for query with result conversion."""
217 future: asyncio.Future[QueryResult] = self.loop.create_future()
218 cb: Callable[..., None]
219 if self._event_thread:
220 cb = functools.partial( # type: ignore[assignment]
221 self.loop.call_soon_threadsafe,
222 self._query_callback, # type: ignore[arg-type]
223 future,
224 qtype,
225 )
226 else:
227 cb = functools.partial(self._query_callback, future, qtype)
228 return future, cb
230 @contextlib.contextmanager
231 def _capture_ares_error(self, fut: asyncio.Future[_T]) -> Iterator[None]:
232 # When pycares raises synchronously (e.g. ARES_EBADNAME for a
233 # malformed hostname), c-ares may also invoke the callback first,
234 # leaving the future already done. Route the error through the
235 # future so callers can rely on `await` to raise.
236 try:
237 yield
238 except pycares.AresError as exc:
239 if fut.done():
240 return
241 # pycares always raises (errno, message), but be defensive:
242 # an args-less AresError should still resolve the future to
243 # avoid an indefinite hang on `await`.
244 errno = exc.args[0] if exc.args else error.ARES_EFORMERR
245 fut.set_exception(
246 error.DNSError(errno, pycares.errno.strerror(errno))
247 )
249 @overload
250 def query(
251 self, host: str, qtype: Literal['A'], qclass: str | None = ...
252 ) -> asyncio.Future[list[AresQueryAResult]]: ...
253 @overload
254 def query(
255 self, host: str, qtype: Literal['AAAA'], qclass: str | None = ...
256 ) -> asyncio.Future[list[AresQueryAAAAResult]]: ...
257 @overload
258 def query(
259 self, host: str, qtype: Literal['CAA'], qclass: str | None = ...
260 ) -> asyncio.Future[list[AresQueryCAAResult]]: ...
261 @overload
262 def query(
263 self, host: str, qtype: Literal['CNAME'], qclass: str | None = ...
264 ) -> asyncio.Future[AresQueryCNAMEResult]: ...
265 @overload
266 def query(
267 self, host: str, qtype: Literal['MX'], qclass: str | None = ...
268 ) -> asyncio.Future[list[AresQueryMXResult]]: ...
269 @overload
270 def query(
271 self, host: str, qtype: Literal['NAPTR'], qclass: str | None = ...
272 ) -> asyncio.Future[list[AresQueryNAPTRResult]]: ...
273 @overload
274 def query(
275 self, host: str, qtype: Literal['NS'], qclass: str | None = ...
276 ) -> asyncio.Future[list[AresQueryNSResult]]: ...
277 @overload
278 def query(
279 self, host: str, qtype: Literal['PTR'], qclass: str | None = ...
280 ) -> asyncio.Future[AresQueryPTRResult]: ...
281 @overload
282 def query(
283 self, host: str, qtype: Literal['SOA'], qclass: str | None = ...
284 ) -> asyncio.Future[AresQuerySOAResult]: ...
285 @overload
286 def query(
287 self, host: str, qtype: Literal['SRV'], qclass: str | None = ...
288 ) -> asyncio.Future[list[AresQuerySRVResult]]: ...
289 @overload
290 def query(
291 self, host: str, qtype: Literal['TXT'], qclass: str | None = ...
292 ) -> asyncio.Future[list[AresQueryTXTResult]]: ...
294 def query(
295 self, host: str, qtype: str, qclass: str | None = None
296 ) -> asyncio.Future[list[Any]] | asyncio.Future[Any]:
297 """Query DNS records (deprecated, use query_dns instead)."""
298 warnings.warn(
299 'query() is deprecated, use query_dns() instead',
300 DeprecationWarning,
301 stacklevel=2,
302 )
303 try:
304 qtype_int = query_type_map[qtype]
305 except KeyError as e:
306 raise ValueError(f'invalid query type: {qtype}') from e
307 qclass_int: int | None = None
308 if qclass is not None:
309 try:
310 qclass_int = query_class_map[qclass]
311 except KeyError as e:
312 raise ValueError(f'invalid query class: {qclass}') from e
314 fut, cb = self._get_query_future_callback(qtype_int)
315 with self._capture_ares_error(fut):
316 if qclass_int is not None:
317 self._channel.query(
318 host, qtype_int, query_class=qclass_int, callback=cb
319 )
320 else:
321 self._channel.query(host, qtype_int, callback=cb)
322 return fut
324 def query_dns(
325 self, host: str, qtype: str, qclass: str | None = None
326 ) -> asyncio.Future[pycares.DNSResult]:
327 """Query DNS records, returning native pycares 5.x DNSResult."""
328 try:
329 qtype_int = query_type_map[qtype]
330 except KeyError as e:
331 raise ValueError(f'invalid query type: {qtype}') from e
332 qclass_int: int | None = None
333 if qclass is not None:
334 try:
335 qclass_int = query_class_map[qclass]
336 except KeyError as e:
337 raise ValueError(f'invalid query class: {qclass}') from e
339 fut: asyncio.Future[pycares.DNSResult]
340 fut, cb = self._get_future_callback()
341 with self._capture_ares_error(fut):
342 if qclass_int is not None:
343 self._channel.query(
344 host, qtype_int, query_class=qclass_int, callback=cb
345 )
346 else:
347 self._channel.query(host, qtype_int, callback=cb)
348 return fut
350 def _gethostbyname_callback(
351 self,
352 fut: asyncio.Future[AresHostResult],
353 host: str,
354 result: pycares.AddrInfoResult | None,
355 errorno: int | None,
356 ) -> None:
357 """Callback for gethostbyname that converts AddrInfoResult."""
358 # See _callback for why we guard on done() rather than cancelled().
359 if fut.done():
360 return
361 if errorno is not None:
362 fut.set_exception(
363 error.DNSError(errorno, pycares.errno.strerror(errorno))
364 )
365 else:
366 assert result is not None # noqa: S101
367 # node.addr is (address_bytes, port) - extract and decode
368 addresses = [node.addr[0].decode() for node in result.nodes]
369 # Get canonical name from cnames if available
370 name = result.cnames[0].name if result.cnames else host
371 fut.set_result(
372 AresHostResult(name=name, aliases=[], addresses=addresses)
373 )
375 def gethostbyname(
376 self, host: str, family: socket.AddressFamily
377 ) -> asyncio.Future[AresHostResult]:
378 """
379 Resolve hostname to addresses.
381 Deprecated: Use getaddrinfo() instead. This is implemented using
382 getaddrinfo as pycares 5.x removed the gethostbyname method.
383 """
384 warnings.warn(
385 'gethostbyname() is deprecated, use getaddrinfo() instead',
386 DeprecationWarning,
387 stacklevel=2,
388 )
389 fut: asyncio.Future[AresHostResult] = self.loop.create_future()
390 cb: Callable[..., None]
391 if self._event_thread:
392 cb = functools.partial( # type: ignore[assignment]
393 self.loop.call_soon_threadsafe,
394 self._gethostbyname_callback, # type: ignore[arg-type]
395 fut,
396 host,
397 )
398 else:
399 cb = functools.partial(self._gethostbyname_callback, fut, host)
400 with self._capture_ares_error(fut):
401 self._channel.getaddrinfo(host, None, family=family, callback=cb)
402 return fut
404 def getaddrinfo(
405 self,
406 host: str,
407 family: socket.AddressFamily = socket.AF_UNSPEC,
408 port: int | None = None,
409 proto: int = 0,
410 type: int = 0,
411 flags: int = 0,
412 ) -> asyncio.Future[pycares.AddrInfoResult]:
413 fut: asyncio.Future[pycares.AddrInfoResult]
414 fut, cb = self._get_future_callback()
415 with self._capture_ares_error(fut):
416 self._channel.getaddrinfo(
417 host,
418 port,
419 family=family,
420 type=type,
421 proto=proto,
422 flags=flags,
423 callback=cb,
424 )
425 return fut
427 def getnameinfo(
428 self,
429 sockaddr: tuple[str, int] | tuple[str, int, int, int],
430 flags: int = 0,
431 ) -> asyncio.Future[pycares.NameInfoResult]:
432 fut: asyncio.Future[pycares.NameInfoResult]
433 fut, cb = self._get_future_callback()
434 with self._capture_ares_error(fut):
435 self._channel.getnameinfo(sockaddr, flags, callback=cb)
436 return fut
438 def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]:
439 fut: asyncio.Future[pycares.HostResult]
440 fut, cb = self._get_future_callback()
441 with self._capture_ares_error(fut):
442 self._channel.gethostbyaddr(name, callback=cb)
443 return fut
445 def cancel(self) -> None:
446 self._channel.cancel()
448 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
449 if readable or writable:
450 if readable:
451 self.loop.add_reader(
452 fd, self._channel.process_fd, fd, pycares.ARES_SOCKET_BAD
453 )
454 self._read_fds.add(fd)
455 if writable:
456 self.loop.add_writer(
457 fd, self._channel.process_fd, pycares.ARES_SOCKET_BAD, fd
458 )
459 self._write_fds.add(fd)
460 if self._timer is None:
461 self._start_timer()
462 else:
463 # socket is now closed
464 if fd in self._read_fds:
465 self._read_fds.discard(fd)
466 self.loop.remove_reader(fd)
468 if fd in self._write_fds:
469 self._write_fds.discard(fd)
470 self.loop.remove_writer(fd)
472 if (
473 not self._read_fds
474 and not self._write_fds
475 and self._timer is not None
476 ):
477 self._timer.cancel()
478 self._timer = None
480 def _timer_cb(self) -> None:
481 if self._read_fds or self._write_fds:
482 self._channel.process_fd(
483 pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD
484 )
485 self._start_timer()
486 else:
487 self._timer = None
489 def _start_timer(self) -> None:
490 timeout = self._timeout
491 if timeout is None or timeout < 0 or timeout > 1:
492 timeout = 1
493 elif timeout == 0:
494 timeout = 0.1
496 self._timer = self.loop.call_later(timeout, self._timer_cb)
498 def _cleanup(self) -> None:
499 """Cleanup timers and file descriptors when closing resolver."""
500 if self._closed:
501 return
502 # Mark as closed first to prevent double cleanup
503 self._closed = True
504 # Cancel timer if running
505 if self._timer is not None:
506 self._timer.cancel()
507 self._timer = None
509 # Remove all file descriptors
510 for fd in self._read_fds:
511 self.loop.remove_reader(fd)
512 for fd in self._write_fds:
513 self.loop.remove_writer(fd)
515 self._read_fds.clear()
516 self._write_fds.clear()
517 self._channel.close()
519 async def close(self) -> None:
520 """
521 Cleanly close the DNS resolver.
523 This should be called to ensure all resources are properly released.
524 After calling close(), the resolver should not be used again.
525 """
526 if not self._closed:
527 self._channel.cancel()
528 self._cleanup()
530 async def __aenter__(self) -> DNSResolver:
531 """Enter the async context manager."""
532 return self
534 async def __aexit__(
535 self,
536 exc_type: type[BaseException] | None,
537 exc_val: BaseException | None,
538 exc_tb: TracebackType | None,
539 ) -> None:
540 """Exit the async context manager."""
541 await self.close()
543 def __del__(self) -> None:
544 """Handle cleanup when the resolver is garbage collected."""
545 # Check if we have a channel to clean up
546 # This can happen if an exception occurs during __init__ before
547 # _channel is created (e.g., RuntimeError on Windows
548 # without proper loop)
549 if hasattr(self, '_channel'):
550 self._cleanup()