1import asyncio
2import socket
3import weakref
4from typing import Any, Dict, Final, List, Optional, Tuple, Type, Union
5
6from .abc import AbstractResolver, ResolveResult
7
8__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver")
9
10
11try:
12 import aiodns
13
14 aiodns_default = hasattr(aiodns.DNSResolver, "getaddrinfo")
15except ImportError: # pragma: no cover
16 aiodns = None # type: ignore[assignment]
17 aiodns_default = False
18
19
20_NUMERIC_SOCKET_FLAGS = socket.AI_NUMERICHOST | socket.AI_NUMERICSERV
21_NAME_SOCKET_FLAGS = socket.NI_NUMERICHOST | socket.NI_NUMERICSERV
22_AI_ADDRCONFIG = socket.AI_ADDRCONFIG
23if hasattr(socket, "AI_MASK"):
24 _AI_ADDRCONFIG &= socket.AI_MASK
25
26
27class ThreadedResolver(AbstractResolver):
28 """Threaded resolver.
29
30 Uses an Executor for synchronous getaddrinfo() calls.
31 concurrent.futures.ThreadPoolExecutor is used by default.
32 """
33
34 def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
35 self._loop = loop or asyncio.get_running_loop()
36
37 async def resolve(
38 self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
39 ) -> List[ResolveResult]:
40 infos = await self._loop.getaddrinfo(
41 host,
42 port,
43 type=socket.SOCK_STREAM,
44 family=family,
45 flags=_AI_ADDRCONFIG,
46 )
47
48 hosts: List[ResolveResult] = []
49 for family, _, proto, _, address in infos:
50 if family == socket.AF_INET6:
51 if len(address) < 3:
52 # IPv6 is not supported by Python build,
53 # or IPv6 is not enabled in the host
54 continue
55 if address[3]:
56 # This is essential for link-local IPv6 addresses.
57 # LL IPv6 is a VERY rare case. Strictly speaking, we should use
58 # getnameinfo() unconditionally, but performance makes sense.
59 resolved_host, _port = await self._loop.getnameinfo(
60 address, _NAME_SOCKET_FLAGS
61 )
62 port = int(_port)
63 else:
64 resolved_host, port = address[:2]
65 else: # IPv4
66 assert family == socket.AF_INET
67 resolved_host, port = address # type: ignore[misc]
68 hosts.append(
69 ResolveResult(
70 hostname=host,
71 host=resolved_host,
72 port=port,
73 family=family,
74 proto=proto,
75 flags=_NUMERIC_SOCKET_FLAGS,
76 )
77 )
78
79 return hosts
80
81 async def close(self) -> None:
82 pass
83
84
85class AsyncResolver(AbstractResolver):
86 """Use the `aiodns` package to make asynchronous DNS lookups"""
87
88 def __init__(
89 self,
90 loop: Optional[asyncio.AbstractEventLoop] = None,
91 *args: Any,
92 **kwargs: Any,
93 ) -> None:
94 if aiodns is None:
95 raise RuntimeError("Resolver requires aiodns library")
96
97 self._loop = loop or asyncio.get_running_loop()
98 self._manager: Optional[_DNSResolverManager] = None
99 # If custom args are provided, create a dedicated resolver instance
100 # This means each AsyncResolver with custom args gets its own
101 # aiodns.DNSResolver instance
102 if args or kwargs:
103 self._resolver = aiodns.DNSResolver(*args, **kwargs)
104 return
105 # Use the shared resolver from the manager for default arguments
106 self._manager = _DNSResolverManager()
107 self._resolver = self._manager.get_resolver(self, self._loop)
108
109 if not hasattr(self._resolver, "gethostbyname"):
110 # aiodns 1.1 is not available, fallback to DNSResolver.query
111 self.resolve = self._resolve_with_query # type: ignore
112
113 async def resolve(
114 self, host: str, port: int = 0, family: socket.AddressFamily = socket.AF_INET
115 ) -> List[ResolveResult]:
116 try:
117 resp = await self._resolver.getaddrinfo(
118 host,
119 port=port,
120 type=socket.SOCK_STREAM,
121 family=family,
122 flags=_AI_ADDRCONFIG,
123 )
124 except aiodns.error.DNSError as exc:
125 msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
126 raise OSError(None, msg) from exc
127 hosts: List[ResolveResult] = []
128 for node in resp.nodes:
129 address: Union[Tuple[bytes, int], Tuple[bytes, int, int, int]] = node.addr
130 family = node.family
131 if family == socket.AF_INET6:
132 if len(address) > 3 and address[3]:
133 # This is essential for link-local IPv6 addresses.
134 # LL IPv6 is a VERY rare case. Strictly speaking, we should use
135 # getnameinfo() unconditionally, but performance makes sense.
136 result = await self._resolver.getnameinfo(
137 (address[0].decode("ascii"), *address[1:]),
138 _NAME_SOCKET_FLAGS,
139 )
140 resolved_host = result.node
141 else:
142 resolved_host = address[0].decode("ascii")
143 port = address[1]
144 else: # IPv4
145 assert family == socket.AF_INET
146 resolved_host = address[0].decode("ascii")
147 port = address[1]
148 hosts.append(
149 ResolveResult(
150 hostname=host,
151 host=resolved_host,
152 port=port,
153 family=family,
154 proto=0,
155 flags=_NUMERIC_SOCKET_FLAGS,
156 )
157 )
158
159 if not hosts:
160 raise OSError(None, "DNS lookup failed")
161
162 return hosts
163
164 async def _resolve_with_query(
165 self, host: str, port: int = 0, family: int = socket.AF_INET
166 ) -> List[Dict[str, Any]]:
167 qtype: Final = "AAAA" if family == socket.AF_INET6 else "A"
168
169 try:
170 resp = await self._resolver.query(host, qtype)
171 except aiodns.error.DNSError as exc:
172 msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed"
173 raise OSError(None, msg) from exc
174
175 hosts = []
176 for rr in resp:
177 hosts.append(
178 {
179 "hostname": host,
180 "host": rr.host,
181 "port": port,
182 "family": family,
183 "proto": 0,
184 "flags": socket.AI_NUMERICHOST,
185 }
186 )
187
188 if not hosts:
189 raise OSError(None, "DNS lookup failed")
190
191 return hosts
192
193 async def close(self) -> None:
194 if self._manager:
195 # Release the resolver from the manager if using the shared resolver
196 self._manager.release_resolver(self, self._loop)
197 self._manager = None # Clear reference to manager
198 self._resolver = None # type: ignore[assignment] # Clear reference to resolver
199 return
200 # Otherwise cancel our dedicated resolver
201 if self._resolver is not None:
202 self._resolver.cancel()
203 self._resolver = None # type: ignore[assignment] # Clear reference
204
205
206class _DNSResolverManager:
207 """Manager for aiodns.DNSResolver objects.
208
209 This class manages shared aiodns.DNSResolver instances
210 with no custom arguments across different event loops.
211 """
212
213 _instance: Optional["_DNSResolverManager"] = None
214
215 def __new__(cls) -> "_DNSResolverManager":
216 if cls._instance is None:
217 cls._instance = super().__new__(cls)
218 cls._instance._init()
219 return cls._instance
220
221 def _init(self) -> None:
222 # Use WeakKeyDictionary to allow event loops to be garbage collected
223 self._loop_data: weakref.WeakKeyDictionary[
224 asyncio.AbstractEventLoop,
225 tuple["aiodns.DNSResolver", weakref.WeakSet["AsyncResolver"]],
226 ] = weakref.WeakKeyDictionary()
227
228 def get_resolver(
229 self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
230 ) -> "aiodns.DNSResolver":
231 """Get or create the shared aiodns.DNSResolver instance for a specific event loop.
232
233 Args:
234 client: The AsyncResolver instance requesting the resolver.
235 This is required to track resolver usage.
236 loop: The event loop to use for the resolver.
237 """
238 # Create a new resolver and client set for this loop if it doesn't exist
239 if loop not in self._loop_data:
240 resolver = aiodns.DNSResolver(loop=loop)
241 client_set: weakref.WeakSet["AsyncResolver"] = weakref.WeakSet()
242 self._loop_data[loop] = (resolver, client_set)
243 else:
244 # Get the existing resolver and client set
245 resolver, client_set = self._loop_data[loop]
246
247 # Register this client with the loop
248 client_set.add(client)
249 return resolver
250
251 def release_resolver(
252 self, client: "AsyncResolver", loop: asyncio.AbstractEventLoop
253 ) -> None:
254 """Release the resolver for an AsyncResolver client when it's closed.
255
256 Args:
257 client: The AsyncResolver instance to release.
258 loop: The event loop the resolver was using.
259 """
260 # Remove client from its loop's tracking
261 current_loop_data = self._loop_data.get(loop)
262 if current_loop_data is None:
263 return
264 resolver, client_set = current_loop_data
265 client_set.discard(client)
266 # If no more clients for this loop, cancel and remove its resolver
267 if not client_set:
268 if resolver is not None:
269 resolver.cancel()
270 del self._loop_data[loop]
271
272
273_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]]
274DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver