Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/aiodns/__init__.py: 36%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import asyncio
2import functools
3import logging
4import pycares
5import socket
6import sys
7from collections.abc import Iterable, Sequence
8from typing import Any, Literal, Optional, TypeVar, Union, overload
10import pycares
11from typing import (
12 Any,
13 Callable,
14 Optional,
15 Set,
16 Sequence,
17 Tuple,
18 Union
19)
21from . import error
24__version__ = '3.4.0'
26__all__ = ('DNSResolver', 'error')
28_T = TypeVar("_T")
30WINDOWS_SELECTOR_ERR_MSG = (
31 "aiodns needs a SelectorEventLoop on Windows. See more: "
32 "https://github.com/aio-libs/aiodns#note-for-windows-users"
33)
35_LOGGER = logging.getLogger(__name__)
37READ = 1
38WRITE = 2
40query_type_map = {'A' : pycares.QUERY_TYPE_A,
41 'AAAA' : pycares.QUERY_TYPE_AAAA,
42 'ANY' : pycares.QUERY_TYPE_ANY,
43 'CAA' : pycares.QUERY_TYPE_CAA,
44 'CNAME' : pycares.QUERY_TYPE_CNAME,
45 'MX' : pycares.QUERY_TYPE_MX,
46 'NAPTR' : pycares.QUERY_TYPE_NAPTR,
47 'NS' : pycares.QUERY_TYPE_NS,
48 'PTR' : pycares.QUERY_TYPE_PTR,
49 'SOA' : pycares.QUERY_TYPE_SOA,
50 'SRV' : pycares.QUERY_TYPE_SRV,
51 'TXT' : pycares.QUERY_TYPE_TXT
52 }
54query_class_map = {'IN' : pycares.QUERY_CLASS_IN,
55 'CHAOS' : pycares.QUERY_CLASS_CHAOS,
56 'HS' : pycares.QUERY_CLASS_HS,
57 'NONE' : pycares.QUERY_CLASS_NONE,
58 'ANY' : pycares.QUERY_CLASS_ANY
59 }
61class DNSResolver:
62 def __init__(self, nameservers: Optional[Sequence[str]] = None,
63 loop: Optional[asyncio.AbstractEventLoop] = None,
64 **kwargs: Any) -> None: # TODO(PY311): Use Unpack for kwargs.
65 self.loop = loop or asyncio.get_event_loop()
66 assert self.loop is not None
67 kwargs.pop('sock_state_cb', None)
68 timeout = kwargs.pop('timeout', None)
69 self._timeout = timeout
70 self._event_thread, self._channel = self._make_channel(**kwargs)
71 if nameservers:
72 self.nameservers = nameservers
73 self._read_fds: set[int] = set()
74 self._write_fds: set[int] = set()
75 self._timer: Optional[asyncio.TimerHandle] = None
77 def _make_channel(self, **kwargs: Any) -> Tuple[bool, pycares.Channel]:
78 if hasattr(pycares, "ares_threadsafety") and pycares.ares_threadsafety():
79 # pycares is thread safe
80 try:
81 return True, pycares.Channel(
82 event_thread=True, timeout=self._timeout, **kwargs
83 )
84 except pycares.AresError as e:
85 if sys.platform == "linux":
86 _LOGGER.warning(
87 "Failed to create a DNS resolver channel with automatic monitoring of "
88 "resolver configuration changes, this usually means the system ran "
89 "out of inotify watches. Falling back to socket state callback. "
90 "Consider increasing the system inotify watch limit: %s",
91 e,
92 )
93 else:
94 _LOGGER.warning(
95 "Failed to create a DNS resolver channel with automatic monitoring "
96 "of resolver configuration changes. Falling back to socket state "
97 "callback: %s",
98 e,
99 )
100 if sys.platform == "win32" and not isinstance(
101 self.loop, asyncio.SelectorEventLoop
102 ):
103 try:
104 import winloop
106 if not isinstance(self.loop, winloop.Loop):
107 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG)
108 except ModuleNotFoundError as ex:
109 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex
110 return False, pycares.Channel(
111 sock_state_cb=self._sock_state_cb, timeout=self._timeout, **kwargs
112 )
114 @property
115 def nameservers(self) -> Sequence[str]:
116 return self._channel.servers
118 @nameservers.setter
119 def nameservers(self, value: Iterable[Union[str, bytes]]) -> None:
120 # Remove type ignore after mypy 1.16.0
121 # https://github.com/python/mypy/issues/12892
122 self._channel.servers = value # type: ignore[assignment]
124 @staticmethod
125 def _callback(fut: asyncio.Future[_T], result: _T, errorno: Optional[int]) -> None:
126 if fut.cancelled():
127 return
128 if errorno is not None:
129 fut.set_exception(error.DNSError(errorno, pycares.errno.strerror(errorno)))
130 else:
131 fut.set_result(result)
133 def _get_future_callback(self) -> Tuple["asyncio.Future[_T]", Callable[[_T, int], None]]:
134 """Return a future and a callback to set the result of the future."""
135 cb: Callable[[_T, int], None]
136 future: "asyncio.Future[_T]" = self.loop.create_future()
137 if self._event_thread:
138 cb = functools.partial( # type: ignore[assignment]
139 self.loop.call_soon_threadsafe,
140 self._callback, # type: ignore[arg-type]
141 future
142 )
143 else:
144 cb = functools.partial(self._callback, future)
145 return future, cb
147 @overload
148 def query(self, host: str, qtype: Literal["A"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_a_result]]:
149 ...
150 @overload
151 def query(self, host: str, qtype: Literal["AAAA"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_aaaa_result]]:
152 ...
153 @overload
154 def query(self, host: str, qtype: Literal["CAA"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_caa_result]]:
155 ...
156 @overload
157 def query(self, host: str, qtype: Literal["CNAME"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_cname_result]]:
158 ...
159 @overload
160 def query(self, host: str, qtype: Literal["MX"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_mx_result]]:
161 ...
162 @overload
163 def query(self, host: str, qtype: Literal["NAPTR"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_naptr_result]]:
164 ...
165 @overload
166 def query(self, host: str, qtype: Literal["NS"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_ns_result]]:
167 ...
168 @overload
169 def query(self, host: str, qtype: Literal["PTR"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_ptr_result]]:
170 ...
171 @overload
172 def query(self, host: str, qtype: Literal["SOA"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_soa_result]]:
173 ...
174 @overload
175 def query(self, host: str, qtype: Literal["SRV"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_srv_result]]:
176 ...
177 @overload
178 def query(self, host: str, qtype: Literal["TXT"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_txt_result]]:
179 ...
181 def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Future[list[Any]]:
182 try:
183 qtype = query_type_map[qtype]
184 except KeyError:
185 raise ValueError('invalid query type: {}'.format(qtype))
186 if qclass is not None:
187 try:
188 qclass = query_class_map[qclass]
189 except KeyError:
190 raise ValueError('invalid query class: {}'.format(qclass))
192 fut: asyncio.Future[list[Any]]
193 fut, cb = self._get_future_callback()
194 self._channel.query(host, qtype, cb, query_class=qclass)
195 return fut
197 def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future[pycares.ares_host_result]:
198 fut: asyncio.Future[pycares.ares_host_result]
199 fut, cb = self._get_future_callback()
200 self._channel.gethostbyname(host, family, cb)
201 return fut
203 def getaddrinfo(self, host: str, family: socket.AddressFamily = socket.AF_UNSPEC, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future[pycares.ares_addrinfo_result]:
204 fut: asyncio.Future[pycares.ares_addrinfo_result]
205 fut, cb = self._get_future_callback()
206 self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags)
207 return fut
209 def getnameinfo(self, sockaddr: Union[tuple[str, int], tuple[str, int, int, int]], flags: int = 0) -> asyncio.Future[pycares.ares_nameinfo_result]:
210 fut: asyncio.Future[pycares.ares_nameinfo_result]
211 fut, cb = self._get_future_callback()
212 self._channel.getnameinfo(sockaddr, flags, cb)
213 return fut
215 def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.ares_host_result]:
216 fut: asyncio.Future[pycares.ares_host_result]
217 fut, cb = self._get_future_callback()
218 self._channel.gethostbyaddr(name, cb)
219 return fut
221 def cancel(self) -> None:
222 self._channel.cancel()
224 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None:
225 if readable or writable:
226 if readable:
227 self.loop.add_reader(fd, self._handle_event, fd, READ)
228 self._read_fds.add(fd)
229 if writable:
230 self.loop.add_writer(fd, self._handle_event, fd, WRITE)
231 self._write_fds.add(fd)
232 if self._timer is None:
233 self._start_timer()
234 else:
235 # socket is now closed
236 if fd in self._read_fds:
237 self._read_fds.discard(fd)
238 self.loop.remove_reader(fd)
240 if fd in self._write_fds:
241 self._write_fds.discard(fd)
242 self.loop.remove_writer(fd)
244 if not self._read_fds and not self._write_fds and self._timer is not None:
245 self._timer.cancel()
246 self._timer = None
248 def _handle_event(self, fd: int, event: int) -> None:
249 read_fd = pycares.ARES_SOCKET_BAD
250 write_fd = pycares.ARES_SOCKET_BAD
251 if event == READ:
252 read_fd = fd
253 elif event == WRITE:
254 write_fd = fd
255 self._channel.process_fd(read_fd, write_fd)
257 def _timer_cb(self) -> None:
258 if self._read_fds or self._write_fds:
259 self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD)
260 self._start_timer()
261 else:
262 self._timer = None
264 def _start_timer(self) -> None:
265 timeout = self._timeout
266 if timeout is None or timeout < 0 or timeout > 1:
267 timeout = 1
268 elif timeout == 0:
269 timeout = 0.1
271 self._timer = self.loop.call_later(timeout, self._timer_cb)