Coverage Report

Created: 2025-07-23 07:04

/src/hickory-dns/crates/proto/src/xfer/mod.rs
Line
Count
Source (jump to first uncovered line)
1
//! DNS high level transit implementations.
2
//!
3
//! Primarily there are two types in this module of interest, the `DnsMultiplexer` type and the `DnsHandle` type. `DnsMultiplexer` can be thought of as the state machine responsible for sending and receiving DNS messages. `DnsHandle` is the type given to API users of the `hickory-proto` library to send messages into the `DnsMultiplexer` for delivery. Finally there is the `DnsRequest` type. This allows for customizations, through `DnsRequestOptions`, to the delivery of messages via a `DnsMultiplexer`.
4
//!
5
//! TODO: this module needs some serious refactoring and normalization.
6
7
#[cfg(feature = "std")]
8
use core::fmt::Display;
9
use core::fmt::{self, Debug};
10
use core::future::Future;
11
use core::pin::Pin;
12
use core::task::{Context, Poll};
13
use core::time::Duration;
14
#[cfg(feature = "std")]
15
use std::net::SocketAddr;
16
17
#[cfg(feature = "std")]
18
use futures_channel::mpsc;
19
#[cfg(feature = "std")]
20
use futures_channel::oneshot;
21
use futures_util::ready;
22
#[cfg(feature = "std")]
23
use futures_util::stream::{Fuse, Peekable};
24
use futures_util::stream::{Stream, StreamExt};
25
#[cfg(feature = "serde")]
26
use serde::{Deserialize, Serialize};
27
#[cfg(feature = "std")]
28
use tracing::{debug, warn};
29
30
use crate::error::{ProtoError, ProtoErrorKind};
31
#[cfg(feature = "std")]
32
use crate::runtime::Time;
33
34
#[cfg(feature = "std")]
35
mod dns_exchange;
36
pub mod dns_handle;
37
#[cfg(feature = "std")]
38
pub mod dns_multiplexer;
39
pub mod dns_request;
40
pub mod dns_response;
41
pub mod retry_dns_handle;
42
mod serial_message;
43
44
#[cfg(feature = "std")]
45
pub use self::dns_exchange::{
46
    Connecting, DnsExchange, DnsExchangeBackground, DnsExchangeConnect, DnsExchangeSend,
47
};
48
pub use self::dns_handle::{DnsHandle, DnsStreamHandle};
49
#[cfg(feature = "std")]
50
pub use self::dns_multiplexer::{DnsMultiplexer, DnsMultiplexerConnect};
51
pub use self::dns_request::{DnsRequest, DnsRequestOptions};
52
pub use self::dns_response::DnsResponse;
53
#[cfg(feature = "std")]
54
pub use self::dns_response::DnsResponseStream;
55
pub use self::retry_dns_handle::RetryDnsHandle;
56
pub use self::serial_message::SerialMessage;
57
58
/// Ignores the result of a send operation and logs and ignores errors
59
#[cfg(feature = "std")]
60
0
fn ignore_send<M, T>(result: Result<M, mpsc::TrySendError<T>>) {
61
0
    if let Err(error) = result {
62
0
        if error.is_disconnected() {
63
0
            debug!("ignoring send error on disconnected stream");
64
0
            return;
65
0
        }
66
0
67
0
        warn!("error notifying wait, possible future leak: {:?}", error);
68
0
    }
69
0
}
70
71
/// A non-multiplexed stream of Serialized DNS messages
72
#[cfg(feature = "std")]
73
pub trait DnsClientStream:
74
    Stream<Item = Result<SerialMessage, ProtoError>> + Display + Send
75
{
76
    /// Time implementation for this impl
77
    type Time: Time;
78
79
    /// The remote name server address
80
    fn name_server_addr(&self) -> SocketAddr;
81
}
82
83
/// Receiver handle for peekable fused SerialMessage channel
84
#[cfg(feature = "std")]
85
pub type StreamReceiver = Peekable<Fuse<mpsc::Receiver<SerialMessage>>>;
86
87
#[cfg(feature = "std")]
88
const CHANNEL_BUFFER_SIZE: usize = 32;
89
90
/// A buffering stream bound to a `SocketAddr`
91
///
92
/// This stream handle ensures that all messages sent via this handle have the remote_addr set as the destination for the packet
93
#[derive(Clone)]
94
#[cfg(feature = "std")]
95
pub struct BufDnsStreamHandle {
96
    remote_addr: SocketAddr,
97
    sender: mpsc::Sender<SerialMessage>,
98
}
99
100
#[cfg(feature = "std")]
101
impl BufDnsStreamHandle {
102
    /// Constructs a new Buffered Stream Handle, used for sending data to the DNS peer.
103
    ///
104
    /// # Arguments
105
    ///
106
    /// * `remote_addr` - the address of the remote DNS system (client or server)
107
    /// * `sender` - the handle being used to send data to the server
108
0
    pub fn new(remote_addr: SocketAddr) -> (Self, StreamReceiver) {
109
0
        let (sender, receiver) = mpsc::channel(CHANNEL_BUFFER_SIZE);
110
0
        let receiver = receiver.fuse().peekable();
111
0
112
0
        let this = Self {
113
0
            remote_addr,
114
0
            sender,
115
0
        };
116
0
117
0
        (this, receiver)
118
0
    }
119
120
    /// Associates a different remote address for any responses.
121
    ///
122
    /// This is mainly useful in server use cases where the incoming address is only known after receiving a packet.
123
0
    pub fn with_remote_addr(&self, remote_addr: SocketAddr) -> Self {
124
0
        Self {
125
0
            remote_addr,
126
0
            sender: self.sender.clone(),
127
0
        }
128
0
    }
129
}
130
131
#[cfg(feature = "std")]
132
impl DnsStreamHandle for BufDnsStreamHandle {
133
0
    fn send(&mut self, buffer: SerialMessage) -> Result<(), ProtoError> {
134
0
        let sender: &mut _ = &mut self.sender;
135
0
        sender
136
0
            .try_send(SerialMessage::new(buffer.into_parts().0, self.remote_addr))
137
0
            .map_err(|e| ProtoError::from(format!("mpsc::SendError {e}")))
138
0
    }
139
}
140
141
/// Types that implement this are capable of sending a serialized DNS message on a stream
142
///
143
/// The underlying Stream implementation should yield `Some(())` whenever it is ready to send a message,
144
///   NotReady, if it is not ready to send a message, and `Err` or `None` in the case that the stream is
145
///   done, and should be shutdown.
146
#[cfg(feature = "std")]
147
pub trait DnsRequestSender: Stream<Item = Result<(), ProtoError>> + Send + Unpin + 'static {
148
    /// Send a message, and return a stream of response
149
    ///
150
    /// # Return
151
    ///
152
    /// A stream which will resolve to SerialMessage responses
153
    fn send_message(&mut self, request: DnsRequest) -> DnsResponseStream;
154
155
    /// Allows the upstream user to inform the underling stream that it should shutdown.
156
    ///
157
    /// After this is called, the next time `poll` is called on the stream it would be correct to return `Poll::Ready(Ok(()))`. This is not required though, if there are say outstanding requests that are not yet complete, then it would be correct to first wait for those results.
158
    fn shutdown(&mut self);
159
160
    /// Returns true if the stream has been shutdown with `shutdown`
161
    fn is_shutdown(&self) -> bool;
162
}
163
164
/// Used for associating a name_server to a DnsRequestStreamHandle
165
#[derive(Clone)]
166
#[cfg(feature = "std")]
167
pub struct BufDnsRequestStreamHandle {
168
    sender: mpsc::Sender<OneshotDnsRequest>,
169
}
170
171
#[cfg(feature = "std")]
172
macro_rules! try_oneshot {
173
    ($expr:expr) => {{
174
        use core::result::Result;
175
176
        match $expr {
177
            Result::Ok(val) => val,
178
            Result::Err(err) => return DnsResponseReceiver::Err(Some(ProtoError::from(err))),
179
        }
180
    }};
181
    ($expr:expr,) => {
182
        $expr?
183
    };
184
}
185
186
#[cfg(feature = "std")]
187
impl DnsHandle for BufDnsRequestStreamHandle {
188
    type Response = DnsResponseReceiver;
189
190
0
    fn send(&self, request: DnsRequest) -> Self::Response {
191
0
        debug!(
192
0
            "enqueueing message:{}:{:?}",
193
0
            request.op_code(),
194
0
            request.queries()
195
        );
196
197
0
        let (request, oneshot) = OneshotDnsRequest::oneshot(request);
198
0
        let mut sender = self.sender.clone();
199
0
        let try_send = sender.try_send(request).map_err(|_| {
200
0
            debug!("unable to enqueue message");
201
0
            ProtoError::from(ProtoErrorKind::Busy)
202
0
        });
203
0
        try_oneshot!(try_send);
204
205
0
        DnsResponseReceiver::Receiver(oneshot)
206
0
    }
207
}
208
209
// TODO: this future should return the origin message in the response on errors
210
/// A OneshotDnsRequest creates a channel for a response to message
211
#[cfg(feature = "std")]
212
pub struct OneshotDnsRequest {
213
    dns_request: DnsRequest,
214
    sender_for_response: oneshot::Sender<DnsResponseStream>,
215
}
216
217
#[cfg(feature = "std")]
218
impl OneshotDnsRequest {
219
    #[cfg(any(feature = "std", feature = "no-std-rand"))]
220
0
    fn oneshot(dns_request: DnsRequest) -> (Self, oneshot::Receiver<DnsResponseStream>) {
221
0
        let (sender_for_response, receiver) = oneshot::channel();
222
0
223
0
        (
224
0
            Self {
225
0
                dns_request,
226
0
                sender_for_response,
227
0
            },
228
0
            receiver,
229
0
        )
230
0
    }
231
232
0
    fn into_parts(self) -> (DnsRequest, OneshotDnsResponse) {
233
0
        (
234
0
            self.dns_request,
235
0
            OneshotDnsResponse(self.sender_for_response),
236
0
        )
237
0
    }
238
}
239
240
#[cfg(feature = "std")]
241
struct OneshotDnsResponse(oneshot::Sender<DnsResponseStream>);
242
243
#[cfg(feature = "std")]
244
impl OneshotDnsResponse {
245
0
    fn send_response(self, serial_response: DnsResponseStream) -> Result<(), DnsResponseStream> {
246
0
        self.0.send(serial_response)
247
0
    }
248
}
249
250
/// A Stream that wraps a [`oneshot::Receiver<Stream>`] and resolves to items in the inner Stream
251
#[cfg(feature = "std")]
252
pub enum DnsResponseReceiver {
253
    /// The receiver
254
    Receiver(oneshot::Receiver<DnsResponseStream>),
255
    /// The stream once received
256
    Received(DnsResponseStream),
257
    /// Error during the send operation
258
    Err(Option<ProtoError>),
259
}
260
261
#[cfg(feature = "std")]
262
impl Stream for DnsResponseReceiver {
263
    type Item = Result<DnsResponse, ProtoError>;
264
265
0
    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
266
        loop {
267
0
            *self = match &mut *self {
268
0
                Self::Receiver(receiver) => {
269
0
                    let receiver = Pin::new(receiver);
270
0
                    let future = ready!(
271
0
                        receiver
272
0
                            .poll(cx)
273
0
                            .map_err(|_| ProtoError::from("receiver was canceled"))
274
0
                    )?;
275
0
                    Self::Received(future)
276
                }
277
0
                Self::Received(stream) => {
278
0
                    return stream.poll_next_unpin(cx);
279
                }
280
0
                Self::Err(err) => return Poll::Ready(err.take().map(Err)),
281
            };
282
        }
283
0
    }
284
}
285
286
/// Helper trait to convert a Stream of dns response into a Future
287
pub trait FirstAnswer<T, E: From<ProtoError>>: Stream<Item = Result<T, E>> + Unpin + Sized {
288
    /// Convert a Stream of dns response into a Future yielding the first answer,
289
    /// discarding others if any.
290
0
    fn first_answer(self) -> FirstAnswerFuture<Self> {
291
0
        FirstAnswerFuture { stream: Some(self) }
292
0
    }
293
}
294
295
impl<E, S, T> FirstAnswer<T, E> for S
296
where
297
    S: Stream<Item = Result<T, E>> + Unpin + Sized,
298
    E: From<ProtoError>,
299
{
300
}
301
302
/// See [FirstAnswer::first_answer]
303
#[derive(Debug)]
304
#[must_use = "futures do nothing unless you `.await` or poll them"]
305
pub struct FirstAnswerFuture<S> {
306
    stream: Option<S>,
307
}
308
309
impl<E, S: Stream<Item = Result<T, E>> + Unpin, T> Future for FirstAnswerFuture<S>
310
where
311
    S: Stream<Item = Result<T, E>> + Unpin + Sized,
312
    E: From<ProtoError>,
313
{
314
    type Output = S::Item;
315
316
0
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
317
0
        let s = self
318
0
            .stream
319
0
            .as_mut()
320
0
            .expect("polling FirstAnswerFuture twice");
321
0
        let item = match ready!(s.poll_next_unpin(cx)) {
322
0
            Some(r) => r,
323
0
            None => Err(ProtoError::from(ProtoErrorKind::Timeout).into()),
324
        };
325
0
        self.stream.take();
326
0
        Poll::Ready(item)
327
0
    }
328
}
329
330
/// The protocol on which a NameServer should be communicated with
331
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
332
#[cfg_attr(
333
    feature = "serde",
334
    derive(Serialize, Deserialize),
335
    serde(rename_all = "lowercase")
336
)]
337
#[non_exhaustive]
338
pub enum Protocol {
339
    /// UDP is the traditional DNS port, this is generally the correct choice
340
    Udp,
341
    /// TCP can be used for large queries, but not all NameServers support it
342
    Tcp,
343
    /// Tls for DNS over TLS
344
    #[cfg(feature = "__tls")]
345
    Tls,
346
    /// Https for DNS over HTTPS
347
    #[cfg(feature = "__https")]
348
    Https,
349
    /// QUIC for DNS over QUIC
350
    #[cfg(feature = "__quic")]
351
    Quic,
352
    /// HTTP/3 for DNS over HTTP/3
353
    #[cfg(feature = "__h3")]
354
    H3,
355
}
356
357
impl Protocol {
358
    /// Returns true if this is a datagram oriented protocol, e.g. UDP
359
0
    pub fn is_datagram(self) -> bool {
360
0
        matches!(self, Self::Udp)
361
0
    }
362
363
    /// Returns true if this is a stream oriented protocol, e.g. TCP
364
0
    pub fn is_stream(self) -> bool {
365
0
        !self.is_datagram()
366
0
    }
367
368
    /// Is this an encrypted protocol, i.e. TLS or HTTPS
369
0
    pub fn is_encrypted(self) -> bool {
370
0
        match self {
371
0
            Self::Udp => false,
372
0
            Self::Tcp => false,
373
            #[cfg(feature = "__tls")]
374
            Self::Tls => true,
375
            #[cfg(feature = "__https")]
376
            Self::Https => true,
377
            #[cfg(feature = "__quic")]
378
            Self::Quic => true,
379
            #[cfg(feature = "__h3")]
380
            Self::H3 => true,
381
        }
382
0
    }
383
}
384
385
impl fmt::Display for Protocol {
386
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
387
0
        f.write_str(match self {
388
0
            Self::Udp => "udp",
389
0
            Self::Tcp => "tcp",
390
            #[cfg(feature = "__tls")]
391
            Self::Tls => "tls",
392
            #[cfg(feature = "__https")]
393
            Self::Https => "https",
394
            #[cfg(feature = "__quic")]
395
            Self::Quic => "quic",
396
            #[cfg(feature = "__h3")]
397
            Self::H3 => "h3",
398
        })
399
0
    }
400
}
401
402
#[allow(unused)] // May be unused depending on features
403
pub(crate) const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);