/src/hickory-dns/crates/proto/src/xfer/mod.rs
Line | Count | Source (jump to first uncovered line) |
1 | | //! DNS high level transit implementations. |
2 | | //! |
3 | | //! Primarily there are two types in this module of interest, the `DnsMultiplexer` type and the `DnsHandle` type. `DnsMultiplexer` can be thought of as the state machine responsible for sending and receiving DNS messages. `DnsHandle` is the type given to API users of the `hickory-proto` library to send messages into the `DnsMultiplexer` for delivery. Finally there is the `DnsRequest` type. This allows for customizations, through `DnsRequestOptions`, to the delivery of messages via a `DnsMultiplexer`. |
4 | | //! |
5 | | //! TODO: this module needs some serious refactoring and normalization. |
6 | | |
7 | | #[cfg(feature = "std")] |
8 | | use core::fmt::Display; |
9 | | use core::fmt::{self, Debug}; |
10 | | use core::future::Future; |
11 | | use core::pin::Pin; |
12 | | use core::task::{Context, Poll}; |
13 | | use core::time::Duration; |
14 | | #[cfg(feature = "std")] |
15 | | use std::net::SocketAddr; |
16 | | |
17 | | #[cfg(feature = "std")] |
18 | | use futures_channel::mpsc; |
19 | | #[cfg(feature = "std")] |
20 | | use futures_channel::oneshot; |
21 | | use futures_util::ready; |
22 | | #[cfg(feature = "std")] |
23 | | use futures_util::stream::{Fuse, Peekable}; |
24 | | use futures_util::stream::{Stream, StreamExt}; |
25 | | #[cfg(feature = "serde")] |
26 | | use serde::{Deserialize, Serialize}; |
27 | | #[cfg(feature = "std")] |
28 | | use tracing::{debug, warn}; |
29 | | |
30 | | use crate::error::{ProtoError, ProtoErrorKind}; |
31 | | #[cfg(feature = "std")] |
32 | | use crate::runtime::Time; |
33 | | |
34 | | #[cfg(feature = "std")] |
35 | | mod dns_exchange; |
36 | | pub mod dns_handle; |
37 | | #[cfg(feature = "std")] |
38 | | pub mod dns_multiplexer; |
39 | | pub mod dns_request; |
40 | | pub mod dns_response; |
41 | | pub mod retry_dns_handle; |
42 | | mod serial_message; |
43 | | |
44 | | #[cfg(feature = "std")] |
45 | | pub use self::dns_exchange::{ |
46 | | Connecting, DnsExchange, DnsExchangeBackground, DnsExchangeConnect, DnsExchangeSend, |
47 | | }; |
48 | | pub use self::dns_handle::{DnsHandle, DnsStreamHandle}; |
49 | | #[cfg(feature = "std")] |
50 | | pub use self::dns_multiplexer::{DnsMultiplexer, DnsMultiplexerConnect}; |
51 | | pub use self::dns_request::{DnsRequest, DnsRequestOptions}; |
52 | | pub use self::dns_response::DnsResponse; |
53 | | #[cfg(feature = "std")] |
54 | | pub use self::dns_response::DnsResponseStream; |
55 | | pub use self::retry_dns_handle::RetryDnsHandle; |
56 | | pub use self::serial_message::SerialMessage; |
57 | | |
58 | | /// Ignores the result of a send operation and logs and ignores errors |
59 | | #[cfg(feature = "std")] |
60 | 0 | fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) { |
61 | 0 | if let Err(error) = result { |
62 | 0 | if error.is_disconnected() { |
63 | 0 | debug!("ignoring send error on disconnected stream"); |
64 | 0 | return; |
65 | 0 | } |
66 | 0 |
|
67 | 0 | warn!("error notifying wait, possible future leak: {:?}", error); |
68 | 0 | } |
69 | 0 | } |
70 | | |
71 | | /// A non-multiplexed stream of Serialized DNS messages |
72 | | #[cfg(feature = "std")] |
73 | | pub trait DnsClientStream: |
74 | | Stream<Item = Result<SerialMessage, ProtoError>> + Display + Send |
75 | | { |
76 | | /// Time implementation for this impl |
77 | | type Time: Time; |
78 | | |
79 | | /// The remote name server address |
80 | | fn name_server_addr(&self) -> SocketAddr; |
81 | | } |
82 | | |
83 | | /// Receiver handle for peekable fused SerialMessage channel |
84 | | #[cfg(feature = "std")] |
85 | | pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>; |
86 | | |
87 | | #[cfg(feature = "std")] |
88 | | const CHANNEL_BUFFER_SIZE: usize = 32; |
89 | | |
90 | | /// A buffering stream bound to a `SocketAddr` |
91 | | /// |
92 | | /// This stream handle ensures that all messages sent via this handle have the remote_addr set as the destination for the packet |
93 | | #[derive(Clone)] |
94 | | #[cfg(feature = "std")] |
95 | | pub struct BufDnsStreamHandle { |
96 | | remote_addr: SocketAddr, |
97 | | sender: mpsc::Sender<SerialMessage>, |
98 | | } |
99 | | |
100 | | #[cfg(feature = "std")] |
101 | | impl BufDnsStreamHandle { |
102 | | /// Constructs a new Buffered Stream Handle, used for sending data to the DNS peer. |
103 | | /// |
104 | | /// # Arguments |
105 | | /// |
106 | | /// * `remote_addr` - the address of the remote DNS system (client or server) |
107 | | /// * `sender` - the handle being used to send data to the server |
108 | 0 | pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) { |
109 | 0 | let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE); |
110 | 0 | let receiver = receiver.fuse().peekable(); |
111 | 0 |
|
112 | 0 | let this = Self { |
113 | 0 | remote_addr, |
114 | 0 | sender, |
115 | 0 | }; |
116 | 0 |
|
117 | 0 | (this, receiver) |
118 | 0 | } |
119 | | |
120 | | /// Associates a different remote address for any responses. |
121 | | /// |
122 | | /// This is mainly useful in server use cases where the incoming address is only known after receiving a packet. |
123 | 0 | pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self { |
124 | 0 | Self { |
125 | 0 | remote_addr, |
126 | 0 | sender: self.sender.clone(), |
127 | 0 | } |
128 | 0 | } |
129 | | } |
130 | | |
131 | | #[cfg(feature = "std")] |
132 | | impl DnsStreamHandle for BufDnsStreamHandle { |
133 | 0 | fn send(&mut self, buffer: SerialMessage) -> Result<(), ProtoError> { |
134 | 0 | let sender: &mut _ = &mut self.sender; |
135 | 0 | sender |
136 | 0 | .try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr)) |
137 | 0 | .map_err(|e| ProtoError::from(format!("mpsc::SendError {e}"))) |
138 | 0 | } |
139 | | } |
140 | | |
141 | | /// Types that implement this are capable of sending a serialized DNS message on a stream |
142 | | /// |
143 | | /// The underlying Stream implementation should yield `Some(())` whenever it is ready to send a message, |
144 | | /// NotReady, if it is not ready to send a message, and `Err` or `None` in the case that the stream is |
145 | | /// done, and should be shutdown. |
146 | | #[cfg(feature = "std")] |
147 | | pub trait DnsRequestSender: Stream<Item = Result<(), ProtoError>> + Send + Unpin + 'static { |
148 | | /// Send a message, and return a stream of response |
149 | | /// |
150 | | /// # Return |
151 | | /// |
152 | | /// A stream which will resolve to SerialMessage responses |
153 | | fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream; |
154 | | |
155 | | /// Allows the upstream user to inform the underling stream that it should shutdown. |
156 | | /// |
157 | | /// After this is called, the next time `poll` is called on the stream it would be correct to return `Poll::Ready(Ok(()))`. This is not required though, if there are say outstanding requests that are not yet complete, then it would be correct to first wait for those results. |
158 | | fn shutdown(&mut self); |
159 | | |
160 | | /// Returns true if the stream has been shutdown with `shutdown` |
161 | | fn is_shutdown(&self) -> bool; |
162 | | } |
163 | | |
164 | | /// Used for associating a name_server to a DnsRequestStreamHandle |
165 | | #[derive(Clone)] |
166 | | #[cfg(feature = "std")] |
167 | | pub struct BufDnsRequestStreamHandle { |
168 | | sender: mpsc::Sender<OneshotDnsRequest>, |
169 | | } |
170 | | |
171 | | #[cfg(feature = "std")] |
172 | | macro_rules! try_oneshot { |
173 | | ($expr:expr) => {{ |
174 | | use core::result::Result; |
175 | | |
176 | | match $expr { |
177 | | Result::Ok(val) => val, |
178 | | Result::Err(err) => return DnsResponseReceiver::Err(Some(ProtoError::from(err))), |
179 | | } |
180 | | }}; |
181 | | ($expr:expr,) => { |
182 | | $expr? |
183 | | }; |
184 | | } |
185 | | |
186 | | #[cfg(feature = "std")] |
187 | | impl DnsHandle for BufDnsRequestStreamHandle { |
188 | | type Response = DnsResponseReceiver; |
189 | | |
190 | 0 | fn send(&self, request: DnsRequest) -> Self::Response { |
191 | 0 | debug!( |
192 | 0 | "enqueueing message:{}:{:?}", |
193 | 0 | request.op_code(), |
194 | 0 | request.queries() |
195 | | ); |
196 | | |
197 | 0 | let (request, oneshot) = OneshotDnsRequest::oneshot(request); |
198 | 0 | let mut sender = self.sender.clone(); |
199 | 0 | let try_send = sender.try_send(request).map_err(|_| { |
200 | 0 | debug!("unable to enqueue message"); |
201 | 0 | ProtoError::from(ProtoErrorKind::Busy) |
202 | 0 | }); |
203 | 0 | try_oneshot!(try_send); |
204 | | |
205 | 0 | DnsResponseReceiver::Receiver(oneshot) |
206 | 0 | } |
207 | | } |
208 | | |
209 | | // TODO: this future should return the origin message in the response on errors |
210 | | /// A OneshotDnsRequest creates a channel for a response to message |
211 | | #[cfg(feature = "std")] |
212 | | pub struct OneshotDnsRequest { |
213 | | dns_request: DnsRequest, |
214 | | sender_for_response: oneshot::Sender<DnsResponseStream>, |
215 | | } |
216 | | |
217 | | #[cfg(feature = "std")] |
218 | | impl OneshotDnsRequest { |
219 | | #[cfg(any(feature = "std", feature = "no-std-rand"))] |
220 | 0 | fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) { |
221 | 0 | let (sender_for_response, receiver) = oneshot::channel(); |
222 | 0 |
|
223 | 0 | ( |
224 | 0 | Self { |
225 | 0 | dns_request, |
226 | 0 | sender_for_response, |
227 | 0 | }, |
228 | 0 | receiver, |
229 | 0 | ) |
230 | 0 | } |
231 | | |
232 | 0 | fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) { |
233 | 0 | ( |
234 | 0 | self.dns_request, |
235 | 0 | OneshotDnsResponse(self.sender_for_response), |
236 | 0 | ) |
237 | 0 | } |
238 | | } |
239 | | |
240 | | #[cfg(feature = "std")] |
241 | | struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>); |
242 | | |
243 | | #[cfg(feature = "std")] |
244 | | impl OneshotDnsResponse { |
245 | 0 | fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> { |
246 | 0 | self.0.send(serial_response) |
247 | 0 | } |
248 | | } |
249 | | |
250 | | /// A Stream that wraps a [`oneshot::Receiver<Stream>`] and resolves to items in the inner Stream |
251 | | #[cfg(feature = "std")] |
252 | | pub enum DnsResponseReceiver { |
253 | | /// The receiver |
254 | | Receiver(oneshot::Receiver<DnsResponseStream>), |
255 | | /// The stream once received |
256 | | Received(DnsResponseStream), |
257 | | /// Error during the send operation |
258 | | Err(Option<ProtoError>), |
259 | | } |
260 | | |
261 | | #[cfg(feature = "std")] |
262 | | impl Stream for DnsResponseReceiver { |
263 | | type Item = Result<DnsResponse, ProtoError>; |
264 | | |
265 | 0 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
266 | | loop { |
267 | 0 | *self = match &mut *self { |
268 | 0 | Self::Receiver(receiver) => { |
269 | 0 | let receiver = Pin::new(receiver); |
270 | 0 | let future = ready!( |
271 | 0 | receiver |
272 | 0 | .poll(cx) |
273 | 0 | .map_err(|_| ProtoError::from("receiver was canceled")) |
274 | 0 | )?; |
275 | 0 | Self::Received(future) |
276 | | } |
277 | 0 | Self::Received(stream) => { |
278 | 0 | return stream.poll_next_unpin(cx); |
279 | | } |
280 | 0 | Self::Err(err) => return Poll::Ready(err.take().map(Err)), |
281 | | }; |
282 | | } |
283 | 0 | } |
284 | | } |
285 | | |
286 | | /// Helper trait to convert a Stream of dns response into a Future |
287 | | pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized { |
288 | | /// Convert a Stream of dns response into a Future yielding the first answer, |
289 | | /// discarding others if any. |
290 | 0 | fn first_answer(self) -> FirstAnswerFuture<Self> { |
291 | 0 | FirstAnswerFuture { stream: Some(self) } |
292 | 0 | } |
293 | | } |
294 | | |
295 | | impl<E, S, T> FirstAnswer<T, E> for S |
296 | | where |
297 | | S: Stream<Item = Result<T, E>> + Unpin + Sized, |
298 | | E: From<ProtoError>, |
299 | | { |
300 | | } |
301 | | |
302 | | /// See [FirstAnswer::first_answer] |
303 | | #[derive(Debug)] |
304 | | #[must_use = "futures do nothing unless you `.await` or poll them"] |
305 | | pub struct FirstAnswerFuture<S> { |
306 | | stream: Option<S>, |
307 | | } |
308 | | |
309 | | impl<E, S: Stream<Item = Result<T, E>> + Unpin, T> Future for FirstAnswerFuture<S> |
310 | | where |
311 | | S: Stream<Item = Result<T, E>> + Unpin + Sized, |
312 | | E: From<ProtoError>, |
313 | | { |
314 | | type Output = S::Item; |
315 | | |
316 | 0 | fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { |
317 | 0 | let s = self |
318 | 0 | .stream |
319 | 0 | .as_mut() |
320 | 0 | .expect("polling FirstAnswerFuture twice"); |
321 | 0 | let item = match ready!(s.poll_next_unpin(cx)) { |
322 | 0 | Some(r) => r, |
323 | 0 | None => Err(ProtoError::from(ProtoErrorKind::Timeout).into()), |
324 | | }; |
325 | 0 | self.stream.take(); |
326 | 0 | Poll::Ready(item) |
327 | 0 | } |
328 | | } |
329 | | |
330 | | /// The protocol on which a NameServer should be communicated with |
331 | | #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] |
332 | | #[cfg_attr( |
333 | | feature = "serde", |
334 | | derive(Serialize, Deserialize), |
335 | | serde(rename_all = "lowercase") |
336 | | )] |
337 | | #[non_exhaustive] |
338 | | pub enum Protocol { |
339 | | /// UDP is the traditional DNS port, this is generally the correct choice |
340 | | Udp, |
341 | | /// TCP can be used for large queries, but not all NameServers support it |
342 | | Tcp, |
343 | | /// Tls for DNS over TLS |
344 | | #[cfg(feature = "__tls")] |
345 | | Tls, |
346 | | /// Https for DNS over HTTPS |
347 | | #[cfg(feature = "__https")] |
348 | | Https, |
349 | | /// QUIC for DNS over QUIC |
350 | | #[cfg(feature = "__quic")] |
351 | | Quic, |
352 | | /// HTTP/3 for DNS over HTTP/3 |
353 | | #[cfg(feature = "__h3")] |
354 | | H3, |
355 | | } |
356 | | |
357 | | impl Protocol { |
358 | | /// Returns true if this is a datagram oriented protocol, e.g. UDP |
359 | 0 | pub fn is_datagram(self) -> bool { |
360 | 0 | matches!(self, Self::Udp) |
361 | 0 | } |
362 | | |
363 | | /// Returns true if this is a stream oriented protocol, e.g. TCP |
364 | 0 | pub fn is_stream(self) -> bool { |
365 | 0 | !self.is_datagram() |
366 | 0 | } |
367 | | |
368 | | /// Is this an encrypted protocol, i.e. TLS or HTTPS |
369 | 0 | pub fn is_encrypted(self) -> bool { |
370 | 0 | match self { |
371 | 0 | Self::Udp => false, |
372 | 0 | Self::Tcp => false, |
373 | | #[cfg(feature = "__tls")] |
374 | | Self::Tls => true, |
375 | | #[cfg(feature = "__https")] |
376 | | Self::Https => true, |
377 | | #[cfg(feature = "__quic")] |
378 | | Self::Quic => true, |
379 | | #[cfg(feature = "__h3")] |
380 | | Self::H3 => true, |
381 | | } |
382 | 0 | } |
383 | | } |
384 | | |
385 | | impl fmt::Display for Protocol { |
386 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
387 | 0 | f.write_str(match self { |
388 | 0 | Self::Udp => "udp", |
389 | 0 | Self::Tcp => "tcp", |
390 | | #[cfg(feature = "__tls")] |
391 | | Self::Tls => "tls", |
392 | | #[cfg(feature = "__https")] |
393 | | Self::Https => "https", |
394 | | #[cfg(feature = "__quic")] |
395 | | Self::Quic => "quic", |
396 | | #[cfg(feature = "__h3")] |
397 | | Self::H3 => "h3", |
398 | | }) |
399 | 0 | } |
400 | | } |
401 | | |
402 | | #[allow(unused)] // May be unused depending on features |
403 | | pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); |