Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/aiodns/__init__.py: 27%

107 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-09 06:47 +0000

1 

2import asyncio 

3import functools 

4import pycares 

5import socket 

6import sys 

7 

8from typing import ( 

9 Any, 

10 Optional, 

11 Set, 

12 Sequence 

13) 

14 

15from . import error 

16 

17 

18__version__ = '3.1.1' 

19 

20__all__ = ('DNSResolver', 'error') 

21 

22 

23READ = 1 

24WRITE = 2 

25 

26query_type_map = {'A' : pycares.QUERY_TYPE_A, 

27 'AAAA' : pycares.QUERY_TYPE_AAAA, 

28 'ANY' : pycares.QUERY_TYPE_ANY, 

29 'CAA' : pycares.QUERY_TYPE_CAA, 

30 'CNAME' : pycares.QUERY_TYPE_CNAME, 

31 'MX' : pycares.QUERY_TYPE_MX, 

32 'NAPTR' : pycares.QUERY_TYPE_NAPTR, 

33 'NS' : pycares.QUERY_TYPE_NS, 

34 'PTR' : pycares.QUERY_TYPE_PTR, 

35 'SOA' : pycares.QUERY_TYPE_SOA, 

36 'SRV' : pycares.QUERY_TYPE_SRV, 

37 'TXT' : pycares.QUERY_TYPE_TXT 

38 } 

39 

40query_class_map = {'IN' : pycares.QUERY_CLASS_IN, 

41 'CHAOS' : pycares.QUERY_CLASS_CHAOS, 

42 'HS' : pycares.QUERY_CLASS_HS, 

43 'NONE' : pycares.QUERY_CLASS_NONE, 

44 'ANY' : pycares.QUERY_CLASS_ANY 

45 } 

46 

47class DNSResolver: 

48 def __init__(self, nameservers: Optional[Sequence[str]] = None, 

49 loop: Optional[asyncio.AbstractEventLoop] = None, 

50 **kwargs: Any) -> None: 

51 self.loop = loop or asyncio.get_event_loop() 

52 assert self.loop is not None 

53 if sys.platform == 'win32': 

54 if not isinstance(self.loop, asyncio.SelectorEventLoop): 

55 raise RuntimeError( 

56 'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86') 

57 kwargs.pop('sock_state_cb', None) 

58 timeout = kwargs.pop('timeout', None) 

59 self._timeout = timeout 

60 self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb, 

61 timeout=timeout, 

62 **kwargs) 

63 if nameservers: 

64 self.nameservers = nameservers 

65 self._read_fds = set() # type: Set[int] 

66 self._write_fds = set() # type: Set[int] 

67 self._timer = None # type: Optional[asyncio.TimerHandle] 

68 

69 @property 

70 def nameservers(self) -> Sequence[str]: 

71 return self._channel.servers 

72 

73 @nameservers.setter 

74 def nameservers(self, value: Sequence[str]) -> None: 

75 self._channel.servers = value 

76 

77 @staticmethod 

78 def _callback(fut: asyncio.Future, result: Any, errorno: int) -> None: 

79 if fut.cancelled(): 

80 return 

81 if errorno is not None: 

82 fut.set_exception(error.DNSError(errorno, pycares.errno.strerror(errorno))) 

83 else: 

84 fut.set_result(result) 

85 

86 def query(self, host: str, qtype: str, qclass: Optional[str]=None) -> asyncio.Future: 

87 try: 

88 qtype = query_type_map[qtype] 

89 except KeyError: 

90 raise ValueError('invalid query type: {}'.format(qtype)) 

91 if qclass is not None: 

92 try: 

93 qclass = query_class_map[qclass] 

94 except KeyError: 

95 raise ValueError('invalid query class: {}'.format(qclass)) 

96 

97 fut = asyncio.Future(loop=self.loop) # type: asyncio.Future 

98 cb = functools.partial(self._callback, fut) 

99 self._channel.query(host, qtype, cb, query_class=qclass) 

100 return fut 

101 

102 def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Future: 

103 fut = asyncio.Future(loop=self.loop) # type: asyncio.Future 

104 cb = functools.partial(self._callback, fut) 

105 self._channel.gethostbyname(host, family, cb) 

106 return fut 

107 

108 def gethostbyaddr(self, name: str) -> asyncio.Future: 

109 fut = asyncio.Future(loop=self.loop) # type: asyncio.Future 

110 cb = functools.partial(self._callback, fut) 

111 self._channel.gethostbyaddr(name, cb) 

112 return fut 

113 

114 def cancel(self) -> None: 

115 self._channel.cancel() 

116 

117 def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: 

118 if readable or writable: 

119 if readable: 

120 self.loop.add_reader(fd, self._handle_event, fd, READ) 

121 self._read_fds.add(fd) 

122 if writable: 

123 self.loop.add_writer(fd, self._handle_event, fd, WRITE) 

124 self._write_fds.add(fd) 

125 if self._timer is None: 

126 self._start_timer() 

127 else: 

128 # socket is now closed 

129 if fd in self._read_fds: 

130 self._read_fds.discard(fd) 

131 self.loop.remove_reader(fd) 

132 

133 if fd in self._write_fds: 

134 self._write_fds.discard(fd) 

135 self.loop.remove_writer(fd) 

136 

137 if not self._read_fds and not self._write_fds and self._timer is not None: 

138 self._timer.cancel() 

139 self._timer = None 

140 

141 def _handle_event(self, fd: int, event: Any) -> None: 

142 read_fd = pycares.ARES_SOCKET_BAD 

143 write_fd = pycares.ARES_SOCKET_BAD 

144 if event == READ: 

145 read_fd = fd 

146 elif event == WRITE: 

147 write_fd = fd 

148 self._channel.process_fd(read_fd, write_fd) 

149 

150 def _timer_cb(self) -> None: 

151 if self._read_fds or self._write_fds: 

152 self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD) 

153 self._start_timer() 

154 else: 

155 self._timer = None 

156 

157 def _start_timer(self): 

158 timeout = self._timeout 

159 if timeout is None or timeout < 0 or timeout > 1: 

160 timeout = 1 

161 elif timeout == 0: 

162 timeout = 0.1 

163 

164 self._timer = self.loop.call_later(timeout, self._timer_cb)