Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.10/site-packages/aiodns/__init__.py: 36%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

171 statements  

1import asyncio 

2import functools 

3import logging 

4import pycares 

5import socket 

6import sys 

7from collections.abc import Iterable, Sequence 

8from typing import Any, Literal, Optional, TypeVar, Union, overload 

9 

10import pycares 

11from typing import ( 

12 Any, 

13 Callable, 

14 Optional, 

15 Set, 

16 Sequence, 

17 Tuple, 

18 Union 

19) 

20 

21from . import error 

22 

23 

24__version__ = '3.4.0' 

25 

26__all__ = ('DNSResolver', 'error') 

27 

28_T = TypeVar("_T") 

29 

30WINDOWS_SELECTOR_ERR_MSG = ( 

31 "aiodns needs a SelectorEventLoop on Windows. See more: " 

32 "https://github.com/aio-libs/aiodns#note-for-windows-users" 

33) 

34 

35_LOGGER = logging.getLogger(__name__) 

36 

37READ = 1 

38WRITE = 2 

39 

40query_type_map = {'A' : pycares.QUERY_TYPE_A, 

41 'AAAA' : pycares.QUERY_TYPE_AAAA, 

42 'ANY' : pycares.QUERY_TYPE_ANY, 

43 'CAA' : pycares.QUERY_TYPE_CAA, 

44 'CNAME' : pycares.QUERY_TYPE_CNAME, 

45 'MX' : pycares.QUERY_TYPE_MX, 

46 'NAPTR' : pycares.QUERY_TYPE_NAPTR, 

47 'NS' : pycares.QUERY_TYPE_NS, 

48 'PTR' : pycares.QUERY_TYPE_PTR, 

49 'SOA' : pycares.QUERY_TYPE_SOA, 

50 'SRV' : pycares.QUERY_TYPE_SRV, 

51 'TXT' : pycares.QUERY_TYPE_TXT 

52 } 

53 

54query_class_map = {'IN' : pycares.QUERY_CLASS_IN, 

55 'CHAOS' : pycares.QUERY_CLASS_CHAOS, 

56 'HS' : pycares.QUERY_CLASS_HS, 

57 'NONE' : pycares.QUERY_CLASS_NONE, 

58 'ANY' : pycares.QUERY_CLASS_ANY 

59 } 

60 

61class DNSResolver: 

62 def __init__(self, nameservers: Optional[Sequence[str]] = None, 

63 loop: Optional[asyncio.AbstractEventLoop] = None, 

64 **kwargs: Any) -> None: # TODO(PY311): Use Unpack for kwargs. 

65 self.loop = loop or asyncio.get_event_loop() 

66 assert self.loop is not None 

67 kwargs.pop('sock_state_cb', None) 

68 timeout = kwargs.pop('timeout', None) 

69 self._timeout = timeout 

70 self._event_thread, self._channel = self._make_channel(**kwargs) 

71 if nameservers: 

72 self.nameservers = nameservers 

73 self._read_fds: set[int] = set() 

74 self._write_fds: set[int] = set() 

75 self._timer: Optional[asyncio.TimerHandle] = None 

76 

77 def _make_channel(self, **kwargs: Any) -> Tuple[bool, pycares.Channel]: 

78 if hasattr(pycares, "ares_threadsafety") and pycares.ares_threadsafety(): 

79 # pycares is thread safe 

80 try: 

81 return True, pycares.Channel( 

82 event_thread=True, timeout=self._timeout, **kwargs 

83 ) 

84 except pycares.AresError as e: 

85 if sys.platform == "linux": 

86 _LOGGER.warning( 

87 "Failed to create a DNS resolver channel with automatic monitoring of " 

88 "resolver configuration changes, this usually means the system ran " 

89 "out of inotify watches. Falling back to socket state callback. " 

90 "Consider increasing the system inotify watch limit: %s", 

91 e, 

92 ) 

93 else: 

94 _LOGGER.warning( 

95 "Failed to create a DNS resolver channel with automatic monitoring " 

96 "of resolver configuration changes. Falling back to socket state " 

97 "callback: %s", 

98 e, 

99 ) 

100 if sys.platform == "win32" and not isinstance( 

101 self.loop, asyncio.SelectorEventLoop 

102 ): 

103 try: 

104 import winloop 

105 

106 if not isinstance(self.loop, winloop.Loop): 

107 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) 

108 except ModuleNotFoundError as ex: 

109 raise RuntimeError(WINDOWS_SELECTOR_ERR_MSG) from ex 

110 return False, pycares.Channel( 

111 sock_state_cb=self._sock_state_cb, timeout=self._timeout, **kwargs 

112 ) 

113 

114 @property 

115 def nameservers(self) -> Sequence[str]: 

116 return self._channel.servers 

117 

118 @nameservers.setter 

119 def nameservers(self, value: Iterable[Union[str, bytes]]) -> None: 

120 # Remove type ignore after mypy 1.16.0 

121 # https://github.com/python/mypy/issues/12892 

122 self._channel.servers = value # type: ignore[assignment] 

123 

124 @staticmethod 

125 def _callback(fut: asyncio.Future[_T], result: _T, errorno: Optional[int]) -> None: 

126 if fut.cancelled(): 

127 return 

128 if errorno is not None: 

129 fut.set_exception(error.DNSError(errorno, pycares.errno.strerror(errorno))) 

130 else: 

131 fut.set_result(result) 

132 

133 def _get_future_callback(self) -> Tuple["asyncio.Future[_T]", Callable[[_T, int], None]]: 

134 """Return a future and a callback to set the result of the future.""" 

135 cb: Callable[[_T, int], None] 

136 future: "asyncio.Future[_T]" = self.loop.create_future() 

137 if self._event_thread: 

138 cb = functools.partial( # type: ignore[assignment] 

139 self.loop.call_soon_threadsafe, 

140 self._callback, # type: ignore[arg-type] 

141 future 

142 ) 

143 else: 

144 cb = functools.partial(self._callback, future) 

145 return future, cb 

146 

147 @overload 

148 def query(self, host: str, qtype: Literal["A"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_a_result]]: 

149 ... 

150 @overload 

151 def query(self, host: str, qtype: Literal["AAAA"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_aaaa_result]]: 

152 ... 

153 @overload 

154 def query(self, host: str, qtype: Literal["CAA"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_caa_result]]: 

155 ... 

156 @overload 

157 def query(self, host: str, qtype: Literal["CNAME"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_cname_result]]: 

158 ... 

159 @overload 

160 def query(self, host: str, qtype: Literal["MX"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_mx_result]]: 

161 ... 

162 @overload 

163 def query(self, host: str, qtype: Literal["NAPTR"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_naptr_result]]: 

164 ... 

165 @overload 

166 def query(self, host: str, qtype: Literal["NS"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_ns_result]]: 

167 ... 

168 @overload 

169 def query(self, host: str, qtype: Literal["PTR"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_ptr_result]]: 

170 ... 

171 @overload 

172 def query(self, host: str, qtype: Literal["SOA"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_soa_result]]: 

173 ... 

174 @overload 

175 def query(self, host: str, qtype: Literal["SRV"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_srv_result]]: 

176 ... 

177 @overload 

178 def query(self, host: str, qtype: Literal["TXT"], qclass: Optional[str] = ...) -> asyncio.Future[list[pycares.ares_query_txt_result]]: 

179 ... 

180 

181 def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Future[list[Any]]: 

182 try: 

183 qtype = query_type_map[qtype] 

184 except KeyError: 

185 raise ValueError('invalid query type: {}'.format(qtype)) 

186 if qclass is not None: 

187 try: 

188 qclass = query_class_map[qclass] 

189 except KeyError: 

190 raise ValueError('invalid query class: {}'.format(qclass)) 

191 

192 fut: asyncio.Future[list[Any]] 

193 fut, cb = self._get_future_callback() 

194 self._channel.query(host, qtype, cb, query_class=qclass) 

195 return fut 

196 

197 def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future[pycares.ares_host_result]: 

198 fut: asyncio.Future[pycares.ares_host_result] 

199 fut, cb = self._get_future_callback() 

200 self._channel.gethostbyname(host, family, cb) 

201 return fut 

202 

203 def getaddrinfo(self, host: str, family: socket.AddressFamily = socket.AF_UNSPEC, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future[pycares.ares_addrinfo_result]: 

204 fut: asyncio.Future[pycares.ares_addrinfo_result] 

205 fut, cb = self._get_future_callback() 

206 self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags) 

207 return fut 

208 

209 def getnameinfo(self, sockaddr: Union[tuple[str, int], tuple[str, int, int, int]], flags: int = 0) -> asyncio.Future[pycares.ares_nameinfo_result]: 

210 fut: asyncio.Future[pycares.ares_nameinfo_result] 

211 fut, cb = self._get_future_callback() 

212 self._channel.getnameinfo(sockaddr, flags, cb) 

213 return fut 

214 

215 def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.ares_host_result]: 

216 fut: asyncio.Future[pycares.ares_host_result] 

217 fut, cb = self._get_future_callback() 

218 self._channel.gethostbyaddr(name, cb) 

219 return fut 

220 

221 def cancel(self) -> None: 

222 self._channel.cancel() 

223 

224 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: 

225 if readable or writable: 

226 if readable: 

227 self.loop.add_reader(fd, self._handle_event, fd, READ) 

228 self._read_fds.add(fd) 

229 if writable: 

230 self.loop.add_writer(fd, self._handle_event, fd, WRITE) 

231 self._write_fds.add(fd) 

232 if self._timer is None: 

233 self._start_timer() 

234 else: 

235 # socket is now closed 

236 if fd in self._read_fds: 

237 self._read_fds.discard(fd) 

238 self.loop.remove_reader(fd) 

239 

240 if fd in self._write_fds: 

241 self._write_fds.discard(fd) 

242 self.loop.remove_writer(fd) 

243 

244 if not self._read_fds and not self._write_fds and self._timer is not None: 

245 self._timer.cancel() 

246 self._timer = None 

247 

248 def _handle_event(self, fd: int, event: int) -> None: 

249 read_fd = pycares.ARES_SOCKET_BAD 

250 write_fd = pycares.ARES_SOCKET_BAD 

251 if event == READ: 

252 read_fd = fd 

253 elif event == WRITE: 

254 write_fd = fd 

255 self._channel.process_fd(read_fd, write_fd) 

256 

257 def _timer_cb(self) -> None: 

258 if self._read_fds or self._write_fds: 

259 self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD) 

260 self._start_timer() 

261 else: 

262 self._timer = None 

263 

264 def _start_timer(self) -> None: 

265 timeout = self._timeout 

266 if timeout is None or timeout < 0 or timeout > 1: 

267 timeout = 1 

268 elif timeout == 0: 

269 timeout = 0.1 

270 

271 self._timer = self.loop.call_later(timeout, self._timer_cb)