Coverage Report

Created: 2025-12-12 07:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/hickory-dns/crates/proto/src/op/message.rs
Line
Count
Source
1
// Copyright 2015-2023 Benjamin Fry <benjaminfry@me.com>
2
//
3
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4
// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5
// https://opensource.org/licenses/MIT>, at your option. This file may not be
6
// copied, modified, or distributed except according to those terms.
7
8
//! Basic protocol message for DNS
9
10
use alloc::{boxed::Box, fmt, vec::Vec};
11
use core::{iter, mem, ops::Deref};
12
13
#[cfg(feature = "serde")]
14
use serde::{Deserialize, Serialize};
15
use tracing::{debug, warn};
16
17
#[cfg(feature = "__dnssec")]
18
use crate::dnssec::rdata::{DNSSECRData, SIG, TSIG};
19
#[cfg(any(feature = "std", feature = "no-std-rand"))]
20
use crate::random;
21
use crate::{
22
    error::{ProtoError, ProtoResult},
23
    op::{DnsResponse, Edns, Header, MessageType, OpCode, Query, ResponseCode},
24
    rr::{RData, Record, RecordType},
25
    serialize::binary::{BinDecodable, BinDecoder, BinEncodable, BinEncoder, EncodeMode},
26
};
27
28
/// The basic request and response data structure, used for all DNS protocols.
29
///
30
/// [RFC 1035, DOMAIN NAMES - IMPLEMENTATION AND SPECIFICATION, November 1987](https://tools.ietf.org/html/rfc1035)
31
///
32
/// ```text
33
/// 4.1. Format
34
///
35
/// All communications inside of the domain protocol are carried in a single
36
/// format called a message.  The top level format of message is divided
37
/// into 5 sections (some of which are empty in certain cases) shown below:
38
///
39
///     +--------------------------+
40
///     |        Header            |
41
///     +--------------------------+
42
///     |  Question / Zone         | the question for the name server
43
///     +--------------------------+
44
///     |   Answer  / Prerequisite | RRs answering the question
45
///     +--------------------------+
46
///     | Authority / Update       | RRs pointing toward an authority
47
///     +--------------------------+
48
///     |      Additional          | RRs holding additional information
49
///     +--------------------------+
50
///
51
/// The header section is always present.  The header includes fields that
52
/// specify which of the remaining sections are present, and also specify
53
/// whether the message is a query or a response, a standard query or some
54
/// other opcode, etc.
55
///
56
/// The names of the sections after the header are derived from their use in
57
/// standard queries.  The question section contains fields that describe a
58
/// question to a name server.  These fields are a query type (QTYPE), a
59
/// query class (QCLASS), and a query domain name (QNAME).  The last three
60
/// sections have the same format: a possibly empty list of concatenated
61
/// resource records (RRs).  The answer section contains RRs that answer the
62
/// question; the authority section contains RRs that point toward an
63
/// authoritative name server; the additional records section contains RRs
64
/// which relate to the query, but are not strictly answers for the
65
/// question.
66
/// ```
67
///
68
/// By default Message is a Query. Use the Message::as_update() to create and update, or
69
///  Message::new_update()
70
#[derive(Clone, Debug, PartialEq, Eq)]
71
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
72
pub struct Message {
73
    header: Header,
74
    queries: Vec<Query>,
75
    answers: Vec<Record>,
76
    authorities: Vec<Record>,
77
    additionals: Vec<Record>,
78
    signature: MessageSignature,
79
    edns: Option<Edns>,
80
}
81
82
impl Message {
83
    /// Returns a new "empty" Message
84
    #[cfg(any(feature = "std", feature = "no-std-rand"))]
85
0
    pub fn query() -> Self {
86
0
        Self::new(random(), MessageType::Query, OpCode::Query)
87
0
    }
88
89
    /// Returns a Message constructed with error details to return to a client
90
    ///
91
    /// # Arguments
92
    ///
93
    /// * `id` - message id should match the request message id
94
    /// * `op_code` - operation of the request
95
    /// * `response_code` - the error code for the response
96
0
    pub fn error_msg(id: u16, op_code: OpCode, response_code: ResponseCode) -> Self {
97
0
        let mut message = Self::response(id, op_code);
98
0
        message.set_response_code(response_code);
99
0
        message
100
0
    }
101
102
    /// Returns a new `Message` with `MessageType::Response` and the given header contents
103
0
    pub fn response(id: u16, op_code: OpCode) -> Self {
104
0
        Self::new(id, MessageType::Response, op_code)
105
0
    }
106
107
    /// Create a new [`Message`] with the given header contents
108
0
    pub fn new(id: u16, message_type: MessageType, op_code: OpCode) -> Self {
109
0
        Self {
110
0
            header: Header::new(id, message_type, op_code),
111
0
            queries: Vec::new(),
112
0
            answers: Vec::new(),
113
0
            authorities: Vec::new(),
114
0
            additionals: Vec::new(),
115
0
            signature: MessageSignature::default(),
116
0
            edns: None,
117
0
        }
118
0
    }
119
120
    /// Truncates a Message, this blindly removes all response fields and sets truncated to `true`
121
0
    pub fn truncate(&self) -> Self {
122
        // copy header
123
0
        let mut header = self.header;
124
0
        header.set_truncated(true);
125
0
        header
126
0
            .set_additional_count(0)
127
0
            .set_answer_count(0)
128
0
            .set_authority_count(0);
129
130
0
        let mut msg = Self::new(0, MessageType::Query, OpCode::Query);
131
0
        msg.header = header;
132
133
        // drops additional/answer/nameservers/signature
134
        // adds query/OPT
135
0
        msg.add_queries(self.queries().iter().cloned());
136
0
        if let Some(edns) = self.extensions().clone() {
137
0
            msg.set_edns(edns);
138
0
        }
139
140
        // TODO, perhaps just quickly add a few response records here? that we know would fit?
141
0
        msg
142
0
    }
143
144
    /// Sets the [`Header`]
145
0
    pub fn set_header(&mut self, header: Header) -> &mut Self {
146
0
        self.header = header;
147
0
        self
148
0
    }
149
150
    /// See [`Header::set_id()`]
151
0
    pub fn set_id(&mut self, id: u16) -> &mut Self {
152
0
        self.header.set_id(id);
153
0
        self
154
0
    }
155
156
    /// See [`Header::set_op_code()`]
157
0
    pub fn set_op_code(&mut self, op_code: OpCode) -> &mut Self {
158
0
        self.header.set_op_code(op_code);
159
0
        self
160
0
    }
161
162
    /// See [`Header::set_authoritative()`]
163
0
    pub fn set_authoritative(&mut self, authoritative: bool) -> &mut Self {
164
0
        self.header.set_authoritative(authoritative);
165
0
        self
166
0
    }
167
168
    /// See [`Header::set_truncated()`]
169
0
    pub fn set_truncated(&mut self, truncated: bool) -> &mut Self {
170
0
        self.header.set_truncated(truncated);
171
0
        self
172
0
    }
173
174
    /// See [`Header::set_recursion_desired()`]
175
0
    pub fn set_recursion_desired(&mut self, recursion_desired: bool) -> &mut Self {
176
0
        self.header.set_recursion_desired(recursion_desired);
177
0
        self
178
0
    }
179
180
    /// See [`Header::set_recursion_available()`]
181
0
    pub fn set_recursion_available(&mut self, recursion_available: bool) -> &mut Self {
182
0
        self.header.set_recursion_available(recursion_available);
183
0
        self
184
0
    }
185
186
    /// See [`Header::set_authentic_data()`]
187
0
    pub fn set_authentic_data(&mut self, authentic_data: bool) -> &mut Self {
188
0
        self.header.set_authentic_data(authentic_data);
189
0
        self
190
0
    }
191
192
    /// See [`Header::set_checking_disabled()`]
193
0
    pub fn set_checking_disabled(&mut self, checking_disabled: bool) -> &mut Self {
194
0
        self.header.set_checking_disabled(checking_disabled);
195
0
        self
196
0
    }
197
198
    /// See [`Header::set_response_code()`]
199
0
    pub fn set_response_code(&mut self, response_code: ResponseCode) -> &mut Self {
200
0
        self.header.set_response_code(response_code);
201
0
        self
202
0
    }
203
204
    /// See [`Header::set_query_count()`]
205
    ///
206
    /// this count will be ignored during serialization,
207
    /// where the length of the associated records will be used instead.
208
0
    pub fn set_query_count(&mut self, query_count: u16) -> &mut Self {
209
0
        self.header.set_query_count(query_count);
210
0
        self
211
0
    }
212
213
    /// See [`Header::set_answer_count()`]
214
    ///
215
    /// this count will be ignored during serialization,
216
    /// where the length of the associated records will be used instead.
217
0
    pub fn set_answer_count(&mut self, answer_count: u16) -> &mut Self {
218
0
        self.header.set_answer_count(answer_count);
219
0
        self
220
0
    }
221
222
    /// See [`Header::set_authority_count()`]
223
    ///
224
    /// this count will be ignored during serialization,
225
    /// where the length of the associated records will be used instead.
226
0
    pub fn set_authority_count(&mut self, authority_count: u16) -> &mut Self {
227
0
        self.header.set_authority_count(authority_count);
228
0
        self
229
0
    }
230
231
    /// See [`Header::set_additional_count()`]
232
    ///
233
    /// this count will be ignored during serialization,
234
    /// where the length of the associated records will be used instead.
235
0
    pub fn set_additional_count(&mut self, additional_count: u16) -> &mut Self {
236
0
        self.header.set_additional_count(additional_count);
237
0
        self
238
0
    }
239
240
    /// Add a query to the Message, either the query response from the server, or the request Query.
241
0
    pub fn add_query(&mut self, query: Query) -> &mut Self {
242
0
        self.queries.push(query);
243
0
        self
244
0
    }
245
246
    /// Adds an iterator over a set of Queries to be added to the message
247
0
    pub fn add_queries<Q, I>(&mut self, queries: Q) -> &mut Self
248
0
    where
249
0
        Q: IntoIterator<Item = Query, IntoIter = I>,
250
0
        I: Iterator<Item = Query>,
251
    {
252
0
        for query in queries {
253
0
            self.add_query(query);
254
0
        }
255
256
0
        self
257
0
    }
258
259
    /// Add a record to the Answer section.
260
0
    pub fn add_answer(&mut self, record: Record) -> &mut Self {
261
0
        self.answers.push(record);
262
0
        self
263
0
    }
264
265
    /// Add all the records from the iterator to the Answer section of the message.
266
0
    pub fn add_answers<R, I>(&mut self, records: R) -> &mut Self
267
0
    where
268
0
        R: IntoIterator<Item = Record, IntoIter = I>,
269
0
        I: Iterator<Item = Record>,
270
    {
271
0
        for record in records {
272
0
            self.add_answer(record);
273
0
        }
274
275
0
        self
276
0
    }
277
278
    /// Sets the Answer section to the specified set of records.
279
    ///
280
    /// # Panics
281
    ///
282
    /// Will panic if the Answer section is already non-empty.
283
0
    pub fn insert_answers(&mut self, records: Vec<Record>) {
284
0
        assert!(self.answers.is_empty());
285
0
        self.answers = records;
286
0
    }
287
288
    /// Add a record to the Authority section.
289
0
    pub fn add_authority(&mut self, record: Record) -> &mut Self {
290
0
        self.authorities.push(record);
291
0
        self
292
0
    }
293
294
    /// Add all the records from the Iterator to the Authority section of the message.
295
0
    pub fn add_authorities<R, I>(&mut self, records: R) -> &mut Self
296
0
    where
297
0
        R: IntoIterator<Item = Record, IntoIter = I>,
298
0
        I: Iterator<Item = Record>,
299
    {
300
0
        for record in records {
301
0
            self.add_authority(record);
302
0
        }
303
304
0
        self
305
0
    }
306
307
    /// Sets the Authority section to the specified set of records.
308
    ///
309
    /// # Panics
310
    ///
311
    /// Will panic if the Authority section is already non-empty.
312
0
    pub fn insert_authorities(&mut self, records: Vec<Record>) {
313
0
        assert!(self.authorities.is_empty());
314
0
        self.authorities = records;
315
0
    }
316
317
    /// Add a record to the Additional section.
318
0
    pub fn add_additional(&mut self, record: Record) -> &mut Self {
319
0
        self.additionals.push(record);
320
0
        self
321
0
    }
322
323
    /// Add all the records from the iterator to the Additional section of the message.
324
0
    pub fn add_additionals<R, I>(&mut self, records: R) -> &mut Self
325
0
    where
326
0
        R: IntoIterator<Item = Record, IntoIter = I>,
327
0
        I: Iterator<Item = Record>,
328
    {
329
0
        for record in records {
330
0
            self.add_additional(record);
331
0
        }
332
333
0
        self
334
0
    }
335
336
    /// Sets the Additional to the specified set of records.
337
    ///
338
    /// # Panics
339
    ///
340
    /// Will panic if additional records are already associated to the message.
341
0
    pub fn insert_additionals(&mut self, records: Vec<Record>) {
342
0
        assert!(self.additionals.is_empty());
343
0
        self.additionals = records;
344
0
    }
345
346
    /// Add the EDNS OPT pseudo-RR to the Message
347
0
    pub fn set_edns(&mut self, edns: Edns) -> &mut Self {
348
0
        self.edns = Some(edns);
349
0
        self
350
0
    }
351
352
    /// Set the signature record for the message.
353
    ///
354
    /// This must be used only after all records have been associated. Generally this will be
355
    /// handled by the client and not need to be used directly
356
    ///
357
    /// # Panics
358
    ///
359
    /// If the `MessageSignature` specifies a `Record` and the record type is not correct. For
360
    /// example, providing a `MessageSignature::Tsig` variant with a `Record` with a type other than
361
    /// `RecordType::TSIG` will panic.
362
    #[cfg(feature = "__dnssec")]
363
0
    pub fn set_signature(&mut self, sig: MessageSignature) -> &mut Self {
364
0
        match &sig {
365
0
            MessageSignature::Tsig(rec) => assert_eq!(RecordType::TSIG, rec.record_type()),
366
0
            MessageSignature::Sig0(rec) => assert_eq!(RecordType::SIG, rec.record_type()),
367
0
            _ => {}
368
        }
369
0
        self.signature = sig;
370
0
        self
371
0
    }
372
373
    /// Returns a clone of the `Message` with the message type set to `Response`.
374
0
    pub fn to_response(&self) -> Self {
375
0
        let mut header = self.header;
376
0
        header.set_message_type(MessageType::Response);
377
0
        Self {
378
0
            header,
379
0
            queries: self.queries.clone(),
380
0
            answers: self.answers.clone(),
381
0
            authorities: self.authorities.clone(),
382
0
            additionals: self.additionals.clone(),
383
0
            signature: self.signature.clone(),
384
0
            edns: self.edns.clone(),
385
0
        }
386
0
    }
387
388
    /// Gets the header of the Message
389
0
    pub fn header(&self) -> &Header {
390
0
        &self.header
391
0
    }
392
393
    /// See [`Header::id()`]
394
0
    pub fn id(&self) -> u16 {
395
0
        self.header.id()
396
0
    }
397
398
    /// See [`Header::message_type()`]
399
0
    pub fn message_type(&self) -> MessageType {
400
0
        self.header.message_type()
401
0
    }
402
403
    /// See [`Header::op_code()`]
404
0
    pub fn op_code(&self) -> OpCode {
405
0
        self.header.op_code()
406
0
    }
407
408
    /// See [`Header::authoritative()`]
409
0
    pub fn authoritative(&self) -> bool {
410
0
        self.header.authoritative()
411
0
    }
412
413
    /// See [`Header::truncated()`]
414
772
    pub fn truncated(&self) -> bool {
415
772
        self.header.truncated()
416
772
    }
417
418
    /// See [`Header::recursion_desired()`]
419
0
    pub fn recursion_desired(&self) -> bool {
420
0
        self.header.recursion_desired()
421
0
    }
422
423
    /// See [`Header::recursion_available()`]
424
0
    pub fn recursion_available(&self) -> bool {
425
0
        self.header.recursion_available()
426
0
    }
427
428
    /// See [`Header::authentic_data()`]
429
0
    pub fn authentic_data(&self) -> bool {
430
0
        self.header.authentic_data()
431
0
    }
432
433
    /// See [`Header::checking_disabled()`]
434
0
    pub fn checking_disabled(&self) -> bool {
435
0
        self.header.checking_disabled()
436
0
    }
437
438
    /// # Return value
439
    ///
440
    /// The `ResponseCode`, if this is an EDNS message then this will join the section from the OPT
441
    ///  record to create the EDNS `ResponseCode`
442
0
    pub fn response_code(&self) -> ResponseCode {
443
0
        self.header.response_code()
444
0
    }
445
446
    /// ```text
447
    /// Question        Carries the query name and other query parameters.
448
    /// ```
449
0
    pub fn queries(&self) -> &[Query] {
450
0
        &self.queries
451
0
    }
452
453
    /// Provides mutable access to `queries`
454
0
    pub fn queries_mut(&mut self) -> &mut Vec<Query> {
455
0
        &mut self.queries
456
0
    }
457
458
    /// Removes all the answers from the Message
459
0
    pub fn take_queries(&mut self) -> Vec<Query> {
460
0
        mem::take(&mut self.queries)
461
0
    }
462
463
    /// ```text
464
    /// Answer          Carries RRs which directly answer the query.
465
    /// ```
466
0
    pub fn answers(&self) -> &[Record] {
467
0
        &self.answers
468
0
    }
469
470
    /// Provides mutable access to `answers`
471
0
    pub fn answers_mut(&mut self) -> &mut Vec<Record> {
472
0
        &mut self.answers
473
0
    }
474
475
    /// Removes the Answer section records from the message
476
0
    pub fn take_answers(&mut self) -> Vec<Record> {
477
0
        mem::take(&mut self.answers)
478
0
    }
479
480
    /// ```text
481
    /// Authority       Carries RRs which describe other authoritative servers.
482
    ///                 May optionally carry the SOA RR for the authoritative
483
    ///                 data in the answer section.
484
    /// ```
485
0
    pub fn authorities(&self) -> &[Record] {
486
0
        &self.authorities
487
0
    }
488
489
    /// Provides mutable access to `authorities`
490
0
    pub fn authorities_mut(&mut self) -> &mut Vec<Record> {
491
0
        &mut self.authorities
492
0
    }
493
494
    /// Remove the Authority section records from the message
495
0
    pub fn take_authorities(&mut self) -> Vec<Record> {
496
0
        mem::take(&mut self.authorities)
497
0
    }
498
499
    /// ```text
500
    /// Additional      Carries RRs which may be helpful in using the RRs in the
501
    ///                 other sections.
502
    /// ```
503
0
    pub fn additionals(&self) -> &[Record] {
504
0
        &self.additionals
505
0
    }
506
507
    /// Provides mutable access to `additionals`
508
0
    pub fn additionals_mut(&mut self) -> &mut Vec<Record> {
509
0
        &mut self.additionals
510
0
    }
511
512
    /// Remove the Additional section records from the message
513
0
    pub fn take_additionals(&mut self) -> Vec<Record> {
514
0
        mem::take(&mut self.additionals)
515
0
    }
516
517
    /// All sections chained
518
0
    pub fn all_sections(&self) -> impl Iterator<Item = &Record> {
519
0
        self.answers
520
0
            .iter()
521
0
            .chain(self.authorities.iter())
522
0
            .chain(self.additionals.iter())
523
0
    }
524
525
    /// [RFC 6891, EDNS(0) Extensions, April 2013](https://tools.ietf.org/html/rfc6891#section-6.1.1)
526
    ///
527
    /// ```text
528
    /// 6.1.1.  Basic Elements
529
    ///
530
    ///  An OPT pseudo-RR (sometimes called a meta-RR) MAY be added to the
531
    ///  additional data section of a request.
532
    ///
533
    ///  The OPT RR has RR type 41.
534
    ///
535
    ///  If an OPT record is present in a received request, compliant
536
    ///  responders MUST include an OPT record in their respective responses.
537
    ///
538
    ///  An OPT record does not carry any DNS data.  It is used only to
539
    ///  contain control information pertaining to the question-and-answer
540
    ///  sequence of a specific transaction.  OPT RRs MUST NOT be cached,
541
    ///  forwarded, or stored in or loaded from Zone Files.
542
    ///
543
    ///  The OPT RR MAY be placed anywhere within the additional data section.
544
    ///  When an OPT RR is included within any DNS message, it MUST be the
545
    ///  only OPT RR in that message.  If a query message with more than one
546
    ///  OPT RR is received, a FORMERR (RCODE=1) MUST be returned.  The
547
    ///  placement flexibility for the OPT RR does not override the need for
548
    ///  the TSIG or SIG(0) RRs to be the last in the additional section
549
    ///  whenever they are present.
550
    /// ```
551
    /// # Return value
552
    ///
553
    /// Optionally returns a reference to EDNS OPT pseudo-RR
554
0
    pub fn extensions(&self) -> &Option<Edns> {
555
0
        &self.edns
556
0
    }
557
558
    /// Returns mutable reference of EDNS OPT pseudo-RR
559
0
    pub fn extensions_mut(&mut self) -> &mut Option<Edns> {
560
0
        &mut self.edns
561
0
    }
562
563
    /// # Return value
564
    ///
565
    /// the max payload value as it's defined in the EDNS OPT pseudo-RR.
566
0
    pub fn max_payload(&self) -> u16 {
567
0
        let max_size = self.edns.as_ref().map_or(512, Edns::max_payload);
568
0
        if max_size < 512 { 512 } else { max_size }
569
0
    }
570
571
    /// # Return value
572
    ///
573
    /// the version as defined in the EDNS record
574
0
    pub fn version(&self) -> u8 {
575
0
        self.edns.as_ref().map_or(0, Edns::version)
576
0
    }
577
578
    /// # Return value
579
    ///
580
    /// the signature over the message, if any
581
0
    pub fn signature(&self) -> &MessageSignature {
582
0
        &self.signature
583
0
    }
584
585
    /// Remove signatures from the Message
586
0
    pub fn take_signature(&mut self) -> MessageSignature {
587
0
        mem::take(&mut self.signature)
588
0
    }
589
590
    // TODO: only necessary in tests, should it be removed?
591
    /// this is necessary to match the counts in the header from the record sections
592
    ///  this happens implicitly on write_to, so no need to call before write_to
593
    #[cfg(test)]
594
    pub fn update_counts(&mut self) -> &mut Self {
595
        self.header = update_header_counts(
596
            &self.header,
597
            self.truncated(),
598
            HeaderCounts {
599
                query_count: self.queries.len(),
600
                answer_count: self.answers.len(),
601
                authority_count: self.authorities.len(),
602
                additional_count: self.additionals.len(),
603
            },
604
        );
605
        self
606
    }
607
608
    /// Attempts to read the specified number of `Query`s
609
0
    pub fn read_queries(decoder: &mut BinDecoder<'_>, count: usize) -> ProtoResult<Vec<Query>> {
610
0
        let mut queries = Vec::with_capacity(count);
611
0
        for _ in 0..count {
612
0
            queries.push(Query::read(decoder)?);
613
        }
614
0
        Ok(queries)
615
0
    }
616
617
    /// Attempts to read the specified number of records
618
    ///
619
    /// # Returns
620
    ///
621
    /// This returns a tuple of first standard Records, then a possibly associated Edns, and then
622
    /// finally a `MessageSignature` if applicable.
623
    ///
624
    /// `MessageSignature::Tsig` and `MessageSignature::Sig0` records are only valid when
625
    /// found in the additional data section. Further, they must always be the last record
626
    /// in that section, and are mutually exclusive. It is not possible to have multiple TSIG
627
    /// or SIG(0) records.
628
    ///
629
    /// RFC 2931 §3.1 says:
630
    ///  "Note: requests and responses can either have a single TSIG or one SIG(0) but not both a
631
    ///   TSIG and a SIG(0)."
632
    /// RFC 8945 §5.1 says:
633
    ///  "This TSIG record MUST be the only TSIG RR in the message and MUST be the last record in
634
    ///   the additional data section."
635
    #[cfg_attr(not(feature = "__dnssec"), allow(unused_mut))]
636
47.2k
    pub fn read_records(
637
47.2k
        decoder: &mut BinDecoder<'_>,
638
47.2k
        count: usize,
639
47.2k
        is_additional: bool,
640
47.2k
    ) -> ProtoResult<(Vec<Record>, Option<Edns>, MessageSignature)> {
641
47.2k
        let mut records: Vec<Record> = Vec::with_capacity(count);
642
47.2k
        let mut edns: Option<Edns> = None;
643
47.2k
        let mut sig = MessageSignature::default();
644
645
47.2k
        for _ in 0..count {
646
1.18M
            let record = Record::read(decoder)?;
647
648
            // There must be no additional records after a TSIG/SIG(0) record.
649
1.18M
            if sig != MessageSignature::Unsigned {
650
13
                return Err("TSIG or SIG(0) record must be final resource record".into());
651
1.18M
            }
652
653
            // OPT, SIG and TSIG records are only allowed in the additional section.
654
1.18M
            if !is_additional
655
838k
                && matches!(
656
839k
                    record.record_type(),
657
                    RecordType::OPT | RecordType::SIG | RecordType::TSIG
658
                )
659
            {
660
176
                return Err(format!(
661
176
                    "record type {} only allowed in additional section",
662
176
                    record.record_type()
663
176
                )
664
176
                .into());
665
1.18M
            } else if !is_additional {
666
838k
                records.push(record);
667
838k
                continue;
668
345k
            }
669
670
345k
            match record.data() {
671
                #[cfg(feature = "__dnssec")]
672
                RData::DNSSEC(DNSSECRData::SIG(_)) => {
673
167
                    sig = MessageSignature::Sig0(
674
167
                        record
675
167
                            .map(|data| match data {
676
167
                                RData::DNSSEC(DNSSECRData::SIG(sig)) => Some(sig),
677
0
                                _ => None,
678
167
                            })
679
167
                            .unwrap(), // Safe: see `match` arm above
680
                    )
681
                }
682
                #[cfg(feature = "__dnssec")]
683
                RData::DNSSEC(DNSSECRData::TSIG(_)) => {
684
556
                    sig = MessageSignature::Tsig(
685
556
                        record
686
556
                            .map(|data| match data {
687
556
                                RData::DNSSEC(DNSSECRData::TSIG(tsig)) => Some(tsig),
688
0
                                _ => None,
689
556
                            })
690
556
                            .unwrap(), // Safe: see `match` arm above
691
                    )
692
                }
693
                RData::Update0(RecordType::OPT) | RData::OPT(_) => {
694
2.11k
                    if edns.is_some() {
695
19
                        return Err("more than one edns record present".into());
696
2.09k
                    }
697
2.09k
                    edns = Some((&record).into());
698
                }
699
342k
                _ => {
700
342k
                    records.push(record);
701
342k
                }
702
            }
703
        }
704
705
43.1k
        Ok((records, edns, sig))
706
47.2k
    }
707
708
    /// Decodes a message from the buffer.
709
0
    pub fn from_vec(buffer: &[u8]) -> ProtoResult<Self> {
710
0
        let mut decoder = BinDecoder::new(buffer);
711
0
        Self::read(&mut decoder)
712
0
    }
713
714
    /// Encodes the Message into a buffer
715
0
    pub fn to_vec(&self) -> Result<Vec<u8>, ProtoError> {
716
        // TODO: this feels like the right place to verify the max packet size of the message,
717
        //  will need to update the header for truncation and the lengths if we send less than the
718
        //  full response. This needs to conform with the EDNS settings of the server...
719
0
        let mut buffer = Vec::with_capacity(512);
720
        {
721
0
            let mut encoder = BinEncoder::new(&mut buffer);
722
0
            self.emit(&mut encoder)?;
723
        }
724
725
0
        Ok(buffer)
726
0
    }
727
728
    /// Finalize the message prior to sending.
729
    ///
730
    /// Subsequent to calling this, the Message should not change.
731
0
    pub fn finalize(
732
0
        &mut self,
733
0
        finalizer: &dyn MessageSigner,
734
0
        inception_time: u64,
735
0
    ) -> ProtoResult<Option<MessageVerifier>> {
736
0
        debug!("finalizing message: {:?}", self);
737
738
        #[cfg_attr(not(feature = "__dnssec"), allow(unused_variables))]
739
0
        let (signature, verifier) = finalizer.sign_message(self, inception_time)?;
740
741
        #[cfg(feature = "__dnssec")]
742
0
        {
743
0
            self.set_signature(signature);
744
0
        }
745
746
0
        Ok(verifier)
747
0
    }
748
749
    /// Consumes `Message` and returns into components
750
0
    pub fn into_parts(self) -> MessageParts {
751
0
        self.into()
752
0
    }
753
}
754
755
impl From<MessageParts> for Message {
756
0
    fn from(msg: MessageParts) -> Self {
757
        let MessageParts {
758
0
            header,
759
0
            queries,
760
0
            answers,
761
0
            authorities,
762
0
            additionals,
763
0
            signature,
764
0
            edns,
765
0
        } = msg;
766
0
        Self {
767
0
            header,
768
0
            queries,
769
0
            answers,
770
0
            authorities,
771
0
            additionals,
772
0
            signature,
773
0
            edns,
774
0
        }
775
0
    }
776
}
777
778
impl Deref for Message {
779
    type Target = Header;
780
781
0
    fn deref(&self) -> &Self::Target {
782
0
        &self.header
783
0
    }
784
}
785
786
/// Consumes `Message` giving public access to fields in `Message` so they can be
787
/// destructured and taken by value
788
/// ```rust
789
/// use hickory_proto::op::{Message, MessageParts};
790
///
791
/// let msg = Message::query();
792
/// let MessageParts { queries, .. } = msg.into_parts();
793
/// ```
794
#[derive(Clone, Debug, PartialEq, Eq)]
795
pub struct MessageParts {
796
    /// message header
797
    pub header: Header,
798
    /// message queries
799
    pub queries: Vec<Query>,
800
    /// message answers
801
    pub answers: Vec<Record>,
802
    /// message authorities
803
    pub authorities: Vec<Record>,
804
    /// message additional records
805
    pub additionals: Vec<Record>,
806
    /// message signature
807
    pub signature: MessageSignature,
808
    /// optional edns records
809
    pub edns: Option<Edns>,
810
}
811
812
impl From<Message> for MessageParts {
813
0
    fn from(msg: Message) -> Self {
814
        let Message {
815
0
            header,
816
0
            queries,
817
0
            answers,
818
0
            authorities,
819
0
            additionals,
820
0
            signature,
821
0
            edns,
822
0
        } = msg;
823
0
        Self {
824
0
            header,
825
0
            queries,
826
0
            answers,
827
0
            authorities,
828
0
            additionals,
829
0
            signature,
830
0
            edns,
831
0
        }
832
0
    }
833
}
834
835
/// Tracks the counts of the records in the Message.
836
///
837
/// This is only used internally during serialization.
838
#[derive(Clone, Copy, Debug)]
839
struct HeaderCounts {
840
    /// The number of queries in the Message
841
    query_count: usize,
842
    /// The number of answer records in the Message
843
    answer_count: usize,
844
    /// The number of authority records in the Message
845
    authority_count: usize,
846
    /// The number of additional records in the Message
847
    additional_count: usize,
848
}
849
850
/// Returns a new Header with accurate counts for each Message section
851
6.63k
fn update_header_counts(
852
6.63k
    current_header: &Header,
853
6.63k
    is_truncated: bool,
854
6.63k
    counts: HeaderCounts,
855
6.63k
) -> Header {
856
6.63k
    assert!(counts.query_count <= u16::MAX as usize);
857
6.63k
    assert!(counts.answer_count <= u16::MAX as usize);
858
6.63k
    assert!(counts.authority_count <= u16::MAX as usize);
859
6.63k
    assert!(counts.additional_count <= u16::MAX as usize);
860
861
    // TODO: should the function just take by value?
862
6.63k
    let mut header = *current_header;
863
6.63k
    header
864
6.63k
        .set_query_count(counts.query_count as u16)
865
6.63k
        .set_answer_count(counts.answer_count as u16)
866
6.63k
        .set_authority_count(counts.authority_count as u16)
867
6.63k
        .set_additional_count(counts.additional_count as u16)
868
6.63k
        .set_truncated(is_truncated);
869
870
6.63k
    header
871
6.63k
}
872
873
/// Alias for a function verifying if a message is properly signed
874
pub type MessageVerifier = Box<dyn FnMut(&[u8]) -> ProtoResult<DnsResponse> + Send>;
875
876
/// A trait for adding a final `MessageSignature` to a Message before it is sent.
877
pub trait MessageSigner: Send + Sync + 'static {
878
    /// Finalize the provided `Message`, computing a `MessageSignature`, and optionally
879
    /// providing a `MessageVerifier` for response messages.
880
    ///
881
    /// # Arguments
882
    ///
883
    /// * `message` - the message to finalize
884
    /// * `current_time` - the current system time.
885
    ///
886
    /// # Return
887
    ///
888
    /// A `MessageSignature` to append to the end of the additional data, and optionally
889
    /// a `MessageVerifier` to use to verify responses provoked by the message.
890
    fn sign_message(
891
        &self,
892
        message: &Message,
893
        current_time: u64,
894
    ) -> ProtoResult<(MessageSignature, Option<MessageVerifier>)>;
895
896
    /// Return whether the message requires a signature before being sent.
897
    /// By default, returns true for AXFR and IXFR queries, and Update and Notify messages
898
0
    fn should_sign_message(&self, message: &Message) -> bool {
899
0
        [OpCode::Update, OpCode::Notify].contains(&message.op_code())
900
0
            || message
901
0
                .queries()
902
0
                .iter()
903
0
                .any(|q| [RecordType::AXFR, RecordType::IXFR].contains(&q.query_type()))
904
0
    }
905
}
906
907
/// A trait for producing a `MessageSignature` for responses
908
pub trait ResponseSigner: Send + Sync {
909
    /// sign produces a `MessageSignature` for the provided encoded, unsigned, response message.
910
    fn sign(self: Box<Self>, response: &[u8]) -> Result<MessageSignature, ProtoError>;
911
}
912
913
/// Returns the count written and a boolean if it was truncated
914
21.2k
fn count_was_truncated(result: ProtoResult<usize>) -> ProtoResult<(usize, bool)> {
915
1.46k
    match result {
916
19.7k
        Ok(count) => Ok((count, false)),
917
1.46k
        Err(ProtoError::NotAllRecordsWritten { count }) => Ok((count, true)),
918
0
        Err(e) => Err(e),
919
    }
920
21.2k
}
921
922
/// A trait that defines types which can be emitted as a set, with the associated count returned.
923
pub trait EmitAndCount {
924
    /// Emit self to the encoder and return the count of items
925
    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize>;
926
}
927
928
impl<'e, I: Iterator<Item = &'e E>, E: 'e + BinEncodable> EmitAndCount for I {
929
26.5k
    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
930
26.5k
        encoder.emit_all(self)
931
26.5k
    }
<core::slice::iter::Iter<hickory_proto::op::query::Query> as hickory_proto::op::message::EmitAndCount>::emit
Line
Count
Source
929
6.66k
    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
930
6.66k
        encoder.emit_all(self)
931
6.66k
    }
<core::slice::iter::Iter<hickory_proto::rr::resource::Record> as hickory_proto::op::message::EmitAndCount>::emit
Line
Count
Source
929
19.9k
    fn emit(&mut self, encoder: &mut BinEncoder<'_>) -> ProtoResult<usize> {
930
19.9k
        encoder.emit_all(self)
931
19.9k
    }
932
}
933
934
/// Emits the different sections of a message properly
935
///
936
/// # Return
937
///
938
/// In the case of a successful emit, the final header (updated counts, etc) is returned for help with logging, etc.
939
#[allow(clippy::too_many_arguments)]
940
6.66k
pub fn emit_message_parts<Q, A, N, D>(
941
6.66k
    header: &Header,
942
6.66k
    queries: &mut Q,
943
6.66k
    answers: &mut A,
944
6.66k
    authorities: &mut N,
945
6.66k
    additionals: &mut D,
946
6.66k
    edns: Option<&Edns>,
947
6.66k
    signature: &MessageSignature,
948
6.66k
    encoder: &mut BinEncoder<'_>,
949
6.66k
) -> ProtoResult<Header>
950
6.66k
where
951
6.66k
    Q: EmitAndCount,
952
6.66k
    A: EmitAndCount,
953
6.66k
    N: EmitAndCount,
954
6.66k
    D: EmitAndCount,
955
{
956
6.66k
    let include_signature = encoder.mode() != EncodeMode::Signing;
957
6.66k
    let place = encoder.place::<Header>()?;
958
959
6.66k
    let query_count = queries.emit(encoder)?;
960
    // TODO: need to do something on max records
961
    //  return offset of last emitted record.
962
6.63k
    let answer_count = count_was_truncated(answers.emit(encoder))?;
963
6.63k
    let authority_count = count_was_truncated(authorities.emit(encoder))?;
964
6.63k
    let mut additional_count = count_was_truncated(additionals.emit(encoder))?;
965
966
6.63k
    if let Some(mut edns) = edns.cloned() {
967
        // need to commit the error code
968
986
        edns.set_rcode_high(header.response_code().high());
969
970
986
        let count = count_was_truncated(encoder.emit_all(iter::once(&Record::from(&edns))))?;
971
986
        additional_count.0 += count.0;
972
986
        additional_count.1 |= count.1;
973
5.65k
    } else if header.response_code().high() > 0 {
974
0
        warn!(
975
0
            "response code: {} for request: {} requires EDNS but none available",
976
0
            header.response_code(),
977
0
            header.id()
978
        );
979
5.65k
    }
980
981
    // this is a little hacky, but if we are Verifying a signature, i.e. the original Message
982
    //  then the SIG0 or TSIG record should not be encoded and the edns record (if it exists) is
983
    //  already part of the additionals section.
984
6.63k
    if include_signature {
985
6.63k
        let count = match signature {
986
            #[cfg(feature = "__dnssec")]
987
86
            MessageSignature::Sig0(rec) => count_was_truncated(encoder.emit_all(iter::once(rec)))?,
988
            #[cfg(feature = "__dnssec")]
989
277
            MessageSignature::Tsig(rec) => count_was_truncated(encoder.emit_all(iter::once(rec)))?,
990
6.27k
            MessageSignature::Unsigned => (0, false),
991
        };
992
6.63k
        additional_count.0 += count.0;
993
6.63k
        additional_count.1 |= count.1;
994
0
    }
995
996
6.63k
    let counts = HeaderCounts {
997
6.63k
        query_count,
998
6.63k
        answer_count: answer_count.0,
999
6.63k
        authority_count: authority_count.0,
1000
6.63k
        additional_count: additional_count.0,
1001
6.63k
    };
1002
6.63k
    let was_truncated =
1003
6.63k
        header.truncated() || answer_count.1 || authority_count.1 || additional_count.1;
1004
1005
6.63k
    let final_header = update_header_counts(header, was_truncated, counts);
1006
6.63k
    place.replace(encoder, final_header)?;
1007
6.63k
    Ok(final_header)
1008
6.66k
}
1009
1010
impl BinEncodable for Message {
1011
6.66k
    fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
1012
6.66k
        emit_message_parts(
1013
6.66k
            &self.header,
1014
6.66k
            &mut self.queries.iter(),
1015
6.66k
            &mut self.answers.iter(),
1016
6.66k
            &mut self.authorities.iter(),
1017
6.66k
            &mut self.additionals.iter(),
1018
6.66k
            self.edns.as_ref(),
1019
6.66k
            &self.signature,
1020
6.66k
            encoder,
1021
26
        )?;
1022
1023
6.63k
        Ok(())
1024
6.66k
    }
1025
}
1026
1027
impl<'r> BinDecodable<'r> for Message {
1028
17.7k
    fn read(decoder: &mut BinDecoder<'r>) -> ProtoResult<Self> {
1029
17.7k
        let mut header = Header::read(decoder)?;
1030
1031
        // TODO: return just header, and in the case of the rest of message getting an error.
1032
        //  this could improve error detection while decoding.
1033
1034
        // get the questions
1035
17.7k
        let count = header.query_count() as usize;
1036
17.7k
        let mut queries = Vec::with_capacity(count);
1037
17.7k
        for _ in 0..count {
1038
608k
            queries.push(Query::read(decoder)?);
1039
        }
1040
1041
        // get all counts before header moves
1042
17.3k
        let answer_count = header.answer_count() as usize;
1043
17.3k
        let authority_count = header.authority_count() as usize;
1044
17.3k
        let additional_count = header.additional_count() as usize;
1045
1046
17.3k
        let (answers, _, _) = Self::read_records(decoder, answer_count, false)?;
1047
15.4k
        let (authorities, _, _) = Self::read_records(decoder, authority_count, false)?;
1048
14.4k
        let (additionals, edns, signature) = Self::read_records(decoder, additional_count, true)?;
1049
1050
        // need to grab error code from EDNS (which might have a higher value)
1051
13.2k
        if let Some(edns) = &edns {
1052
1.92k
            let high_response_code = edns.rcode_high();
1053
1.92k
            header.merge_response_code(high_response_code);
1054
11.3k
        }
1055
1056
13.2k
        Ok(Self {
1057
13.2k
            header,
1058
13.2k
            queries,
1059
13.2k
            answers,
1060
13.2k
            authorities,
1061
13.2k
            additionals,
1062
13.2k
            signature,
1063
13.2k
            edns,
1064
13.2k
        })
1065
17.7k
    }
1066
}
1067
1068
impl fmt::Display for Message {
1069
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
1070
0
        let write_query = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1071
0
            for d in slice {
1072
0
                writeln!(f, ";; {d}")?;
1073
            }
1074
1075
0
            Ok(())
1076
0
        };
1077
1078
0
        let write_slice = |slice, f: &mut fmt::Formatter<'_>| -> Result<(), fmt::Error> {
1079
0
            for d in slice {
1080
0
                writeln!(f, "{d}")?;
1081
            }
1082
1083
0
            Ok(())
1084
0
        };
1085
1086
0
        writeln!(f, "; header {header}", header = self.header())?;
1087
1088
0
        if let Some(edns) = self.extensions() {
1089
0
            writeln!(f, "; edns {edns}")?;
1090
0
        }
1091
1092
0
        writeln!(f, "; query")?;
1093
0
        write_query(self.queries(), f)?;
1094
1095
0
        if self.header().message_type() == MessageType::Response
1096
0
            || self.header().op_code() == OpCode::Update
1097
        {
1098
0
            writeln!(f, "; answers {}", self.answer_count())?;
1099
0
            write_slice(self.answers(), f)?;
1100
0
            writeln!(f, "; authorities {}", self.authority_count())?;
1101
0
            write_slice(self.authorities(), f)?;
1102
0
            writeln!(f, "; additionals {}", self.additional_count())?;
1103
0
            write_slice(self.additionals(), f)?;
1104
0
        }
1105
1106
0
        Ok(())
1107
0
    }
1108
}
1109
1110
/// Indicates how a [Message] is signed.
1111
///
1112
/// Per RFC, the choice of RFC 2931 SIG(0), or RFC 8945 TSIG is mutually exclusive:
1113
/// only one or the other may be used. See [`Message::read_records()`] for more
1114
/// information.
1115
#[derive(Clone, Debug, Eq, PartialEq, Default)]
1116
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
1117
pub enum MessageSignature {
1118
    /// The message is not signed, or the dnssec crate feature is not enabled.
1119
    #[default]
1120
    Unsigned,
1121
    /// The message has an RFC 2931 SIG(0) signature [Record].
1122
    #[cfg(feature = "__dnssec")]
1123
    Sig0(Record<SIG>),
1124
    /// The message has an RFC 8945 TSIG signature [Record].
1125
    #[cfg(feature = "__dnssec")]
1126
    Tsig(Record<TSIG>),
1127
}
1128
1129
#[cfg(test)]
1130
mod tests {
1131
    use super::*;
1132
1133
    #[cfg(feature = "__dnssec")]
1134
    use crate::dnssec::Algorithm;
1135
    #[cfg(feature = "__dnssec")]
1136
    use crate::dnssec::rdata::sig::SigInput;
1137
    #[cfg(feature = "__dnssec")]
1138
    use crate::dnssec::rdata::tsig::{TSIG, TsigAlgorithm};
1139
    use crate::rr::rdata::A;
1140
    #[cfg(feature = "std")]
1141
    use crate::rr::rdata::OPT;
1142
    #[cfg(feature = "std")]
1143
    use crate::rr::rdata::opt::{ClientSubnet, EdnsCode, EdnsOption};
1144
    use crate::rr::{Name, RData};
1145
    #[cfg(feature = "__dnssec")]
1146
    use crate::rr::{RecordType, SerialNumber};
1147
    #[cfg(feature = "std")]
1148
    use crate::std::net::IpAddr;
1149
    #[cfg(feature = "std")]
1150
    use crate::std::string::ToString;
1151
1152
    #[test]
1153
    fn test_emit_and_read_header() {
1154
        let mut message = Message::response(10, OpCode::Update);
1155
        message
1156
            .set_authoritative(true)
1157
            .set_truncated(false)
1158
            .set_recursion_desired(true)
1159
            .set_recursion_available(true)
1160
            .set_response_code(ResponseCode::ServFail);
1161
1162
        test_emit_and_read(message);
1163
    }
1164
1165
    #[test]
1166
    fn test_emit_and_read_query() {
1167
        let mut message = Message::response(10, OpCode::Update);
1168
        message
1169
            .set_authoritative(true)
1170
            .set_truncated(true)
1171
            .set_recursion_desired(true)
1172
            .set_recursion_available(true)
1173
            .set_response_code(ResponseCode::ServFail)
1174
            .add_query(Query::new())
1175
            .update_counts(); // we're not testing the query parsing, just message
1176
1177
        test_emit_and_read(message);
1178
    }
1179
1180
    #[test]
1181
    fn test_emit_and_read_records() {
1182
        let mut message = Message::response(10, OpCode::Update);
1183
        message
1184
            .set_authoritative(true)
1185
            .set_truncated(true)
1186
            .set_recursion_desired(true)
1187
            .set_recursion_available(true)
1188
            .set_authentic_data(true)
1189
            .set_checking_disabled(true)
1190
            .set_response_code(ResponseCode::ServFail);
1191
1192
        message.add_answer(Record::stub());
1193
        message.add_authority(Record::stub());
1194
        message.add_additional(Record::stub());
1195
        message.update_counts();
1196
1197
        test_emit_and_read(message);
1198
    }
1199
1200
    #[cfg(test)]
1201
    fn test_emit_and_read(message: Message) {
1202
        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1203
        {
1204
            let mut encoder = BinEncoder::new(&mut byte_vec);
1205
            message.emit(&mut encoder).unwrap();
1206
        }
1207
1208
        let mut decoder = BinDecoder::new(&byte_vec);
1209
        let got = Message::read(&mut decoder).unwrap();
1210
1211
        assert_eq!(got, message);
1212
    }
1213
1214
    #[test]
1215
    fn test_header_counts_correction_after_emit_read() {
1216
        let mut message = Message::response(10, OpCode::Update);
1217
        message
1218
            .set_authoritative(true)
1219
            .set_truncated(true)
1220
            .set_recursion_desired(true)
1221
            .set_recursion_available(true)
1222
            .set_authentic_data(true)
1223
            .set_checking_disabled(true)
1224
            .set_response_code(ResponseCode::ServFail);
1225
1226
        message.add_answer(Record::stub());
1227
        message.add_authority(Record::stub());
1228
        message.add_additional(Record::stub());
1229
1230
        // at here, we don't call update_counts and we even set wrong count,
1231
        // because we are trying to test whether the counts in the header
1232
        // are correct after the message is emitted and read.
1233
        message.set_query_count(1);
1234
        message.set_answer_count(5);
1235
        message.set_authority_count(5);
1236
        // message.set_additional_count(1);
1237
1238
        let got = get_message_after_emitting_and_reading(message);
1239
1240
        // make comparison
1241
        assert_eq!(got.query_count(), 0);
1242
        assert_eq!(got.answer_count(), 1);
1243
        assert_eq!(got.authority_count(), 1);
1244
        assert_eq!(got.additional_count(), 1);
1245
    }
1246
1247
    #[cfg(test)]
1248
    fn get_message_after_emitting_and_reading(message: Message) -> Message {
1249
        let mut byte_vec: Vec<u8> = Vec::with_capacity(512);
1250
        {
1251
            let mut encoder = BinEncoder::new(&mut byte_vec);
1252
            message.emit(&mut encoder).unwrap();
1253
        }
1254
1255
        let mut decoder = BinDecoder::new(&byte_vec);
1256
1257
        Message::read(&mut decoder).unwrap()
1258
    }
1259
1260
    #[test]
1261
    fn test_legit_message() {
1262
        #[rustfmt::skip]
1263
        let buf: Vec<u8> = vec![
1264
            0x10, 0x00, 0x81,
1265
            0x80, // id = 4096, response, op=query, recursion_desired, recursion_available, no_error
1266
            0x00, 0x01, 0x00, 0x01, // 1 query, 1 answer,
1267
            0x00, 0x00, 0x00, 0x00, // 0 nameservers, 0 additional record
1268
            0x03, b'w', b'w', b'w', // query --- www.example.com
1269
            0x07, b'e', b'x', b'a', //
1270
            b'm', b'p', b'l', b'e', //
1271
            0x03, b'c', b'o', b'm', //
1272
            0x00,                   // 0 = endname
1273
            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1274
            0xC0, 0x0C,             // name pointer to www.example.com
1275
            0x00, 0x01, 0x00, 0x01, // RecordType = A, Class = IN
1276
            0x00, 0x00, 0x00, 0x02, // TTL = 2 seconds
1277
            0x00, 0x04,             // record length = 4 (ipv4 address)
1278
            0x5D, 0xB8, 0xD7, 0x0E, // address = 93.184.215.14
1279
        ];
1280
1281
        let mut decoder = BinDecoder::new(&buf);
1282
        let message = Message::read(&mut decoder).unwrap();
1283
1284
        assert_eq!(message.id(), 4_096);
1285
1286
        let mut buf: Vec<u8> = Vec::with_capacity(512);
1287
        {
1288
            let mut encoder = BinEncoder::new(&mut buf);
1289
            message.emit(&mut encoder).unwrap();
1290
        }
1291
1292
        let mut decoder = BinDecoder::new(&buf);
1293
        let message = Message::read(&mut decoder).unwrap();
1294
1295
        assert_eq!(message.id(), 4_096);
1296
    }
1297
1298
    #[test]
1299
    fn rdata_zero_roundtrip() {
1300
        let buf = &[
1301
            160, 160, 0, 13, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0,
1302
        ];
1303
1304
        assert!(Message::from_bytes(buf).is_err());
1305
    }
1306
1307
    #[test]
1308
    fn nsec_deserialization() {
1309
        const CRASHING_MESSAGE: &[u8] = &[
1310
            0, 0, 132, 0, 0, 0, 0, 1, 0, 0, 0, 1, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100,
1311
            52, 50, 52, 45, 52, 102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55,
1312
            56, 48, 102, 50, 98, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4,
1313
            192, 168, 1, 17, 36, 49, 101, 48, 101, 101, 51, 100, 51, 45, 100, 52, 50, 52, 45, 52,
1314
            102, 55, 56, 45, 57, 101, 52, 99, 45, 99, 51, 56, 51, 51, 55, 55, 56, 48, 102, 50, 98,
1315
            5, 108, 111, 99, 97, 108, 0, 0, 47, 128, 1, 0, 0, 0, 120, 0, 5, 192, 70, 0, 1, 64,
1316
        ];
1317
1318
        Message::from_vec(CRASHING_MESSAGE).expect("failed to parse message");
1319
    }
1320
1321
    #[test]
1322
    fn prior_to_pointer() {
1323
        const MESSAGE: &[u8] = include_bytes!("../../tests/test-data/fuzz-prior-to-pointer.rdata");
1324
        let message = Message::from_bytes(MESSAGE).expect("failed to parse message");
1325
        let encoded = message.to_bytes().unwrap();
1326
        Message::from_bytes(&encoded).expect("failed to parse encoded message");
1327
    }
1328
1329
    #[test]
1330
    fn test_read_records_unsigned() {
1331
        let records = vec![
1332
            Record::from_rdata(
1333
                Name::from_labels(vec!["example", "com"]).unwrap(),
1334
                300,
1335
                RData::A(A::new(127, 0, 0, 1)),
1336
            ),
1337
            Record::from_rdata(
1338
                Name::from_labels(vec!["www", "example", "com"]).unwrap(),
1339
                300,
1340
                RData::A(A::new(127, 0, 0, 1)),
1341
            ),
1342
        ];
1343
        let result = encode_and_read_records(records.clone(), false);
1344
        let (output_records, edns, signature) = result.unwrap();
1345
        assert_eq!(output_records.len(), records.len());
1346
        assert!(edns.is_none());
1347
        assert_eq!(signature, MessageSignature::Unsigned);
1348
    }
1349
1350
    #[cfg(feature = "std")]
1351
    #[test]
1352
    fn test_read_records_edns() {
1353
        let records = vec![
1354
            Record::from_rdata(
1355
                Name::from_labels(vec!["example", "com"]).unwrap(),
1356
                300,
1357
                RData::A(A::new(127, 0, 0, 1)),
1358
            ),
1359
            Record::from_rdata(
1360
                Name::new(),
1361
                0,
1362
                RData::OPT(OPT::new(vec![(
1363
                    EdnsCode::Subnet,
1364
                    EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1365
                )])),
1366
            ),
1367
        ];
1368
        let result = encode_and_read_records(records, true);
1369
        let (output_records, edns, signature) = result.unwrap();
1370
        assert_eq!(output_records.len(), 1); // Only the A record, OPT becomes EDNS
1371
        assert!(edns.is_some());
1372
        assert_eq!(signature, MessageSignature::Unsigned);
1373
    }
1374
1375
    #[cfg(feature = "__dnssec")]
1376
    #[test]
1377
    fn test_read_records_tsig() {
1378
        let records = vec![
1379
            Record::from_rdata(
1380
                Name::from_labels(vec!["example", "com"]).unwrap(),
1381
                300,
1382
                RData::A(A::new(127, 0, 0, 1)),
1383
            ),
1384
            Record::from_rdata(
1385
                Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1386
                0,
1387
                fake_tsig(),
1388
            ),
1389
        ];
1390
        let result = encode_and_read_records(records, true);
1391
        let (output_records, edns, signature) = result.unwrap();
1392
        assert_eq!(output_records.len(), 1); // Only the A record, TSIG becomes signature
1393
        assert!(edns.is_none());
1394
        assert!(matches!(signature, MessageSignature::Tsig(_)));
1395
    }
1396
1397
    #[cfg(feature = "__dnssec")]
1398
    #[test]
1399
    fn test_read_records_sig0() {
1400
        let records = vec![
1401
            Record::from_rdata(
1402
                Name::from_labels(vec!["example", "com"]).unwrap(),
1403
                300,
1404
                RData::A(A::new(127, 0, 0, 1)),
1405
            ),
1406
            Record::from_rdata(
1407
                Name::from_labels(vec!["sig", "example", "com"]).unwrap(),
1408
                0,
1409
                fake_sig0(),
1410
            ),
1411
        ];
1412
        let result = encode_and_read_records(records, true);
1413
        assert!(result.is_ok());
1414
        let (output_records, edns, signature) = result.unwrap();
1415
        assert_eq!(output_records.len(), 1); // Only the A record, SIG0 becomes signature
1416
        assert!(edns.is_none());
1417
        assert!(matches!(signature, MessageSignature::Sig0(_)));
1418
    }
1419
1420
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1421
    #[test]
1422
    fn test_read_records_edns_tsig() {
1423
        let records = vec![
1424
            Record::from_rdata(
1425
                Name::from_labels(vec!["example", "com"]).unwrap(),
1426
                300,
1427
                RData::A(A::new(127, 0, 0, 1)),
1428
            ),
1429
            Record::from_rdata(
1430
                Name::new(),
1431
                0,
1432
                RData::OPT(OPT::new(vec![(
1433
                    EdnsCode::Subnet,
1434
                    EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1435
                )])),
1436
            ),
1437
            Record::from_rdata(
1438
                Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1439
                0,
1440
                fake_tsig(),
1441
            ),
1442
        ];
1443
1444
        let result = encode_and_read_records(records, true);
1445
        assert!(result.is_ok());
1446
        let (output_records, edns, signature) = result.unwrap();
1447
        assert_eq!(output_records.len(), 1); // Only the A record
1448
        assert!(edns.is_some());
1449
        assert!(matches!(signature, MessageSignature::Tsig(_)));
1450
    }
1451
1452
    #[cfg(feature = "std")]
1453
    #[test]
1454
    fn test_read_records_unsigned_multiple_edns() {
1455
        let opt_record = Record::from_rdata(
1456
            Name::new(),
1457
            0,
1458
            RData::OPT(OPT::new(vec![(
1459
                EdnsCode::Subnet,
1460
                EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1461
            )])),
1462
        );
1463
        let error = encode_and_read_records(
1464
            vec![
1465
                opt_record.clone(),
1466
                Record::from_rdata(
1467
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1468
                    300,
1469
                    RData::A(A::new(127, 0, 0, 1)),
1470
                ),
1471
                opt_record.clone(),
1472
            ],
1473
            true,
1474
        )
1475
        .unwrap_err();
1476
        assert!(
1477
            error
1478
                .to_string()
1479
                .contains("more than one edns record present")
1480
        );
1481
    }
1482
1483
    #[cfg(feature = "std")]
1484
    #[test]
1485
    fn test_read_records_opt_not_additional() {
1486
        let opt_record = Record::from_rdata(
1487
            Name::new(),
1488
            0,
1489
            RData::OPT(OPT::new(vec![(
1490
                EdnsCode::Subnet,
1491
                EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1492
            )])),
1493
        );
1494
        let err = encode_and_read_records(
1495
            vec![
1496
                opt_record.clone(),
1497
                Record::from_rdata(
1498
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1499
                    300,
1500
                    RData::A(A::new(127, 0, 0, 1)),
1501
                ),
1502
            ],
1503
            false,
1504
        )
1505
        .unwrap_err();
1506
        assert!(
1507
            err.to_string()
1508
                .contains("record type OPT only allowed in additional section")
1509
        );
1510
    }
1511
1512
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1513
    #[test]
1514
    fn test_read_records_signed_multiple_edns() {
1515
        let opt_record = Record::from_rdata(
1516
            Name::new(),
1517
            0,
1518
            RData::OPT(OPT::new(vec![(
1519
                EdnsCode::Subnet,
1520
                EdnsOption::Subnet(ClientSubnet::new(IpAddr::from([127, 0, 0, 1]), 0, 24)),
1521
            )])),
1522
        );
1523
        let error = encode_and_read_records(
1524
            vec![
1525
                opt_record.clone(),
1526
                Record::from_rdata(
1527
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1528
                    300,
1529
                    RData::A(A::new(127, 0, 0, 1)),
1530
                ),
1531
                opt_record.clone(),
1532
                Record::from_rdata(
1533
                    Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1534
                    0,
1535
                    fake_tsig(),
1536
                ),
1537
            ],
1538
            true,
1539
        )
1540
        .unwrap_err();
1541
        assert!(
1542
            error
1543
                .to_string()
1544
                .contains("more than one edns record present")
1545
        );
1546
    }
1547
1548
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1549
    #[test]
1550
    fn test_read_records_tsig_not_additional() {
1551
        let err = encode_and_read_records(
1552
            vec![
1553
                Record::from_rdata(
1554
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1555
                    300,
1556
                    RData::A(A::new(127, 0, 0, 1)),
1557
                ),
1558
                Record::from_rdata(
1559
                    Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1560
                    0,
1561
                    fake_tsig(),
1562
                ),
1563
            ],
1564
            false,
1565
        )
1566
        .unwrap_err();
1567
        assert!(
1568
            err.to_string()
1569
                .contains("record type TSIG only allowed in additional section")
1570
        );
1571
    }
1572
1573
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1574
    #[test]
1575
    fn test_read_records_sig0_not_additional() {
1576
        let err = encode_and_read_records(
1577
            vec![
1578
                Record::from_rdata(
1579
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1580
                    300,
1581
                    RData::A(A::new(127, 0, 0, 1)),
1582
                ),
1583
                Record::from_rdata(
1584
                    Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
1585
                    0,
1586
                    fake_sig0(),
1587
                ),
1588
            ],
1589
            false,
1590
        )
1591
        .unwrap_err();
1592
        assert!(
1593
            err.to_string()
1594
                .contains("record type SIG only allowed in additional section")
1595
        );
1596
    }
1597
1598
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1599
    #[test]
1600
    fn test_read_records_tsig_not_last() {
1601
        let a_record = Record::from_rdata(
1602
            Name::from_labels(vec!["example", "com"]).unwrap(),
1603
            300,
1604
            RData::A(A::new(127, 0, 0, 1)),
1605
        );
1606
        let error = encode_and_read_records(
1607
            vec![
1608
                a_record.clone(),
1609
                Record::from_rdata(
1610
                    Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1611
                    0,
1612
                    fake_tsig(),
1613
                ),
1614
                a_record.clone(),
1615
            ],
1616
            true,
1617
        )
1618
        .unwrap_err()
1619
        .to_string();
1620
        assert!(error.contains("TSIG or SIG(0) record must be final"));
1621
    }
1622
1623
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1624
    #[test]
1625
    fn test_read_records_sig0_not_last() {
1626
        let a_record = Record::from_rdata(
1627
            Name::from_labels(vec!["example", "com"]).unwrap(),
1628
            300,
1629
            RData::A(A::new(127, 0, 0, 1)),
1630
        );
1631
        let error = encode_and_read_records(
1632
            vec![
1633
                a_record.clone(),
1634
                Record::from_rdata(
1635
                    Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
1636
                    0,
1637
                    fake_tsig(),
1638
                ),
1639
                a_record.clone(),
1640
            ],
1641
            true,
1642
        )
1643
        .unwrap_err()
1644
        .to_string();
1645
        assert!(error.contains("TSIG or SIG(0) record must be final"));
1646
    }
1647
1648
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1649
    #[test]
1650
    fn test_read_records_both_sig0_tsig() {
1651
        let error = encode_and_read_records(
1652
            vec![
1653
                Record::from_rdata(
1654
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1655
                    300,
1656
                    RData::A(A::new(127, 0, 0, 1)),
1657
                ),
1658
                Record::from_rdata(
1659
                    Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
1660
                    0,
1661
                    fake_sig0(),
1662
                ),
1663
                Record::from_rdata(
1664
                    Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1665
                    0,
1666
                    fake_tsig(),
1667
                ),
1668
            ],
1669
            true,
1670
        )
1671
        .unwrap_err()
1672
        .to_string();
1673
        assert!(error.contains("TSIG or SIG(0) record must be final"));
1674
    }
1675
1676
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1677
    #[test]
1678
    fn test_read_records_multiple_tsig() {
1679
        let tsig_record = Record::from_rdata(
1680
            Name::from_labels(vec!["tsig", "example", "com"]).unwrap(),
1681
            0,
1682
            fake_tsig(),
1683
        );
1684
        let error = encode_and_read_records(
1685
            vec![
1686
                Record::from_rdata(
1687
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1688
                    300,
1689
                    RData::A(A::new(127, 0, 0, 1)),
1690
                ),
1691
                tsig_record.clone(),
1692
                tsig_record.clone(),
1693
            ],
1694
            true,
1695
        )
1696
        .unwrap_err()
1697
        .to_string();
1698
        assert!(error.contains("TSIG or SIG(0) record must be final"));
1699
    }
1700
1701
    #[cfg(all(feature = "std", feature = "__dnssec"))]
1702
    #[test]
1703
    fn test_read_records_multiple_sig0() {
1704
        let sig0_record = Record::from_rdata(
1705
            Name::from_labels(vec!["sig0", "example", "com"]).unwrap(),
1706
            0,
1707
            fake_tsig(),
1708
        );
1709
        let error = encode_and_read_records(
1710
            vec![
1711
                Record::from_rdata(
1712
                    Name::from_labels(vec!["example", "com"]).unwrap(),
1713
                    300,
1714
                    RData::A(A::new(127, 0, 0, 1)),
1715
                ),
1716
                sig0_record.clone(),
1717
                sig0_record.clone(),
1718
            ],
1719
            true,
1720
        )
1721
        .unwrap_err()
1722
        .to_string();
1723
        assert!(error.contains("TSIG or SIG(0) record must be final"));
1724
    }
1725
1726
    fn encode_and_read_records(
1727
        records: Vec<Record>,
1728
        is_additional: bool,
1729
    ) -> ProtoResult<(Vec<Record>, Option<Edns>, MessageSignature)> {
1730
        let mut bytes = Vec::new();
1731
        let mut encoder = BinEncoder::new(&mut bytes);
1732
        encoder.emit_all(records.iter())?;
1733
        Message::read_records(&mut BinDecoder::new(&bytes), records.len(), is_additional)
1734
    }
1735
1736
    #[cfg(feature = "__dnssec")]
1737
    fn fake_tsig() -> RData {
1738
        RData::DNSSEC(DNSSECRData::TSIG(TSIG::new(
1739
            TsigAlgorithm::HmacSha256,
1740
            0,
1741
            0,
1742
            vec![],
1743
            0,
1744
            None,
1745
            vec![],
1746
        )))
1747
    }
1748
1749
    #[cfg(feature = "__dnssec")]
1750
    fn fake_sig0() -> RData {
1751
        RData::DNSSEC(DNSSECRData::SIG(SIG {
1752
            input: SigInput {
1753
                type_covered: RecordType::A,
1754
                algorithm: Algorithm::RSASHA256,
1755
                num_labels: 0,
1756
                original_ttl: 0,
1757
                sig_expiration: SerialNumber(0),
1758
                sig_inception: SerialNumber(0),
1759
                key_tag: 0,
1760
                signer_name: Name::root(),
1761
            },
1762
            sig: Vec::new(),
1763
        }))
1764
    }
1765
}