Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiodns/__init__.py: 31%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1from __future__ import annotations
3import asyncio
4import functools
5import logging
6import socket
7import sys
8import warnings
9import weakref
10from collections.abc import Callable, Iterable, Sequence
11from types import TracebackType
12from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
14import pycares
16from . import error
17from .compat import (
18 AresHostResult,
19 AresQueryAAAAResult,
20 AresQueryAResult,
21 AresQueryCAAResult,
22 AresQueryCNAMEResult,
23 AresQueryMXResult,
24 AresQueryNAPTRResult,
25 AresQueryNSResult,
26 AresQueryPTRResult,
27 AresQuerySOAResult,
28 AresQuerySRVResult,
29 AresQueryTXTResult,
30 QueryResult,
31 convert_result,
32)
34__version__ = '4.0.0'
36__all__ = (
37 'DNSResolver',
38 'error',
39)
41_T = TypeVar('_T')
43WINDOWS_SELECTOR_ERR_MSG = (
44 'aiodns needs a SelectorEventLoop on Windows. See more: '
45 'https://github.com/aio-libs/aiodns#note-for-windows-users'
46)
48_LOGGER = logging.getLogger(__name__)
50query_type_map = {
51 'A': pycares.QUERY_TYPE_A,
52 'AAAA': pycares.QUERY_TYPE_AAAA,
53 'ANY': pycares.QUERY_TYPE_ANY,
54 'CAA': pycares.QUERY_TYPE_CAA,
55 'CNAME': pycares.QUERY_TYPE_CNAME,
56 'MX': pycares.QUERY_TYPE_MX,
57 'NAPTR': pycares.QUERY_TYPE_NAPTR,
58 'NS': pycares.QUERY_TYPE_NS,
59 'PTR': pycares.QUERY_TYPE_PTR,
60 'SOA': pycares.QUERY_TYPE_SOA,
61 'SRV': pycares.QUERY_TYPE_SRV,
62 'TXT': pycares.QUERY_TYPE_TXT,
63}
65query_class_map = {
66 'IN': pycares.QUERY_CLASS_IN,
67 'CHAOS': pycares.QUERY_CLASS_CHAOS,
68 'HS': pycares.QUERY_CLASS_HS,
69 'NONE': pycares.QUERY_CLASS_NONE,
70 'ANY': pycares.QUERY_CLASS_ANY,
71}
74class DNSResolver:
75 def __init__(
76 self,
77 nameservers: Sequence[str] | None = None,
78 loop: asyncio.AbstractEventLoop | None = None,
79 **kwargs: Any,
80 ) -> None: # TODO(PY311): Use Unpack for kwargs.
81 self._closed = True
82 self.loop = loop or asyncio.get_event_loop()
83 if TYPE_CHECKING:
84 assert self.loop is not None
85 kwargs.pop('sock_state_cb', None)
86 timeout = kwargs.pop('timeout', None)
87 self._timeout = timeout
88 self._event_thread, self._channel = self._make_channel(**kwargs)
89 if nameservers:
90 self.nameservers = nameservers
91 self._read_fds: set[int] = set()
92 self._write_fds: set[int] = set()
93 self._timer: asyncio.TimerHandle | None = None
94 self._closed = False
96 def _make_channel(self, **kwargs: Any) -> tuple[bool, pycares.Channel]:
97 # pycares 5+ uses event_thread by default when sock_state_cb
98 # is not provided
99 try:
100 return True, pycares.Channel(timeout=self._timeout, **kwargs)
101 except pycares.AresError as e:
102 if sys.platform == 'linux':
103 _LOGGER.warning(
104 'Failed to create DNS resolver channel with automatic '
105 'monitoring of resolver configuration changes. This '
106 'usually means the system ran out of inotify watches. '
107 'Falling back to socket state callback. Consider '
108 'increasing the system inotify watch limit: %s',
109 e,
110 )
111 else:
112 _LOGGER.warning(
113 'Failed to create DNS resolver channel with automatic '
114 'monitoring of resolver configuration changes. '
115 'Falling back to socket state callback: %s',
116 e,
117 )
118 # Fall back to sock_state_cb (needs SelectorEventLoop on Windows)
119 if sys.platform == 'win32' and not isinstance(
120 self.loop, asyncio.SelectorEventLoop
121 ):
122 try:
123 import winloop
125 if not isinstance(self.loop, winloop.Loop):
126 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
127 except ModuleNotFoundError as ex:
128 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex
129 # Use weak reference for deterministic cleanup. Without it there's a
130 # reference cycle (DNSResolver -> _channel -> callback -> DNSResolver).
131 # Python 3.4+ can handle cycles with __del__, but weak ref ensures
132 # cleanup happens immediately when last reference is dropped.
133 weak_self = weakref.ref(self)
135 def sock_state_cb_wrapper(
136 fd: int, readable: bool, writable: bool
137 ) -> None:
138 this = weak_self()
139 if this is not None:
140 this._sock_state_cb(fd, readable, writable)
142 return False, pycares.Channel(
143 sock_state_cb=sock_state_cb_wrapper,
144 timeout=self._timeout,
145 **kwargs,
146 )
148 @property
149 def nameservers(self) -> Sequence[str]:
150 # pycares 5.x returns servers with port (e.g., '8.8.8.8:53')
151 # Strip port for backward compatibility with pycares 4.x
152 return [s.rsplit(':', 1)[0] for s in self._channel.servers]
154 @nameservers.setter
155 def nameservers(self, value: Iterable[str | bytes]) -> None:
156 self._channel.servers = value
158 def _callback(
159 self, fut: asyncio.Future[_T], result: _T, errorno: int | None
160 ) -> None:
161 if fut.cancelled():
162 return
163 if errorno is not None:
164 fut.set_exception(
165 error.DNSError(errorno, pycares.errno.strerror(errorno))
166 )
167 else:
168 fut.set_result(result)
170 def _get_future_callback(
171 self,
172 ) -> tuple[asyncio.Future[_T], Callable[[_T, int | None], None]]:
173 """Return a future and a callback to set the result of the future."""
174 cb: Callable[[_T, int | None], None]
175 future: asyncio.Future[_T] = self.loop.create_future()
176 if self._event_thread:
177 cb = functools.partial( # type: ignore[assignment]
178 self.loop.call_soon_threadsafe,
179 self._callback, # type: ignore[arg-type]
180 future,
181 )
182 else:
183 cb = functools.partial(self._callback, future)
184 return future, cb
186 def _query_callback(
187 self,
188 fut: asyncio.Future[QueryResult],
189 qtype: int,
190 result: pycares.DNSResult,
191 errorno: int | None,
192 ) -> None:
193 """Callback for query that converts results to compatible format."""
194 if fut.cancelled():
195 return
196 if errorno is not None:
197 fut.set_exception(
198 error.DNSError(errorno, pycares.errno.strerror(errorno))
199 )
200 else:
201 fut.set_result(convert_result(result, qtype))
203 def _get_query_future_callback(
204 self, qtype: int
205 ) -> tuple[asyncio.Future[QueryResult], Callable[..., None]]:
206 """Return a future and callback for query with result conversion."""
207 future: asyncio.Future[QueryResult] = self.loop.create_future()
208 cb: Callable[..., None]
209 if self._event_thread:
210 cb = functools.partial( # type: ignore[assignment]
211 self.loop.call_soon_threadsafe,
212 self._query_callback, # type: ignore[arg-type]
213 future,
214 qtype,
215 )
216 else:
217 cb = functools.partial(self._query_callback, future, qtype)
218 return future, cb
220 @overload
221 def query(
222 self, host: str, qtype: Literal['A'], qclass: str | None = ...
223 ) -> asyncio.Future[list[AresQueryAResult]]: ...
224 @overload
225 def query(
226 self, host: str, qtype: Literal['AAAA'], qclass: str | None = ...
227 ) -> asyncio.Future[list[AresQueryAAAAResult]]: ...
228 @overload
229 def query(
230 self, host: str, qtype: Literal['CAA'], qclass: str | None = ...
231 ) -> asyncio.Future[list[AresQueryCAAResult]]: ...
232 @overload
233 def query(
234 self, host: str, qtype: Literal['CNAME'], qclass: str | None = ...
235 ) -> asyncio.Future[AresQueryCNAMEResult]: ...
236 @overload
237 def query(
238 self, host: str, qtype: Literal['MX'], qclass: str | None = ...
239 ) -> asyncio.Future[list[AresQueryMXResult]]: ...
240 @overload
241 def query(
242 self, host: str, qtype: Literal['NAPTR'], qclass: str | None = ...
243 ) -> asyncio.Future[list[AresQueryNAPTRResult]]: ...
244 @overload
245 def query(
246 self, host: str, qtype: Literal['NS'], qclass: str | None = ...
247 ) -> asyncio.Future[list[AresQueryNSResult]]: ...
248 @overload
249 def query(
250 self, host: str, qtype: Literal['PTR'], qclass: str | None = ...
251 ) -> asyncio.Future[AresQueryPTRResult]: ...
252 @overload
253 def query(
254 self, host: str, qtype: Literal['SOA'], qclass: str | None = ...
255 ) -> asyncio.Future[AresQuerySOAResult]: ...
256 @overload
257 def query(
258 self, host: str, qtype: Literal['SRV'], qclass: str | None = ...
259 ) -> asyncio.Future[list[AresQuerySRVResult]]: ...
260 @overload
261 def query(
262 self, host: str, qtype: Literal['TXT'], qclass: str | None = ...
263 ) -> asyncio.Future[list[AresQueryTXTResult]]: ...
265 def query(
266 self, host: str, qtype: str, qclass: str | None = None
267 ) -> asyncio.Future[list[Any]] | asyncio.Future[Any]:
268 """Query DNS records (deprecated, use query_dns instead)."""
269 warnings.warn(
270 'query() is deprecated, use query_dns() instead',
271 DeprecationWarning,
272 stacklevel=2,
273 )
274 try:
275 qtype_int = query_type_map[qtype]
276 except KeyError as e:
277 raise ValueError(f'invalid query type: {qtype}') from e
278 qclass_int: int | None = None
279 if qclass is not None:
280 try:
281 qclass_int = query_class_map[qclass]
282 except KeyError as e:
283 raise ValueError(f'invalid query class: {qclass}') from e
285 fut, cb = self._get_query_future_callback(qtype_int)
286 if qclass_int is not None:
287 self._channel.query(
288 host, qtype_int, query_class=qclass_int, callback=cb
289 )
290 else:
291 self._channel.query(host, qtype_int, callback=cb)
292 return fut
294 def query_dns(
295 self, host: str, qtype: str, qclass: str | None = None
296 ) -> asyncio.Future[pycares.DNSResult]:
297 """Query DNS records, returning native pycares 5.x DNSResult."""
298 try:
299 qtype_int = query_type_map[qtype]
300 except KeyError as e:
301 raise ValueError(f'invalid query type: {qtype}') from e
302 qclass_int: int | None = None
303 if qclass is not None:
304 try:
305 qclass_int = query_class_map[qclass]
306 except KeyError as e:
307 raise ValueError(f'invalid query class: {qclass}') from e
309 fut: asyncio.Future[pycares.DNSResult]
310 fut, cb = self._get_future_callback()
311 if qclass_int is not None:
312 self._channel.query(
313 host, qtype_int, query_class=qclass_int, callback=cb
314 )
315 else:
316 self._channel.query(host, qtype_int, callback=cb)
317 return fut
319 def _gethostbyname_callback(
320 self,
321 fut: asyncio.Future[AresHostResult],
322 host: str,
323 result: pycares.AddrInfoResult | None,
324 errorno: int | None,
325 ) -> None:
326 """Callback for gethostbyname that converts AddrInfoResult."""
327 if fut.cancelled():
328 return
329 if errorno is not None:
330 fut.set_exception(
331 error.DNSError(errorno, pycares.errno.strerror(errorno))
332 )
333 else:
334 assert result is not None # noqa: S101
335 # node.addr is (address_bytes, port) - extract and decode
336 addresses = [node.addr[0].decode() for node in result.nodes]
337 # Get canonical name from cnames if available
338 name = result.cnames[0].name if result.cnames else host
339 fut.set_result(
340 AresHostResult(name=name, aliases=[], addresses=addresses)
341 )
343 def gethostbyname(
344 self, host: str, family: socket.AddressFamily
345 ) -> asyncio.Future[AresHostResult]:
346 """
347 Resolve hostname to addresses.
349 Deprecated: Use getaddrinfo() instead. This is implemented using
350 getaddrinfo as pycares 5.x removed the gethostbyname method.
351 """
352 warnings.warn(
353 'gethostbyname() is deprecated, use getaddrinfo() instead',
354 DeprecationWarning,
355 stacklevel=2,
356 )
357 fut: asyncio.Future[AresHostResult] = self.loop.create_future()
358 cb: Callable[..., None]
359 if self._event_thread:
360 cb = functools.partial( # type: ignore[assignment]
361 self.loop.call_soon_threadsafe,
362 self._gethostbyname_callback, # type: ignore[arg-type]
363 fut,
364 host,
365 )
366 else:
367 cb = functools.partial(self._gethostbyname_callback, fut, host)
368 self._channel.getaddrinfo(host, None, family=family, callback=cb)
369 return fut
371 def getaddrinfo(
372 self,
373 host: str,
374 family: socket.AddressFamily = socket.AF_UNSPEC,
375 port: int | None = None,
376 proto: int = 0,
377 type: int = 0,
378 flags: int = 0,
379 ) -> asyncio.Future[pycares.AddrInfoResult]:
380 fut: asyncio.Future[pycares.AddrInfoResult]
381 fut, cb = self._get_future_callback()
382 self._channel.getaddrinfo(
383 host,
384 port,
385 family=family,
386 type=type,
387 proto=proto,
388 flags=flags,
389 callback=cb,
390 )
391 return fut
393 def getnameinfo(
394 self,
395 sockaddr: tuple[str, int] | tuple[str, int, int, int],
396 flags: int = 0,
397 ) -> asyncio.Future[pycares.NameInfoResult]:
398 fut: asyncio.Future[pycares.NameInfoResult]
399 fut, cb = self._get_future_callback()
400 self._channel.getnameinfo(sockaddr, flags, callback=cb)
401 return fut
403 def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]:
404 fut: asyncio.Future[pycares.HostResult]
405 fut, cb = self._get_future_callback()
406 self._channel.gethostbyaddr(name, callback=cb)
407 return fut
409 def cancel(self) -> None:
410 self._channel.cancel()
412 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
413 if readable or writable:
414 if readable:
415 self.loop.add_reader(
416 fd, self._channel.process_fd, fd, pycares.ARES_SOCKET_BAD
417 )
418 self._read_fds.add(fd)
419 if writable:
420 self.loop.add_writer(
421 fd, self._channel.process_fd, pycares.ARES_SOCKET_BAD, fd
422 )
423 self._write_fds.add(fd)
424 if self._timer is None:
425 self._start_timer()
426 else:
427 # socket is now closed
428 if fd in self._read_fds:
429 self._read_fds.discard(fd)
430 self.loop.remove_reader(fd)
432 if fd in self._write_fds:
433 self._write_fds.discard(fd)
434 self.loop.remove_writer(fd)
436 if (
437 not self._read_fds
438 and not self._write_fds
439 and self._timer is not None
440 ):
441 self._timer.cancel()
442 self._timer = None
444 def _timer_cb(self) -> None:
445 if self._read_fds or self._write_fds:
446 self._channel.process_fd(
447 pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD
448 )
449 self._start_timer()
450 else:
451 self._timer = None
453 def _start_timer(self) -> None:
454 timeout = self._timeout
455 if timeout is None or timeout < 0 or timeout > 1:
456 timeout = 1
457 elif timeout == 0:
458 timeout = 0.1
460 self._timer = self.loop.call_later(timeout, self._timer_cb)
462 def _cleanup(self) -> None:
463 """Cleanup timers and file descriptors when closing resolver."""
464 if self._closed:
465 return
466 # Mark as closed first to prevent double cleanup
467 self._closed = True
468 # Cancel timer if running
469 if self._timer is not None:
470 self._timer.cancel()
471 self._timer = None
473 # Remove all file descriptors
474 for fd in self._read_fds:
475 self.loop.remove_reader(fd)
476 for fd in self._write_fds:
477 self.loop.remove_writer(fd)
479 self._read_fds.clear()
480 self._write_fds.clear()
481 self._channel.close()
483 async def close(self) -> None:
484 """
485 Cleanly close the DNS resolver.
487 This should be called to ensure all resources are properly released.
488 After calling close(), the resolver should not be used again.
489 """
490 if not self._closed:
491 self._channel.cancel()
492 self._cleanup()
494 async def __aenter__(self) -> DNSResolver:
495 """Enter the async context manager."""
496 return self
498 async def __aexit__(
499 self,
500 exc_type: type[BaseException] | None,
501 exc_val: BaseException | None,
502 exc_tb: TracebackType | None,
503 ) -> None:
504 """Exit the async context manager."""
505 await self.close()
507 def __del__(self) -> None:
508 """Handle cleanup when the resolver is garbage collected."""
509 # Check if we have a channel to clean up
510 # This can happen if an exception occurs during __init__ before
511 # _channel is created (e.g., RuntimeError on Windows
512 # without proper loop)
513 if hasattr(self, '_channel'):
514 self._cleanup()