Coverage Report

Created: 2025-05-07 06:59

/rust/registry/src/index.crates.io-6f17d22bba15001f/rustls-0.23.26/src/msgs/handshake.rs
Line
Count
Source (jump to first uncovered line)
1
use alloc::collections::BTreeSet;
2
#[cfg(feature = "logging")]
3
use alloc::string::String;
4
use alloc::vec;
5
use alloc::vec::Vec;
6
use core::ops::Deref;
7
use core::{fmt, iter};
8
9
use pki_types::{CertificateDer, DnsName};
10
11
#[cfg(feature = "tls12")]
12
use crate::crypto::ActiveKeyExchange;
13
use crate::crypto::SecureRandom;
14
use crate::enums::{
15
    CertificateCompressionAlgorithm, CipherSuite, EchClientHelloType, HandshakeType,
16
    ProtocolVersion, SignatureScheme,
17
};
18
use crate::error::InvalidMessage;
19
#[cfg(feature = "tls12")]
20
use crate::ffdhe_groups::FfdheGroup;
21
use crate::log::warn;
22
use crate::msgs::base::{Payload, PayloadU8, PayloadU16, PayloadU24};
23
use crate::msgs::codec::{self, Codec, LengthPrefixedBuffer, ListLength, Reader, TlsListElement};
24
use crate::msgs::enums::{
25
    CertificateStatusType, CertificateType, ClientCertificateType, Compression, ECCurveType,
26
    ECPointFormat, EchVersion, ExtensionType, HpkeAead, HpkeKdf, HpkeKem, KeyUpdateRequest,
27
    NamedGroup, PSKKeyExchangeMode, ServerNameType,
28
};
29
use crate::rand;
30
use crate::sync::Arc;
31
use crate::verify::DigitallySignedStruct;
32
use crate::x509::wrap_in_sequence;
33
34
/// Create a newtype wrapper around a given type.
35
///
36
/// This is used to create newtypes for the various TLS message types which is used to wrap
37
/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
38
/// anything other than access to the underlying bytes.
39
macro_rules! wrapped_payload(
40
  ($(#[$comment:meta])* $vis:vis struct $name:ident, $inner:ident,) => {
41
    $(#[$comment])*
42
    #[derive(Clone, Debug)]
43
    $vis struct $name($inner);
44
45
    impl From<Vec<u8>> for $name {
46
0
        fn from(v: Vec<u8>) -> Self {
47
0
            Self($inner::new(v))
48
0
        }
Unexecuted instantiation: <rustls::msgs::handshake::ProtocolName as core::convert::From<alloc::vec::Vec<u8>>>::from
Unexecuted instantiation: <rustls::msgs::handshake::PresharedKeyBinder as core::convert::From<alloc::vec::Vec<u8>>>::from
Unexecuted instantiation: <rustls::msgs::handshake::ResponderId as core::convert::From<alloc::vec::Vec<u8>>>::from
Unexecuted instantiation: <rustls::msgs::handshake::DistinguishedName as core::convert::From<alloc::vec::Vec<u8>>>::from
49
    }
50
51
    impl AsRef<[u8]> for $name {
52
0
        fn as_ref(&self) -> &[u8] {
53
0
            self.0.0.as_slice()
54
0
        }
Unexecuted instantiation: <rustls::msgs::handshake::ProtocolName as core::convert::AsRef<[u8]>>::as_ref
Unexecuted instantiation: <rustls::msgs::handshake::PresharedKeyBinder as core::convert::AsRef<[u8]>>::as_ref
Unexecuted instantiation: <rustls::msgs::handshake::ResponderId as core::convert::AsRef<[u8]>>::as_ref
Unexecuted instantiation: <rustls::msgs::handshake::DistinguishedName as core::convert::AsRef<[u8]>>::as_ref
55
    }
56
57
    impl Codec<'_> for $name {
58
0
        fn encode(&self, bytes: &mut Vec<u8>) {
59
0
            self.0.encode(bytes);
60
0
        }
Unexecuted instantiation: <rustls::msgs::handshake::ProtocolName as rustls::msgs::codec::Codec>::encode
Unexecuted instantiation: <rustls::msgs::handshake::PresharedKeyBinder as rustls::msgs::codec::Codec>::encode
Unexecuted instantiation: <rustls::msgs::handshake::ResponderId as rustls::msgs::codec::Codec>::encode
Unexecuted instantiation: <rustls::msgs::handshake::DistinguishedName as rustls::msgs::codec::Codec>::encode
61
62
0
        fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
63
0
            Ok(Self($inner::read(r)?))
64
0
        }
Unexecuted instantiation: <rustls::msgs::handshake::ProtocolName as rustls::msgs::codec::Codec>::read
Unexecuted instantiation: <rustls::msgs::handshake::PresharedKeyBinder as rustls::msgs::codec::Codec>::read
Unexecuted instantiation: <rustls::msgs::handshake::ResponderId as rustls::msgs::codec::Codec>::read
Unexecuted instantiation: <rustls::msgs::handshake::DistinguishedName as rustls::msgs::codec::Codec>::read
65
    }
66
  }
67
);
68
69
#[derive(Clone, Copy, Eq, PartialEq)]
70
pub struct Random(pub(crate) [u8; 32]);
71
72
impl fmt::Debug for Random {
73
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74
0
        super::base::hex(f, &self.0)
75
0
    }
76
}
77
78
static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
79
    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
80
    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
81
]);
82
83
static ZERO_RANDOM: Random = Random([0u8; 32]);
84
85
impl Codec<'_> for Random {
86
0
    fn encode(&self, bytes: &mut Vec<u8>) {
87
0
        bytes.extend_from_slice(&self.0);
88
0
    }
89
90
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
91
0
        let Some(bytes) = r.take(32) else {
92
0
            return Err(InvalidMessage::MissingData("Random"));
93
        };
94
95
0
        let mut opaque = [0; 32];
96
0
        opaque.clone_from_slice(bytes);
97
0
        Ok(Self(opaque))
98
0
    }
99
}
100
101
impl Random {
102
0
    pub(crate) fn new(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
103
0
        let mut data = [0u8; 32];
104
0
        secure_random.fill(&mut data)?;
105
0
        Ok(Self(data))
106
0
    }
107
}
108
109
impl From<[u8; 32]> for Random {
110
    #[inline]
111
0
    fn from(bytes: [u8; 32]) -> Self {
112
0
        Self(bytes)
113
0
    }
114
}
115
116
#[derive(Copy, Clone)]
117
pub struct SessionId {
118
    len: usize,
119
    data: [u8; 32],
120
}
121
122
impl fmt::Debug for SessionId {
123
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124
0
        super::base::hex(f, &self.data[..self.len])
125
0
    }
126
}
127
128
impl PartialEq for SessionId {
129
0
    fn eq(&self, other: &Self) -> bool {
130
0
        if self.len != other.len {
131
0
            return false;
132
0
        }
133
0
134
0
        let mut diff = 0u8;
135
0
        for i in 0..self.len {
136
0
            diff |= self.data[i] ^ other.data[i];
137
0
        }
138
139
0
        diff == 0u8
140
0
    }
141
}
142
143
impl Codec<'_> for SessionId {
144
0
    fn encode(&self, bytes: &mut Vec<u8>) {
145
0
        debug_assert!(self.len <= 32);
146
0
        bytes.push(self.len as u8);
147
0
        bytes.extend_from_slice(self.as_ref());
148
0
    }
149
150
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
151
0
        let len = u8::read(r)? as usize;
152
0
        if len > 32 {
153
0
            return Err(InvalidMessage::TrailingData("SessionID"));
154
0
        }
155
156
0
        let Some(bytes) = r.take(len) else {
157
0
            return Err(InvalidMessage::MissingData("SessionID"));
158
        };
159
160
0
        let mut out = [0u8; 32];
161
0
        out[..len].clone_from_slice(&bytes[..len]);
162
0
        Ok(Self { data: out, len })
163
0
    }
164
}
165
166
impl SessionId {
167
0
    pub fn random(secure_random: &dyn SecureRandom) -> Result<Self, rand::GetRandomFailed> {
168
0
        let mut data = [0u8; 32];
169
0
        secure_random.fill(&mut data)?;
170
0
        Ok(Self { data, len: 32 })
171
0
    }
172
173
0
    pub(crate) fn empty() -> Self {
174
0
        Self {
175
0
            data: [0u8; 32],
176
0
            len: 0,
177
0
        }
178
0
    }
179
180
    #[cfg(feature = "tls12")]
181
0
    pub(crate) fn is_empty(&self) -> bool {
182
0
        self.len == 0
183
0
    }
184
}
185
186
impl AsRef<[u8]> for SessionId {
187
0
    fn as_ref(&self) -> &[u8] {
188
0
        &self.data[..self.len]
189
0
    }
190
}
191
192
#[derive(Clone, Debug, PartialEq)]
193
pub struct UnknownExtension {
194
    pub(crate) typ: ExtensionType,
195
    pub(crate) payload: Payload<'static>,
196
}
197
198
impl UnknownExtension {
199
0
    fn encode(&self, bytes: &mut Vec<u8>) {
200
0
        self.payload.encode(bytes);
201
0
    }
202
203
0
    fn read(typ: ExtensionType, r: &mut Reader<'_>) -> Self {
204
0
        let payload = Payload::read(r).into_owned();
205
0
        Self { typ, payload }
206
0
    }
207
}
208
209
impl TlsListElement for ECPointFormat {
210
    const SIZE_LEN: ListLength = ListLength::U8;
211
}
212
213
impl TlsListElement for NamedGroup {
214
    const SIZE_LEN: ListLength = ListLength::U16;
215
}
216
217
impl TlsListElement for SignatureScheme {
218
    const SIZE_LEN: ListLength = ListLength::U16;
219
}
220
221
#[derive(Clone, Debug)]
222
pub(crate) enum ServerNamePayload {
223
    HostName(DnsName<'static>),
224
    IpAddress(PayloadU16),
225
    Unknown(Payload<'static>),
226
}
227
228
impl ServerNamePayload {
229
0
    pub(crate) fn new_hostname(hostname: DnsName<'static>) -> Self {
230
0
        Self::HostName(hostname)
231
0
    }
232
233
0
    fn read_hostname(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
234
        use pki_types::ServerName;
235
0
        let raw = PayloadU16::read(r)?;
236
237
0
        match ServerName::try_from(raw.0.as_slice()) {
238
0
            Ok(ServerName::DnsName(d)) => Ok(Self::HostName(d.to_owned())),
239
0
            Ok(ServerName::IpAddress(_)) => Ok(Self::IpAddress(raw)),
240
            Ok(_) | Err(_) => {
241
0
                warn!(
242
0
                    "Illegal SNI hostname received {:?}",
243
0
                    String::from_utf8_lossy(&raw.0)
244
0
                );
245
0
                Err(InvalidMessage::InvalidServerName)
246
            }
247
        }
248
0
    }
249
250
0
    fn encode(&self, bytes: &mut Vec<u8>) {
251
0
        match self {
252
0
            Self::HostName(name) => {
253
0
                (name.as_ref().len() as u16).encode(bytes);
254
0
                bytes.extend_from_slice(name.as_ref().as_bytes());
255
0
            }
256
0
            Self::IpAddress(r) => r.encode(bytes),
257
0
            Self::Unknown(r) => r.encode(bytes),
258
        }
259
0
    }
260
}
261
262
#[derive(Clone, Debug)]
263
pub struct ServerName {
264
    pub(crate) typ: ServerNameType,
265
    pub(crate) payload: ServerNamePayload,
266
}
267
268
impl Codec<'_> for ServerName {
269
0
    fn encode(&self, bytes: &mut Vec<u8>) {
270
0
        self.typ.encode(bytes);
271
0
        self.payload.encode(bytes);
272
0
    }
273
274
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
275
0
        let typ = ServerNameType::read(r)?;
276
277
0
        let payload = match typ {
278
0
            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
279
0
            _ => ServerNamePayload::Unknown(Payload::read(r).into_owned()),
280
        };
281
282
0
        Ok(Self { typ, payload })
283
0
    }
284
}
285
286
impl TlsListElement for ServerName {
287
    const SIZE_LEN: ListLength = ListLength::U16;
288
}
289
290
pub(crate) trait ConvertServerNameList {
291
    fn has_duplicate_names_for_type(&self) -> bool;
292
    fn single_hostname(&self) -> Option<DnsName<'_>>;
293
}
294
295
impl ConvertServerNameList for [ServerName] {
296
    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
297
0
    fn has_duplicate_names_for_type(&self) -> bool {
298
0
        has_duplicates::<_, _, u8>(self.iter().map(|name| name.typ))
299
0
    }
300
301
0
    fn single_hostname(&self) -> Option<DnsName<'_>> {
302
0
        fn only_dns_hostnames(name: &ServerName) -> Option<DnsName<'_>> {
303
0
            if let ServerNamePayload::HostName(dns) = &name.payload {
304
0
                Some(dns.borrow())
305
            } else {
306
0
                None
307
            }
308
0
        }
309
310
0
        self.iter()
311
0
            .filter_map(only_dns_hostnames)
312
0
            .next()
313
0
    }
314
}
315
316
wrapped_payload!(pub struct ProtocolName, PayloadU8,);
317
318
impl TlsListElement for ProtocolName {
319
    const SIZE_LEN: ListLength = ListLength::U16;
320
}
321
322
pub(crate) trait ConvertProtocolNameList {
323
    fn from_slices(names: &[&[u8]]) -> Self;
324
    fn to_slices(&self) -> Vec<&[u8]>;
325
    fn as_single_slice(&self) -> Option<&[u8]>;
326
}
327
328
impl ConvertProtocolNameList for Vec<ProtocolName> {
329
0
    fn from_slices(names: &[&[u8]]) -> Self {
330
0
        let mut ret = Self::new();
331
332
0
        for name in names {
333
0
            ret.push(ProtocolName::from(name.to_vec()));
334
0
        }
335
336
0
        ret
337
0
    }
338
339
0
    fn to_slices(&self) -> Vec<&[u8]> {
340
0
        self.iter()
341
0
            .map(|proto| proto.as_ref())
342
0
            .collect::<Vec<&[u8]>>()
343
0
    }
344
345
0
    fn as_single_slice(&self) -> Option<&[u8]> {
346
0
        if self.len() == 1 {
347
0
            Some(self[0].as_ref())
348
        } else {
349
0
            None
350
        }
351
0
    }
352
}
353
354
// --- TLS 1.3 Key shares ---
355
#[derive(Clone, Debug)]
356
pub struct KeyShareEntry {
357
    pub(crate) group: NamedGroup,
358
    pub(crate) payload: PayloadU16,
359
}
360
361
impl KeyShareEntry {
362
0
    pub fn new(group: NamedGroup, payload: impl Into<Vec<u8>>) -> Self {
363
0
        Self {
364
0
            group,
365
0
            payload: PayloadU16::new(payload.into()),
366
0
        }
367
0
    }
Unexecuted instantiation: <rustls::msgs::handshake::KeyShareEntry>::new::<alloc::vec::Vec<u8>>
Unexecuted instantiation: <rustls::msgs::handshake::KeyShareEntry>::new::<&[u8]>
368
369
0
    pub fn group(&self) -> NamedGroup {
370
0
        self.group
371
0
    }
372
}
373
374
impl Codec<'_> for KeyShareEntry {
375
0
    fn encode(&self, bytes: &mut Vec<u8>) {
376
0
        self.group.encode(bytes);
377
0
        self.payload.encode(bytes);
378
0
    }
379
380
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
381
0
        let group = NamedGroup::read(r)?;
382
0
        let payload = PayloadU16::read(r)?;
383
384
0
        Ok(Self { group, payload })
385
0
    }
386
}
387
388
// --- TLS 1.3 PresharedKey offers ---
389
#[derive(Clone, Debug)]
390
pub(crate) struct PresharedKeyIdentity {
391
    pub(crate) identity: PayloadU16,
392
    pub(crate) obfuscated_ticket_age: u32,
393
}
394
395
impl PresharedKeyIdentity {
396
0
    pub(crate) fn new(id: Vec<u8>, age: u32) -> Self {
397
0
        Self {
398
0
            identity: PayloadU16::new(id),
399
0
            obfuscated_ticket_age: age,
400
0
        }
401
0
    }
402
}
403
404
impl Codec<'_> for PresharedKeyIdentity {
405
0
    fn encode(&self, bytes: &mut Vec<u8>) {
406
0
        self.identity.encode(bytes);
407
0
        self.obfuscated_ticket_age.encode(bytes);
408
0
    }
409
410
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
411
0
        Ok(Self {
412
0
            identity: PayloadU16::read(r)?,
413
0
            obfuscated_ticket_age: u32::read(r)?,
414
        })
415
0
    }
416
}
417
418
impl TlsListElement for PresharedKeyIdentity {
419
    const SIZE_LEN: ListLength = ListLength::U16;
420
}
421
422
wrapped_payload!(pub(crate) struct PresharedKeyBinder, PayloadU8,);
423
424
impl TlsListElement for PresharedKeyBinder {
425
    const SIZE_LEN: ListLength = ListLength::U16;
426
}
427
428
#[derive(Clone, Debug)]
429
pub struct PresharedKeyOffer {
430
    pub(crate) identities: Vec<PresharedKeyIdentity>,
431
    pub(crate) binders: Vec<PresharedKeyBinder>,
432
}
433
434
impl PresharedKeyOffer {
435
    /// Make a new one with one entry.
436
0
    pub(crate) fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
437
0
        Self {
438
0
            identities: vec![id],
439
0
            binders: vec![PresharedKeyBinder::from(binder)],
440
0
        }
441
0
    }
442
}
443
444
impl Codec<'_> for PresharedKeyOffer {
445
0
    fn encode(&self, bytes: &mut Vec<u8>) {
446
0
        self.identities.encode(bytes);
447
0
        self.binders.encode(bytes);
448
0
    }
449
450
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
451
0
        Ok(Self {
452
0
            identities: Vec::read(r)?,
453
0
            binders: Vec::read(r)?,
454
        })
455
0
    }
456
}
457
458
// --- RFC6066 certificate status request ---
459
wrapped_payload!(pub(crate) struct ResponderId, PayloadU16,);
460
461
impl TlsListElement for ResponderId {
462
    const SIZE_LEN: ListLength = ListLength::U16;
463
}
464
465
#[derive(Clone, Debug)]
466
pub struct OcspCertificateStatusRequest {
467
    pub(crate) responder_ids: Vec<ResponderId>,
468
    pub(crate) extensions: PayloadU16,
469
}
470
471
impl Codec<'_> for OcspCertificateStatusRequest {
472
0
    fn encode(&self, bytes: &mut Vec<u8>) {
473
0
        CertificateStatusType::OCSP.encode(bytes);
474
0
        self.responder_ids.encode(bytes);
475
0
        self.extensions.encode(bytes);
476
0
    }
477
478
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
479
0
        Ok(Self {
480
0
            responder_ids: Vec::read(r)?,
481
0
            extensions: PayloadU16::read(r)?,
482
        })
483
0
    }
484
}
485
486
#[derive(Clone, Debug)]
487
pub enum CertificateStatusRequest {
488
    Ocsp(OcspCertificateStatusRequest),
489
    Unknown((CertificateStatusType, Payload<'static>)),
490
}
491
492
impl Codec<'_> for CertificateStatusRequest {
493
0
    fn encode(&self, bytes: &mut Vec<u8>) {
494
0
        match self {
495
0
            Self::Ocsp(r) => r.encode(bytes),
496
0
            Self::Unknown((typ, payload)) => {
497
0
                typ.encode(bytes);
498
0
                payload.encode(bytes);
499
0
            }
500
        }
501
0
    }
502
503
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
504
0
        let typ = CertificateStatusType::read(r)?;
505
506
0
        match typ {
507
            CertificateStatusType::OCSP => {
508
0
                let ocsp_req = OcspCertificateStatusRequest::read(r)?;
509
0
                Ok(Self::Ocsp(ocsp_req))
510
            }
511
            _ => {
512
0
                let data = Payload::read(r).into_owned();
513
0
                Ok(Self::Unknown((typ, data)))
514
            }
515
        }
516
0
    }
517
}
518
519
impl CertificateStatusRequest {
520
0
    pub(crate) fn build_ocsp() -> Self {
521
0
        let ocsp = OcspCertificateStatusRequest {
522
0
            responder_ids: Vec::new(),
523
0
            extensions: PayloadU16::empty(),
524
0
        };
525
0
        Self::Ocsp(ocsp)
526
0
    }
527
}
528
529
// ---
530
531
impl TlsListElement for PSKKeyExchangeMode {
532
    const SIZE_LEN: ListLength = ListLength::U8;
533
}
534
535
impl TlsListElement for KeyShareEntry {
536
    const SIZE_LEN: ListLength = ListLength::U16;
537
}
538
539
impl TlsListElement for ProtocolVersion {
540
    const SIZE_LEN: ListLength = ListLength::U8;
541
}
542
543
impl TlsListElement for CertificateType {
544
    const SIZE_LEN: ListLength = ListLength::U8;
545
}
546
547
impl TlsListElement for CertificateCompressionAlgorithm {
548
    const SIZE_LEN: ListLength = ListLength::U8;
549
}
550
551
#[derive(Clone, Debug)]
552
pub enum ClientExtension {
553
    EcPointFormats(Vec<ECPointFormat>),
554
    NamedGroups(Vec<NamedGroup>),
555
    SignatureAlgorithms(Vec<SignatureScheme>),
556
    ServerName(Vec<ServerName>),
557
    SessionTicket(ClientSessionTicket),
558
    Protocols(Vec<ProtocolName>),
559
    SupportedVersions(Vec<ProtocolVersion>),
560
    KeyShare(Vec<KeyShareEntry>),
561
    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
562
    PresharedKey(PresharedKeyOffer),
563
    Cookie(PayloadU16),
564
    ExtendedMasterSecretRequest,
565
    CertificateStatusRequest(CertificateStatusRequest),
566
    ServerCertTypes(Vec<CertificateType>),
567
    ClientCertTypes(Vec<CertificateType>),
568
    TransportParameters(Vec<u8>),
569
    TransportParametersDraft(Vec<u8>),
570
    EarlyData,
571
    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
572
    EncryptedClientHello(EncryptedClientHello),
573
    EncryptedClientHelloOuterExtensions(Vec<ExtensionType>),
574
    AuthorityNames(Vec<DistinguishedName>),
575
    Unknown(UnknownExtension),
576
}
577
578
impl ClientExtension {
579
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
580
0
        match self {
581
0
            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
582
0
            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
583
0
            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
584
0
            Self::ServerName(_) => ExtensionType::ServerName,
585
0
            Self::SessionTicket(_) => ExtensionType::SessionTicket,
586
0
            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
587
0
            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
588
0
            Self::KeyShare(_) => ExtensionType::KeyShare,
589
0
            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
590
0
            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
591
0
            Self::Cookie(_) => ExtensionType::Cookie,
592
0
            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
593
0
            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
594
0
            Self::ClientCertTypes(_) => ExtensionType::ClientCertificateType,
595
0
            Self::ServerCertTypes(_) => ExtensionType::ServerCertificateType,
596
0
            Self::TransportParameters(_) => ExtensionType::TransportParameters,
597
0
            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
598
0
            Self::EarlyData => ExtensionType::EarlyData,
599
0
            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
600
0
            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
601
            Self::EncryptedClientHelloOuterExtensions(_) => {
602
0
                ExtensionType::EncryptedClientHelloOuterExtensions
603
            }
604
0
            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
605
0
            Self::Unknown(r) => r.typ,
606
        }
607
0
    }
608
}
609
610
impl Codec<'_> for ClientExtension {
611
0
    fn encode(&self, bytes: &mut Vec<u8>) {
612
0
        self.ext_type().encode(bytes);
613
0
614
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
615
0
        match self {
616
0
            Self::EcPointFormats(r) => r.encode(nested.buf),
617
0
            Self::NamedGroups(r) => r.encode(nested.buf),
618
0
            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
619
0
            Self::ServerName(r) => r.encode(nested.buf),
620
            Self::SessionTicket(ClientSessionTicket::Request)
621
            | Self::ExtendedMasterSecretRequest
622
0
            | Self::EarlyData => {}
623
0
            Self::SessionTicket(ClientSessionTicket::Offer(r)) => r.encode(nested.buf),
624
0
            Self::Protocols(r) => r.encode(nested.buf),
625
0
            Self::SupportedVersions(r) => r.encode(nested.buf),
626
0
            Self::KeyShare(r) => r.encode(nested.buf),
627
0
            Self::PresharedKeyModes(r) => r.encode(nested.buf),
628
0
            Self::PresharedKey(r) => r.encode(nested.buf),
629
0
            Self::Cookie(r) => r.encode(nested.buf),
630
0
            Self::CertificateStatusRequest(r) => r.encode(nested.buf),
631
0
            Self::ClientCertTypes(r) => r.encode(nested.buf),
632
0
            Self::ServerCertTypes(r) => r.encode(nested.buf),
633
0
            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
634
0
                nested.buf.extend_from_slice(r);
635
0
            }
636
0
            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
637
0
            Self::EncryptedClientHello(r) => r.encode(nested.buf),
638
0
            Self::EncryptedClientHelloOuterExtensions(r) => r.encode(nested.buf),
639
0
            Self::AuthorityNames(r) => r.encode(nested.buf),
640
0
            Self::Unknown(r) => r.encode(nested.buf),
641
        }
642
0
    }
643
644
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
645
0
        let typ = ExtensionType::read(r)?;
646
0
        let len = u16::read(r)? as usize;
647
0
        let mut sub = r.sub(len)?;
648
649
0
        let ext = match typ {
650
0
            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
651
0
            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
652
0
            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
653
0
            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
654
            ExtensionType::SessionTicket => {
655
0
                if sub.any_left() {
656
0
                    let contents = Payload::read(&mut sub).into_owned();
657
0
                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
658
                } else {
659
0
                    Self::SessionTicket(ClientSessionTicket::Request)
660
                }
661
            }
662
0
            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
663
0
            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
664
0
            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
665
0
            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
666
0
            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
667
0
            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
668
0
            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
669
0
                Self::ExtendedMasterSecretRequest
670
            }
671
0
            ExtensionType::ClientCertificateType => Self::ClientCertTypes(Vec::read(&mut sub)?),
672
0
            ExtensionType::ServerCertificateType => Self::ServerCertTypes(Vec::read(&mut sub)?),
673
            ExtensionType::StatusRequest => {
674
0
                let csr = CertificateStatusRequest::read(&mut sub)?;
675
0
                Self::CertificateStatusRequest(csr)
676
            }
677
0
            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
678
            ExtensionType::TransportParametersDraft => {
679
0
                Self::TransportParametersDraft(sub.rest().to_vec())
680
            }
681
0
            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
682
            ExtensionType::CompressCertificate => {
683
0
                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
684
            }
685
            ExtensionType::EncryptedClientHelloOuterExtensions => {
686
0
                Self::EncryptedClientHelloOuterExtensions(Vec::read(&mut sub)?)
687
            }
688
0
            ExtensionType::CertificateAuthorities => Self::AuthorityNames(Vec::read(&mut sub)?),
689
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
690
        };
691
692
0
        sub.expect_empty("ClientExtension")
693
0
            .map(|_| ext)
694
0
    }
695
}
696
697
0
fn trim_hostname_trailing_dot_for_sni(dns_name: &DnsName<'_>) -> DnsName<'static> {
698
0
    let dns_name_str = dns_name.as_ref();
699
0
700
0
    // RFC6066: "The hostname is represented as a byte string using
701
0
    // ASCII encoding without a trailing dot"
702
0
    if dns_name_str.ends_with('.') {
703
0
        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
704
0
        DnsName::try_from(trimmed)
705
0
            .unwrap()
706
0
            .to_owned()
707
    } else {
708
0
        dns_name.to_owned()
709
    }
710
0
}
711
712
impl ClientExtension {
713
    /// Make a basic SNI ServerNameRequest quoting `hostname`.
714
0
    pub(crate) fn make_sni(dns_name: &DnsName<'_>) -> Self {
715
0
        let name = ServerName {
716
0
            typ: ServerNameType::HostName,
717
0
            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
718
0
        };
719
0
720
0
        Self::ServerName(vec![name])
721
0
    }
722
}
723
724
#[derive(Clone, Debug)]
725
pub enum ClientSessionTicket {
726
    Request,
727
    Offer(Payload<'static>),
728
}
729
730
#[derive(Clone, Debug)]
731
pub enum ServerExtension {
732
    EcPointFormats(Vec<ECPointFormat>),
733
    ServerNameAck,
734
    SessionTicketAck,
735
    RenegotiationInfo(PayloadU8),
736
    Protocols(Vec<ProtocolName>),
737
    KeyShare(KeyShareEntry),
738
    PresharedKey(u16),
739
    ExtendedMasterSecretAck,
740
    CertificateStatusAck,
741
    ServerCertType(CertificateType),
742
    ClientCertType(CertificateType),
743
    SupportedVersions(ProtocolVersion),
744
    TransportParameters(Vec<u8>),
745
    TransportParametersDraft(Vec<u8>),
746
    EarlyData,
747
    EncryptedClientHello(ServerEncryptedClientHello),
748
    Unknown(UnknownExtension),
749
}
750
751
impl ServerExtension {
752
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
753
0
        match self {
754
0
            Self::EcPointFormats(_) => ExtensionType::ECPointFormats,
755
0
            Self::ServerNameAck => ExtensionType::ServerName,
756
0
            Self::SessionTicketAck => ExtensionType::SessionTicket,
757
0
            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
758
0
            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
759
0
            Self::KeyShare(_) => ExtensionType::KeyShare,
760
0
            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
761
0
            Self::ClientCertType(_) => ExtensionType::ClientCertificateType,
762
0
            Self::ServerCertType(_) => ExtensionType::ServerCertificateType,
763
0
            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
764
0
            Self::CertificateStatusAck => ExtensionType::StatusRequest,
765
0
            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
766
0
            Self::TransportParameters(_) => ExtensionType::TransportParameters,
767
0
            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
768
0
            Self::EarlyData => ExtensionType::EarlyData,
769
0
            Self::EncryptedClientHello(_) => ExtensionType::EncryptedClientHello,
770
0
            Self::Unknown(r) => r.typ,
771
        }
772
0
    }
773
}
774
775
impl Codec<'_> for ServerExtension {
776
0
    fn encode(&self, bytes: &mut Vec<u8>) {
777
0
        self.ext_type().encode(bytes);
778
0
779
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
780
0
        match self {
781
0
            Self::EcPointFormats(r) => r.encode(nested.buf),
782
            Self::ServerNameAck
783
            | Self::SessionTicketAck
784
            | Self::ExtendedMasterSecretAck
785
            | Self::CertificateStatusAck
786
0
            | Self::EarlyData => {}
787
0
            Self::RenegotiationInfo(r) => r.encode(nested.buf),
788
0
            Self::Protocols(r) => r.encode(nested.buf),
789
0
            Self::KeyShare(r) => r.encode(nested.buf),
790
0
            Self::PresharedKey(r) => r.encode(nested.buf),
791
0
            Self::ClientCertType(r) => r.encode(nested.buf),
792
0
            Self::ServerCertType(r) => r.encode(nested.buf),
793
0
            Self::SupportedVersions(r) => r.encode(nested.buf),
794
0
            Self::TransportParameters(r) | Self::TransportParametersDraft(r) => {
795
0
                nested.buf.extend_from_slice(r);
796
0
            }
797
0
            Self::EncryptedClientHello(r) => r.encode(nested.buf),
798
0
            Self::Unknown(r) => r.encode(nested.buf),
799
        }
800
0
    }
801
802
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
803
0
        let typ = ExtensionType::read(r)?;
804
0
        let len = u16::read(r)? as usize;
805
0
        let mut sub = r.sub(len)?;
806
807
0
        let ext = match typ {
808
0
            ExtensionType::ECPointFormats => Self::EcPointFormats(Vec::read(&mut sub)?),
809
0
            ExtensionType::ServerName => Self::ServerNameAck,
810
0
            ExtensionType::SessionTicket => Self::SessionTicketAck,
811
0
            ExtensionType::StatusRequest => Self::CertificateStatusAck,
812
0
            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
813
0
            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
814
            ExtensionType::ClientCertificateType => {
815
0
                Self::ClientCertType(CertificateType::read(&mut sub)?)
816
            }
817
            ExtensionType::ServerCertificateType => {
818
0
                Self::ServerCertType(CertificateType::read(&mut sub)?)
819
            }
820
0
            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
821
0
            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
822
0
            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
823
            ExtensionType::SupportedVersions => {
824
0
                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
825
            }
826
0
            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
827
            ExtensionType::TransportParametersDraft => {
828
0
                Self::TransportParametersDraft(sub.rest().to_vec())
829
            }
830
0
            ExtensionType::EarlyData => Self::EarlyData,
831
            ExtensionType::EncryptedClientHello => {
832
0
                Self::EncryptedClientHello(ServerEncryptedClientHello::read(&mut sub)?)
833
            }
834
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
835
        };
836
837
0
        sub.expect_empty("ServerExtension")
838
0
            .map(|_| ext)
839
0
    }
840
}
841
842
impl ServerExtension {
843
0
    pub(crate) fn make_alpn(proto: &[&[u8]]) -> Self {
844
0
        Self::Protocols(Vec::from_slices(proto))
845
0
    }
846
847
    #[cfg(feature = "tls12")]
848
0
    pub(crate) fn make_empty_renegotiation_info() -> Self {
849
0
        let empty = Vec::new();
850
0
        Self::RenegotiationInfo(PayloadU8::new(empty))
851
0
    }
852
}
853
854
#[derive(Clone, Debug)]
855
pub struct ClientHelloPayload {
856
    pub client_version: ProtocolVersion,
857
    pub random: Random,
858
    pub session_id: SessionId,
859
    pub cipher_suites: Vec<CipherSuite>,
860
    pub compression_methods: Vec<Compression>,
861
    pub extensions: Vec<ClientExtension>,
862
}
863
864
impl Codec<'_> for ClientHelloPayload {
865
0
    fn encode(&self, bytes: &mut Vec<u8>) {
866
0
        self.payload_encode(bytes, Encoding::Standard)
867
0
    }
868
869
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
870
0
        let mut ret = Self {
871
0
            client_version: ProtocolVersion::read(r)?,
872
0
            random: Random::read(r)?,
873
0
            session_id: SessionId::read(r)?,
874
0
            cipher_suites: Vec::read(r)?,
875
0
            compression_methods: Vec::read(r)?,
876
0
            extensions: Vec::new(),
877
0
        };
878
0
879
0
        if r.any_left() {
880
0
            ret.extensions = Vec::read(r)?;
881
0
        }
882
883
0
        match (r.any_left(), ret.extensions.is_empty()) {
884
0
            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
885
0
            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
886
0
            _ => Ok(ret),
887
        }
888
0
    }
889
}
890
891
impl TlsListElement for CipherSuite {
892
    const SIZE_LEN: ListLength = ListLength::U16;
893
}
894
895
impl TlsListElement for Compression {
896
    const SIZE_LEN: ListLength = ListLength::U8;
897
}
898
899
impl TlsListElement for ClientExtension {
900
    const SIZE_LEN: ListLength = ListLength::U16;
901
}
902
903
impl TlsListElement for ExtensionType {
904
    const SIZE_LEN: ListLength = ListLength::U8;
905
}
906
907
impl ClientHelloPayload {
908
0
    pub(crate) fn ech_inner_encoding(&self, to_compress: Vec<ExtensionType>) -> Vec<u8> {
909
0
        let mut bytes = Vec::new();
910
0
        self.payload_encode(&mut bytes, Encoding::EchInnerHello { to_compress });
911
0
        bytes
912
0
    }
913
914
0
    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
915
0
        self.client_version.encode(bytes);
916
0
        self.random.encode(bytes);
917
0
918
0
        match purpose {
919
            // SessionID is required to be empty in the encoded inner client hello.
920
0
            Encoding::EchInnerHello { .. } => SessionId::empty().encode(bytes),
921
0
            _ => self.session_id.encode(bytes),
922
        }
923
924
0
        self.cipher_suites.encode(bytes);
925
0
        self.compression_methods.encode(bytes);
926
927
0
        let to_compress = match purpose {
928
            // Compressed extensions must be replaced in the encoded inner client hello.
929
0
            Encoding::EchInnerHello { to_compress } if !to_compress.is_empty() => to_compress,
930
            _ => {
931
0
                if !self.extensions.is_empty() {
932
0
                    self.extensions.encode(bytes);
933
0
                }
934
0
                return;
935
            }
936
        };
937
938
        // Safety: not empty check in match guard.
939
0
        let first_compressed_type = *to_compress.first().unwrap();
940
0
941
0
        // Compressed extensions are in a contiguous range and must be replaced
942
0
        // with a marker extension.
943
0
        let compressed_start_idx = self
944
0
            .extensions
945
0
            .iter()
946
0
            .position(|ext| ext.ext_type() == first_compressed_type);
947
0
        let compressed_end_idx = compressed_start_idx.map(|start| start + to_compress.len());
948
0
        let marker_ext = ClientExtension::EncryptedClientHelloOuterExtensions(to_compress);
949
0
950
0
        let exts = self
951
0
            .extensions
952
0
            .iter()
953
0
            .enumerate()
954
0
            .filter_map(|(i, ext)| {
955
0
                if Some(i) == compressed_start_idx {
956
0
                    Some(&marker_ext)
957
0
                } else if Some(i) > compressed_start_idx && Some(i) < compressed_end_idx {
958
0
                    None
959
                } else {
960
0
                    Some(ext)
961
                }
962
0
            });
963
0
964
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
965
0
        for ext in exts {
966
0
            ext.encode(nested.buf);
967
0
        }
968
0
    }
969
970
    /// Returns true if there is more than one extension of a given
971
    /// type.
972
0
    pub(crate) fn has_duplicate_extension(&self) -> bool {
973
0
        has_duplicates::<_, _, u16>(
974
0
            self.extensions
975
0
                .iter()
976
0
                .map(|ext| ext.ext_type()),
977
0
        )
978
0
    }
979
980
0
    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
981
0
        self.extensions
982
0
            .iter()
983
0
            .find(|x| x.ext_type() == ext)
984
0
    }
985
986
0
    pub(crate) fn sni_extension(&self) -> Option<&[ServerName]> {
987
0
        let ext = self.find_extension(ExtensionType::ServerName)?;
988
0
        match ext {
989
0
            // Does this comply with RFC6066?
990
0
            //
991
0
            // [RFC6066][] specifies that literal IP addresses are illegal in
992
0
            // `ServerName`s with a `name_type` of `host_name`.
993
0
            //
994
0
            // Some clients incorrectly send such extensions: we choose to
995
0
            // successfully parse these (into `ServerNamePayload::IpAddress`)
996
0
            // but then act like the client sent no `server_name` extension.
997
0
            //
998
0
            // [RFC6066]: https://datatracker.ietf.org/doc/html/rfc6066#section-3
999
0
            ClientExtension::ServerName(req)
1000
0
                if !req
1001
0
                    .iter()
1002
0
                    .any(|name| matches!(name.payload, ServerNamePayload::IpAddress(_))) =>
1003
            {
1004
0
                Some(req)
1005
            }
1006
0
            _ => None,
1007
        }
1008
0
    }
1009
1010
0
    pub fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1011
0
        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1012
0
        match ext {
1013
0
            ClientExtension::SignatureAlgorithms(req) => Some(req),
1014
0
            _ => None,
1015
        }
1016
0
    }
1017
1018
0
    pub(crate) fn namedgroups_extension(&self) -> Option<&[NamedGroup]> {
1019
0
        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
1020
0
        match ext {
1021
0
            ClientExtension::NamedGroups(req) => Some(req),
1022
0
            _ => None,
1023
        }
1024
0
    }
1025
1026
    #[cfg(feature = "tls12")]
1027
0
    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1028
0
        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1029
0
        match ext {
1030
0
            ClientExtension::EcPointFormats(req) => Some(req),
1031
0
            _ => None,
1032
        }
1033
0
    }
1034
1035
0
    pub(crate) fn server_certificate_extension(&self) -> Option<&[CertificateType]> {
1036
0
        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
1037
0
        match ext {
1038
0
            ClientExtension::ServerCertTypes(req) => Some(req),
1039
0
            _ => None,
1040
        }
1041
0
    }
1042
1043
0
    pub(crate) fn client_certificate_extension(&self) -> Option<&[CertificateType]> {
1044
0
        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
1045
0
        match ext {
1046
0
            ClientExtension::ClientCertTypes(req) => Some(req),
1047
0
            _ => None,
1048
        }
1049
0
    }
1050
1051
0
    pub(crate) fn alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
1052
0
        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1053
0
        match ext {
1054
0
            ClientExtension::Protocols(req) => Some(req),
1055
0
            _ => None,
1056
        }
1057
0
    }
1058
1059
0
    pub(crate) fn quic_params_extension(&self) -> Option<Vec<u8>> {
1060
0
        let ext = self
1061
0
            .find_extension(ExtensionType::TransportParameters)
1062
0
            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1063
0
        match ext {
1064
0
            ClientExtension::TransportParameters(bytes)
1065
0
            | ClientExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
1066
0
            _ => None,
1067
        }
1068
0
    }
1069
1070
    #[cfg(feature = "tls12")]
1071
0
    pub(crate) fn ticket_extension(&self) -> Option<&ClientExtension> {
1072
0
        self.find_extension(ExtensionType::SessionTicket)
1073
0
    }
1074
1075
0
    pub(crate) fn versions_extension(&self) -> Option<&[ProtocolVersion]> {
1076
0
        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1077
0
        match ext {
1078
0
            ClientExtension::SupportedVersions(vers) => Some(vers),
1079
0
            _ => None,
1080
        }
1081
0
    }
1082
1083
0
    pub fn keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
1084
0
        let ext = self.find_extension(ExtensionType::KeyShare)?;
1085
0
        match ext {
1086
0
            ClientExtension::KeyShare(shares) => Some(shares),
1087
0
            _ => None,
1088
        }
1089
0
    }
1090
1091
0
    pub(crate) fn has_keyshare_extension_with_duplicates(&self) -> bool {
1092
0
        self.keyshare_extension()
1093
0
            .map(|entries| {
1094
0
                has_duplicates::<_, _, u16>(
1095
0
                    entries
1096
0
                        .iter()
1097
0
                        .map(|kse| u16::from(kse.group)),
1098
0
                )
1099
0
            })
1100
0
            .unwrap_or_default()
1101
0
    }
1102
1103
0
    pub(crate) fn psk(&self) -> Option<&PresharedKeyOffer> {
1104
0
        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1105
0
        match ext {
1106
0
            ClientExtension::PresharedKey(psk) => Some(psk),
1107
0
            _ => None,
1108
        }
1109
0
    }
1110
1111
0
    pub(crate) fn check_psk_ext_is_last(&self) -> bool {
1112
0
        self.extensions
1113
0
            .last()
1114
0
            .is_some_and(|ext| ext.ext_type() == ExtensionType::PreSharedKey)
1115
0
    }
1116
1117
0
    pub(crate) fn psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
1118
0
        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1119
0
        match ext {
1120
0
            ClientExtension::PresharedKeyModes(psk_modes) => Some(psk_modes),
1121
0
            _ => None,
1122
        }
1123
0
    }
1124
1125
0
    pub(crate) fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
1126
0
        self.psk_modes()
1127
0
            .map(|modes| modes.contains(&mode))
1128
0
            .unwrap_or(false)
1129
0
    }
1130
1131
0
    pub(crate) fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1132
0
        let last_extension = self.extensions.last_mut();
1133
0
        if let Some(ClientExtension::PresharedKey(offer)) = last_extension {
1134
0
            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1135
0
        }
1136
0
    }
1137
1138
    #[cfg(feature = "tls12")]
1139
0
    pub(crate) fn ems_support_offered(&self) -> bool {
1140
0
        self.find_extension(ExtensionType::ExtendedMasterSecret)
1141
0
            .is_some()
1142
0
    }
1143
1144
0
    pub(crate) fn early_data_extension_offered(&self) -> bool {
1145
0
        self.find_extension(ExtensionType::EarlyData)
1146
0
            .is_some()
1147
0
    }
1148
1149
0
    pub(crate) fn certificate_compression_extension(
1150
0
        &self,
1151
0
    ) -> Option<&[CertificateCompressionAlgorithm]> {
1152
0
        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
1153
0
        match ext {
1154
0
            ClientExtension::CertificateCompressionAlgorithms(algs) => Some(algs),
1155
0
            _ => None,
1156
        }
1157
0
    }
1158
1159
0
    pub(crate) fn has_certificate_compression_extension_with_duplicates(&self) -> bool {
1160
0
        if let Some(algs) = self.certificate_compression_extension() {
1161
0
            has_duplicates::<_, _, u16>(algs.iter().cloned())
1162
        } else {
1163
0
            false
1164
        }
1165
0
    }
1166
1167
0
    pub(crate) fn certificate_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1168
0
        match self.find_extension(ExtensionType::CertificateAuthorities)? {
1169
0
            ClientExtension::AuthorityNames(ext) => Some(ext),
1170
0
            _ => unreachable!("extension type checked"),
1171
        }
1172
0
    }
1173
}
1174
1175
#[derive(Clone, Debug)]
1176
pub(crate) enum HelloRetryExtension {
1177
    KeyShare(NamedGroup),
1178
    Cookie(PayloadU16),
1179
    SupportedVersions(ProtocolVersion),
1180
    EchHelloRetryRequest(Vec<u8>),
1181
    Unknown(UnknownExtension),
1182
}
1183
1184
impl HelloRetryExtension {
1185
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
1186
0
        match self {
1187
0
            Self::KeyShare(_) => ExtensionType::KeyShare,
1188
0
            Self::Cookie(_) => ExtensionType::Cookie,
1189
0
            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1190
0
            Self::EchHelloRetryRequest(_) => ExtensionType::EncryptedClientHello,
1191
0
            Self::Unknown(r) => r.typ,
1192
        }
1193
0
    }
1194
}
1195
1196
impl Codec<'_> for HelloRetryExtension {
1197
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1198
0
        self.ext_type().encode(bytes);
1199
0
1200
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1201
0
        match self {
1202
0
            Self::KeyShare(r) => r.encode(nested.buf),
1203
0
            Self::Cookie(r) => r.encode(nested.buf),
1204
0
            Self::SupportedVersions(r) => r.encode(nested.buf),
1205
0
            Self::EchHelloRetryRequest(r) => {
1206
0
                nested.buf.extend_from_slice(r);
1207
0
            }
1208
0
            Self::Unknown(r) => r.encode(nested.buf),
1209
        }
1210
0
    }
1211
1212
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1213
0
        let typ = ExtensionType::read(r)?;
1214
0
        let len = u16::read(r)? as usize;
1215
0
        let mut sub = r.sub(len)?;
1216
1217
0
        let ext = match typ {
1218
0
            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1219
0
            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1220
            ExtensionType::SupportedVersions => {
1221
0
                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1222
            }
1223
0
            ExtensionType::EncryptedClientHello => Self::EchHelloRetryRequest(sub.rest().to_vec()),
1224
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1225
        };
1226
1227
0
        sub.expect_empty("HelloRetryExtension")
1228
0
            .map(|_| ext)
1229
0
    }
1230
}
1231
1232
impl TlsListElement for HelloRetryExtension {
1233
    const SIZE_LEN: ListLength = ListLength::U16;
1234
}
1235
1236
#[derive(Clone, Debug)]
1237
pub struct HelloRetryRequest {
1238
    pub(crate) legacy_version: ProtocolVersion,
1239
    pub session_id: SessionId,
1240
    pub(crate) cipher_suite: CipherSuite,
1241
    pub(crate) extensions: Vec<HelloRetryExtension>,
1242
}
1243
1244
impl Codec<'_> for HelloRetryRequest {
1245
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1246
0
        self.payload_encode(bytes, Encoding::Standard)
1247
0
    }
1248
1249
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1250
0
        let session_id = SessionId::read(r)?;
1251
0
        let cipher_suite = CipherSuite::read(r)?;
1252
0
        let compression = Compression::read(r)?;
1253
1254
0
        if compression != Compression::Null {
1255
0
            return Err(InvalidMessage::UnsupportedCompression);
1256
0
        }
1257
0
1258
0
        Ok(Self {
1259
0
            legacy_version: ProtocolVersion::Unknown(0),
1260
0
            session_id,
1261
0
            cipher_suite,
1262
0
            extensions: Vec::read(r)?,
1263
        })
1264
0
    }
1265
}
1266
1267
impl HelloRetryRequest {
1268
    /// Returns true if there is more than one extension of a given
1269
    /// type.
1270
0
    pub(crate) fn has_duplicate_extension(&self) -> bool {
1271
0
        has_duplicates::<_, _, u16>(
1272
0
            self.extensions
1273
0
                .iter()
1274
0
                .map(|ext| ext.ext_type()),
1275
0
        )
1276
0
    }
1277
1278
0
    pub(crate) fn has_unknown_extension(&self) -> bool {
1279
0
        self.extensions.iter().any(|ext| {
1280
0
            ext.ext_type() != ExtensionType::KeyShare
1281
0
                && ext.ext_type() != ExtensionType::SupportedVersions
1282
0
                && ext.ext_type() != ExtensionType::Cookie
1283
0
                && ext.ext_type() != ExtensionType::EncryptedClientHello
1284
0
        })
1285
0
    }
1286
1287
0
    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1288
0
        self.extensions
1289
0
            .iter()
1290
0
            .find(|x| x.ext_type() == ext)
1291
0
    }
1292
1293
0
    pub fn requested_key_share_group(&self) -> Option<NamedGroup> {
1294
0
        let ext = self.find_extension(ExtensionType::KeyShare)?;
1295
0
        match ext {
1296
0
            HelloRetryExtension::KeyShare(grp) => Some(*grp),
1297
0
            _ => None,
1298
        }
1299
0
    }
1300
1301
0
    pub(crate) fn cookie(&self) -> Option<&PayloadU16> {
1302
0
        let ext = self.find_extension(ExtensionType::Cookie)?;
1303
0
        match ext {
1304
0
            HelloRetryExtension::Cookie(ck) => Some(ck),
1305
0
            _ => None,
1306
        }
1307
0
    }
1308
1309
0
    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1310
0
        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1311
0
        match ext {
1312
0
            HelloRetryExtension::SupportedVersions(ver) => Some(*ver),
1313
0
            _ => None,
1314
        }
1315
0
    }
1316
1317
0
    pub(crate) fn ech(&self) -> Option<&Vec<u8>> {
1318
0
        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
1319
0
        match ext {
1320
0
            HelloRetryExtension::EchHelloRetryRequest(ech) => Some(ech),
1321
0
            _ => None,
1322
        }
1323
0
    }
1324
1325
0
    fn payload_encode(&self, bytes: &mut Vec<u8>, purpose: Encoding) {
1326
0
        self.legacy_version.encode(bytes);
1327
0
        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1328
0
        self.session_id.encode(bytes);
1329
0
        self.cipher_suite.encode(bytes);
1330
0
        Compression::Null.encode(bytes);
1331
0
1332
0
        match purpose {
1333
            // For the purpose of ECH confirmation, the Encrypted Client Hello extension
1334
            // must have its payload replaced by 8 zero bytes.
1335
            //
1336
            // See draft-ietf-tls-esni-18 7.2.1:
1337
            // <https://datatracker.ietf.org/doc/html/draft-ietf-tls-esni-18#name-sending-helloretryrequest-2>
1338
            Encoding::EchConfirmation => {
1339
0
                let extensions = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1340
0
                for ext in &self.extensions {
1341
0
                    match ext.ext_type() {
1342
0
                        ExtensionType::EncryptedClientHello => {
1343
0
                            HelloRetryExtension::EchHelloRetryRequest(vec![0u8; 8])
1344
0
                                .encode(extensions.buf);
1345
0
                        }
1346
0
                        _ => {
1347
0
                            ext.encode(extensions.buf);
1348
0
                        }
1349
                    }
1350
                }
1351
            }
1352
0
            _ => {
1353
0
                self.extensions.encode(bytes);
1354
0
            }
1355
        }
1356
0
    }
1357
}
1358
1359
#[derive(Clone, Debug)]
1360
pub struct ServerHelloPayload {
1361
    pub extensions: Vec<ServerExtension>,
1362
    pub(crate) legacy_version: ProtocolVersion,
1363
    pub(crate) random: Random,
1364
    pub(crate) session_id: SessionId,
1365
    pub(crate) cipher_suite: CipherSuite,
1366
    pub(crate) compression_method: Compression,
1367
}
1368
1369
impl Codec<'_> for ServerHelloPayload {
1370
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1371
0
        self.payload_encode(bytes, Encoding::Standard)
1372
0
    }
1373
1374
    // minus version and random, which have already been read.
1375
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1376
0
        let session_id = SessionId::read(r)?;
1377
0
        let suite = CipherSuite::read(r)?;
1378
0
        let compression = Compression::read(r)?;
1379
1380
        // RFC5246:
1381
        // "The presence of extensions can be detected by determining whether
1382
        //  there are bytes following the compression_method field at the end of
1383
        //  the ServerHello."
1384
0
        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1385
1386
0
        let ret = Self {
1387
0
            legacy_version: ProtocolVersion::Unknown(0),
1388
0
            random: ZERO_RANDOM,
1389
0
            session_id,
1390
0
            cipher_suite: suite,
1391
0
            compression_method: compression,
1392
0
            extensions,
1393
0
        };
1394
0
1395
0
        r.expect_empty("ServerHelloPayload")
1396
0
            .map(|_| ret)
1397
0
    }
1398
}
1399
1400
impl HasServerExtensions for ServerHelloPayload {
1401
0
    fn extensions(&self) -> &[ServerExtension] {
1402
0
        &self.extensions
1403
0
    }
1404
}
1405
1406
impl ServerHelloPayload {
1407
0
    pub(crate) fn key_share(&self) -> Option<&KeyShareEntry> {
1408
0
        let ext = self.find_extension(ExtensionType::KeyShare)?;
1409
0
        match ext {
1410
0
            ServerExtension::KeyShare(share) => Some(share),
1411
0
            _ => None,
1412
        }
1413
0
    }
1414
1415
0
    pub(crate) fn psk_index(&self) -> Option<u16> {
1416
0
        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1417
0
        match ext {
1418
0
            ServerExtension::PresharedKey(index) => Some(*index),
1419
0
            _ => None,
1420
        }
1421
0
    }
1422
1423
0
    pub(crate) fn ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1424
0
        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1425
0
        match ext {
1426
0
            ServerExtension::EcPointFormats(fmts) => Some(fmts),
1427
0
            _ => None,
1428
        }
1429
0
    }
1430
1431
    #[cfg(feature = "tls12")]
1432
0
    pub(crate) fn ems_support_acked(&self) -> bool {
1433
0
        self.find_extension(ExtensionType::ExtendedMasterSecret)
1434
0
            .is_some()
1435
0
    }
1436
1437
0
    pub(crate) fn supported_versions(&self) -> Option<ProtocolVersion> {
1438
0
        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1439
0
        match ext {
1440
0
            ServerExtension::SupportedVersions(vers) => Some(*vers),
1441
0
            _ => None,
1442
        }
1443
0
    }
1444
1445
0
    fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
1446
0
        self.legacy_version.encode(bytes);
1447
0
1448
0
        match encoding {
1449
            // When encoding a ServerHello for ECH confirmation, the random value
1450
            // has the last 8 bytes zeroed out.
1451
0
            Encoding::EchConfirmation => {
1452
0
                // Indexing safety: self.random is 32 bytes long by definition.
1453
0
                let rand_vec = self.random.get_encoding();
1454
0
                bytes.extend_from_slice(&rand_vec.as_slice()[..24]);
1455
0
                bytes.extend_from_slice(&[0u8; 8]);
1456
0
            }
1457
0
            _ => self.random.encode(bytes),
1458
        }
1459
1460
0
        self.session_id.encode(bytes);
1461
0
        self.cipher_suite.encode(bytes);
1462
0
        self.compression_method.encode(bytes);
1463
0
1464
0
        if !self.extensions.is_empty() {
1465
0
            self.extensions.encode(bytes);
1466
0
        }
1467
0
    }
1468
}
1469
1470
#[derive(Clone, Default, Debug)]
1471
pub struct CertificateChain<'a>(pub Vec<CertificateDer<'a>>);
1472
1473
impl CertificateChain<'_> {
1474
0
    pub(crate) fn into_owned(self) -> CertificateChain<'static> {
1475
0
        CertificateChain(
1476
0
            self.0
1477
0
                .into_iter()
1478
0
                .map(|c| c.into_owned())
1479
0
                .collect(),
1480
0
        )
1481
0
    }
1482
}
1483
1484
impl<'a> Codec<'a> for CertificateChain<'a> {
1485
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1486
0
        Vec::encode(&self.0, bytes)
1487
0
    }
1488
1489
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1490
0
        Vec::read(r).map(Self)
1491
0
    }
1492
}
1493
1494
impl<'a> Deref for CertificateChain<'a> {
1495
    type Target = [CertificateDer<'a>];
1496
1497
0
    fn deref(&self) -> &[CertificateDer<'a>] {
1498
0
        &self.0
1499
0
    }
1500
}
1501
1502
impl TlsListElement for CertificateDer<'_> {
1503
    const SIZE_LEN: ListLength = ListLength::U24 {
1504
        max: CERTIFICATE_MAX_SIZE_LIMIT,
1505
        error: InvalidMessage::CertificatePayloadTooLarge,
1506
    };
1507
}
1508
1509
/// TLS has a 16MB size limit on any handshake message,
1510
/// plus a 16MB limit on any given certificate.
1511
///
1512
/// We contract that to 64KB to limit the amount of memory allocation
1513
/// that is directly controllable by the peer.
1514
pub(crate) const CERTIFICATE_MAX_SIZE_LIMIT: usize = 0x1_0000;
1515
1516
#[derive(Debug)]
1517
pub(crate) enum CertificateExtension<'a> {
1518
    CertificateStatus(CertificateStatus<'a>),
1519
    Unknown(UnknownExtension),
1520
}
1521
1522
impl CertificateExtension<'_> {
1523
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
1524
0
        match self {
1525
0
            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1526
0
            Self::Unknown(r) => r.typ,
1527
        }
1528
0
    }
1529
1530
0
    pub(crate) fn cert_status(&self) -> Option<&[u8]> {
1531
0
        match self {
1532
0
            Self::CertificateStatus(cs) => Some(cs.ocsp_response.0.bytes()),
1533
0
            _ => None,
1534
        }
1535
0
    }
1536
1537
0
    pub(crate) fn into_owned(self) -> CertificateExtension<'static> {
1538
0
        match self {
1539
0
            Self::CertificateStatus(st) => CertificateExtension::CertificateStatus(st.into_owned()),
1540
0
            Self::Unknown(unk) => CertificateExtension::Unknown(unk),
1541
        }
1542
0
    }
1543
}
1544
1545
impl<'a> Codec<'a> for CertificateExtension<'a> {
1546
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1547
0
        self.ext_type().encode(bytes);
1548
0
1549
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
1550
0
        match self {
1551
0
            Self::CertificateStatus(r) => r.encode(nested.buf),
1552
0
            Self::Unknown(r) => r.encode(nested.buf),
1553
        }
1554
0
    }
1555
1556
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1557
0
        let typ = ExtensionType::read(r)?;
1558
0
        let len = u16::read(r)? as usize;
1559
0
        let mut sub = r.sub(len)?;
1560
1561
0
        let ext = match typ {
1562
            ExtensionType::StatusRequest => {
1563
0
                let st = CertificateStatus::read(&mut sub)?;
1564
0
                Self::CertificateStatus(st)
1565
            }
1566
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1567
        };
1568
1569
0
        sub.expect_empty("CertificateExtension")
1570
0
            .map(|_| ext)
1571
0
    }
1572
}
1573
1574
impl TlsListElement for CertificateExtension<'_> {
1575
    const SIZE_LEN: ListLength = ListLength::U16;
1576
}
1577
1578
#[derive(Debug)]
1579
pub(crate) struct CertificateEntry<'a> {
1580
    pub(crate) cert: CertificateDer<'a>,
1581
    pub(crate) exts: Vec<CertificateExtension<'a>>,
1582
}
1583
1584
impl<'a> Codec<'a> for CertificateEntry<'a> {
1585
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1586
0
        self.cert.encode(bytes);
1587
0
        self.exts.encode(bytes);
1588
0
    }
1589
1590
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1591
0
        Ok(Self {
1592
0
            cert: CertificateDer::read(r)?,
1593
0
            exts: Vec::read(r)?,
1594
        })
1595
0
    }
1596
}
1597
1598
impl<'a> CertificateEntry<'a> {
1599
0
    pub(crate) fn new(cert: CertificateDer<'a>) -> Self {
1600
0
        Self {
1601
0
            cert,
1602
0
            exts: Vec::new(),
1603
0
        }
1604
0
    }
1605
1606
0
    pub(crate) fn into_owned(self) -> CertificateEntry<'static> {
1607
0
        CertificateEntry {
1608
0
            cert: self.cert.into_owned(),
1609
0
            exts: self
1610
0
                .exts
1611
0
                .into_iter()
1612
0
                .map(CertificateExtension::into_owned)
1613
0
                .collect(),
1614
0
        }
1615
0
    }
1616
1617
0
    pub(crate) fn has_duplicate_extension(&self) -> bool {
1618
0
        has_duplicates::<_, _, u16>(
1619
0
            self.exts
1620
0
                .iter()
1621
0
                .map(|ext| ext.ext_type()),
1622
0
        )
1623
0
    }
1624
1625
0
    pub(crate) fn has_unknown_extension(&self) -> bool {
1626
0
        self.exts
1627
0
            .iter()
1628
0
            .any(|ext| ext.ext_type() != ExtensionType::StatusRequest)
1629
0
    }
1630
1631
0
    pub(crate) fn ocsp_response(&self) -> Option<&[u8]> {
1632
0
        self.exts
1633
0
            .iter()
1634
0
            .find(|ext| ext.ext_type() == ExtensionType::StatusRequest)
1635
0
            .and_then(CertificateExtension::cert_status)
1636
0
    }
1637
}
1638
1639
impl TlsListElement for CertificateEntry<'_> {
1640
    const SIZE_LEN: ListLength = ListLength::U24 {
1641
        max: CERTIFICATE_MAX_SIZE_LIMIT,
1642
        error: InvalidMessage::CertificatePayloadTooLarge,
1643
    };
1644
}
1645
1646
#[derive(Debug)]
1647
pub struct CertificatePayloadTls13<'a> {
1648
    pub(crate) context: PayloadU8,
1649
    pub(crate) entries: Vec<CertificateEntry<'a>>,
1650
}
1651
1652
impl<'a> Codec<'a> for CertificatePayloadTls13<'a> {
1653
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1654
0
        self.context.encode(bytes);
1655
0
        self.entries.encode(bytes);
1656
0
    }
1657
1658
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
1659
0
        Ok(Self {
1660
0
            context: PayloadU8::read(r)?,
1661
0
            entries: Vec::read(r)?,
1662
        })
1663
0
    }
1664
}
1665
1666
impl<'a> CertificatePayloadTls13<'a> {
1667
0
    pub(crate) fn new(
1668
0
        certs: impl Iterator<Item = &'a CertificateDer<'a>>,
1669
0
        ocsp_response: Option<&'a [u8]>,
1670
0
    ) -> Self {
1671
0
        Self {
1672
0
            context: PayloadU8::empty(),
1673
0
            entries: certs
1674
0
                // zip certificate iterator with `ocsp_response` followed by
1675
0
                // an infinite-length iterator of `None`.
1676
0
                .zip(
1677
0
                    ocsp_response
1678
0
                        .into_iter()
1679
0
                        .map(Some)
1680
0
                        .chain(iter::repeat(None)),
1681
0
                )
1682
0
                .map(|(cert, ocsp)| {
1683
0
                    let mut e = CertificateEntry::new(cert.clone());
1684
0
                    if let Some(ocsp) = ocsp {
1685
0
                        e.exts
1686
0
                            .push(CertificateExtension::CertificateStatus(
1687
0
                                CertificateStatus::new(ocsp),
1688
0
                            ));
1689
0
                    }
1690
0
                    e
1691
0
                })
1692
0
                .collect(),
1693
0
        }
1694
0
    }
1695
1696
0
    pub(crate) fn into_owned(self) -> CertificatePayloadTls13<'static> {
1697
0
        CertificatePayloadTls13 {
1698
0
            context: self.context,
1699
0
            entries: self
1700
0
                .entries
1701
0
                .into_iter()
1702
0
                .map(CertificateEntry::into_owned)
1703
0
                .collect(),
1704
0
        }
1705
0
    }
1706
1707
0
    pub(crate) fn any_entry_has_duplicate_extension(&self) -> bool {
1708
0
        for entry in &self.entries {
1709
0
            if entry.has_duplicate_extension() {
1710
0
                return true;
1711
0
            }
1712
        }
1713
1714
0
        false
1715
0
    }
1716
1717
0
    pub(crate) fn any_entry_has_unknown_extension(&self) -> bool {
1718
0
        for entry in &self.entries {
1719
0
            if entry.has_unknown_extension() {
1720
0
                return true;
1721
0
            }
1722
        }
1723
1724
0
        false
1725
0
    }
1726
1727
0
    pub(crate) fn any_entry_has_extension(&self) -> bool {
1728
0
        for entry in &self.entries {
1729
0
            if !entry.exts.is_empty() {
1730
0
                return true;
1731
0
            }
1732
        }
1733
1734
0
        false
1735
0
    }
1736
1737
0
    pub(crate) fn end_entity_ocsp(&self) -> Vec<u8> {
1738
0
        self.entries
1739
0
            .first()
1740
0
            .and_then(CertificateEntry::ocsp_response)
1741
0
            .map(|resp| resp.to_vec())
1742
0
            .unwrap_or_default()
1743
0
    }
1744
1745
0
    pub(crate) fn into_certificate_chain(self) -> CertificateChain<'a> {
1746
0
        CertificateChain(
1747
0
            self.entries
1748
0
                .into_iter()
1749
0
                .map(|e| e.cert)
1750
0
                .collect(),
1751
0
        )
1752
0
    }
1753
}
1754
1755
/// Describes supported key exchange mechanisms.
1756
#[derive(Clone, Copy, Debug, PartialEq)]
1757
#[non_exhaustive]
1758
pub enum KeyExchangeAlgorithm {
1759
    /// Diffie-Hellman Key exchange (with only known parameters as defined in [RFC 7919]).
1760
    ///
1761
    /// [RFC 7919]: https://datatracker.ietf.org/doc/html/rfc7919
1762
    DHE,
1763
    /// Key exchange performed via elliptic curve Diffie-Hellman.
1764
    ECDHE,
1765
}
1766
1767
pub(crate) static ALL_KEY_EXCHANGE_ALGORITHMS: &[KeyExchangeAlgorithm] =
1768
    &[KeyExchangeAlgorithm::ECDHE, KeyExchangeAlgorithm::DHE];
1769
1770
// We don't support arbitrary curves.  It's a terrible
1771
// idea and unnecessary attack surface.  Please,
1772
// get a grip.
1773
#[derive(Debug)]
1774
pub(crate) struct EcParameters {
1775
    pub(crate) curve_type: ECCurveType,
1776
    pub(crate) named_group: NamedGroup,
1777
}
1778
1779
impl Codec<'_> for EcParameters {
1780
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1781
0
        self.curve_type.encode(bytes);
1782
0
        self.named_group.encode(bytes);
1783
0
    }
1784
1785
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1786
0
        let ct = ECCurveType::read(r)?;
1787
0
        if ct != ECCurveType::NamedCurve {
1788
0
            return Err(InvalidMessage::UnsupportedCurveType);
1789
0
        }
1790
1791
0
        let grp = NamedGroup::read(r)?;
1792
1793
0
        Ok(Self {
1794
0
            curve_type: ct,
1795
0
            named_group: grp,
1796
0
        })
1797
0
    }
1798
}
1799
1800
#[cfg(feature = "tls12")]
1801
pub(crate) trait KxDecode<'a>: fmt::Debug + Sized {
1802
    /// Decode a key exchange message given the key_exchange `algo`
1803
    fn decode(r: &mut Reader<'a>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage>;
1804
}
1805
1806
#[cfg(feature = "tls12")]
1807
#[derive(Debug)]
1808
pub(crate) enum ClientKeyExchangeParams {
1809
    Ecdh(ClientEcdhParams),
1810
    Dh(ClientDhParams),
1811
}
1812
1813
#[cfg(feature = "tls12")]
1814
impl ClientKeyExchangeParams {
1815
0
    pub(crate) fn pub_key(&self) -> &[u8] {
1816
0
        match self {
1817
0
            Self::Ecdh(ecdh) => &ecdh.public.0,
1818
0
            Self::Dh(dh) => &dh.public.0,
1819
        }
1820
0
    }
1821
1822
0
    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1823
0
        match self {
1824
0
            Self::Ecdh(ecdh) => ecdh.encode(buf),
1825
0
            Self::Dh(dh) => dh.encode(buf),
1826
        }
1827
0
    }
1828
}
1829
1830
#[cfg(feature = "tls12")]
1831
impl KxDecode<'_> for ClientKeyExchangeParams {
1832
0
    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1833
        use KeyExchangeAlgorithm::*;
1834
0
        Ok(match algo {
1835
0
            ECDHE => Self::Ecdh(ClientEcdhParams::read(r)?),
1836
0
            DHE => Self::Dh(ClientDhParams::read(r)?),
1837
        })
1838
0
    }
1839
}
1840
1841
#[cfg(feature = "tls12")]
1842
#[derive(Debug)]
1843
pub(crate) struct ClientEcdhParams {
1844
    pub(crate) public: PayloadU8,
1845
}
1846
1847
#[cfg(feature = "tls12")]
1848
impl Codec<'_> for ClientEcdhParams {
1849
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1850
0
        self.public.encode(bytes);
1851
0
    }
1852
1853
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1854
0
        let pb = PayloadU8::read(r)?;
1855
0
        Ok(Self { public: pb })
1856
0
    }
1857
}
1858
1859
#[cfg(feature = "tls12")]
1860
#[derive(Debug)]
1861
pub(crate) struct ClientDhParams {
1862
    pub(crate) public: PayloadU16,
1863
}
1864
1865
#[cfg(feature = "tls12")]
1866
impl Codec<'_> for ClientDhParams {
1867
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1868
0
        self.public.encode(bytes);
1869
0
    }
1870
1871
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1872
0
        Ok(Self {
1873
0
            public: PayloadU16::read(r)?,
1874
        })
1875
0
    }
1876
}
1877
1878
#[derive(Debug)]
1879
pub(crate) struct ServerEcdhParams {
1880
    pub(crate) curve_params: EcParameters,
1881
    pub(crate) public: PayloadU8,
1882
}
1883
1884
impl ServerEcdhParams {
1885
    #[cfg(feature = "tls12")]
1886
0
    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1887
0
        Self {
1888
0
            curve_params: EcParameters {
1889
0
                curve_type: ECCurveType::NamedCurve,
1890
0
                named_group: kx.group(),
1891
0
            },
1892
0
            public: PayloadU8::new(kx.pub_key().to_vec()),
1893
0
        }
1894
0
    }
1895
}
1896
1897
impl Codec<'_> for ServerEcdhParams {
1898
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1899
0
        self.curve_params.encode(bytes);
1900
0
        self.public.encode(bytes);
1901
0
    }
1902
1903
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1904
0
        let cp = EcParameters::read(r)?;
1905
0
        let pb = PayloadU8::read(r)?;
1906
1907
0
        Ok(Self {
1908
0
            curve_params: cp,
1909
0
            public: pb,
1910
0
        })
1911
0
    }
1912
}
1913
1914
#[derive(Debug)]
1915
#[allow(non_snake_case)]
1916
pub(crate) struct ServerDhParams {
1917
    pub(crate) dh_p: PayloadU16,
1918
    pub(crate) dh_g: PayloadU16,
1919
    pub(crate) dh_Ys: PayloadU16,
1920
}
1921
1922
impl ServerDhParams {
1923
    #[cfg(feature = "tls12")]
1924
0
    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1925
0
        let Some(params) = kx.ffdhe_group() else {
1926
0
            panic!("invalid NamedGroup for DHE key exchange: {:?}", kx.group());
1927
        };
1928
1929
0
        Self {
1930
0
            dh_p: PayloadU16::new(params.p.to_vec()),
1931
0
            dh_g: PayloadU16::new(params.g.to_vec()),
1932
0
            dh_Ys: PayloadU16::new(kx.pub_key().to_vec()),
1933
0
        }
1934
0
    }
1935
1936
    #[cfg(feature = "tls12")]
1937
0
    pub(crate) fn as_ffdhe_group(&self) -> FfdheGroup<'_> {
1938
0
        FfdheGroup::from_params_trimming_leading_zeros(&self.dh_p.0, &self.dh_g.0)
1939
0
    }
1940
}
1941
1942
impl Codec<'_> for ServerDhParams {
1943
0
    fn encode(&self, bytes: &mut Vec<u8>) {
1944
0
        self.dh_p.encode(bytes);
1945
0
        self.dh_g.encode(bytes);
1946
0
        self.dh_Ys.encode(bytes);
1947
0
    }
1948
1949
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
1950
0
        Ok(Self {
1951
0
            dh_p: PayloadU16::read(r)?,
1952
0
            dh_g: PayloadU16::read(r)?,
1953
0
            dh_Ys: PayloadU16::read(r)?,
1954
        })
1955
0
    }
1956
}
1957
1958
#[allow(dead_code)]
1959
#[derive(Debug)]
1960
pub(crate) enum ServerKeyExchangeParams {
1961
    Ecdh(ServerEcdhParams),
1962
    Dh(ServerDhParams),
1963
}
1964
1965
impl ServerKeyExchangeParams {
1966
    #[cfg(feature = "tls12")]
1967
0
    pub(crate) fn new(kx: &dyn ActiveKeyExchange) -> Self {
1968
0
        match kx.group().key_exchange_algorithm() {
1969
0
            KeyExchangeAlgorithm::DHE => Self::Dh(ServerDhParams::new(kx)),
1970
0
            KeyExchangeAlgorithm::ECDHE => Self::Ecdh(ServerEcdhParams::new(kx)),
1971
        }
1972
0
    }
1973
1974
    #[cfg(feature = "tls12")]
1975
0
    pub(crate) fn pub_key(&self) -> &[u8] {
1976
0
        match self {
1977
0
            Self::Ecdh(ecdh) => &ecdh.public.0,
1978
0
            Self::Dh(dh) => &dh.dh_Ys.0,
1979
        }
1980
0
    }
1981
1982
0
    pub(crate) fn encode(&self, buf: &mut Vec<u8>) {
1983
0
        match self {
1984
0
            Self::Ecdh(ecdh) => ecdh.encode(buf),
1985
0
            Self::Dh(dh) => dh.encode(buf),
1986
        }
1987
0
    }
1988
}
1989
1990
#[cfg(feature = "tls12")]
1991
impl KxDecode<'_> for ServerKeyExchangeParams {
1992
0
    fn decode(r: &mut Reader<'_>, algo: KeyExchangeAlgorithm) -> Result<Self, InvalidMessage> {
1993
        use KeyExchangeAlgorithm::*;
1994
0
        Ok(match algo {
1995
0
            ECDHE => Self::Ecdh(ServerEcdhParams::read(r)?),
1996
0
            DHE => Self::Dh(ServerDhParams::read(r)?),
1997
        })
1998
0
    }
1999
}
2000
2001
#[derive(Debug)]
2002
pub struct ServerKeyExchange {
2003
    pub(crate) params: ServerKeyExchangeParams,
2004
    pub(crate) dss: DigitallySignedStruct,
2005
}
2006
2007
impl ServerKeyExchange {
2008
0
    pub fn encode(&self, buf: &mut Vec<u8>) {
2009
0
        self.params.encode(buf);
2010
0
        self.dss.encode(buf);
2011
0
    }
2012
}
2013
2014
#[derive(Debug)]
2015
pub enum ServerKeyExchangePayload {
2016
    Known(ServerKeyExchange),
2017
    Unknown(Payload<'static>),
2018
}
2019
2020
impl From<ServerKeyExchange> for ServerKeyExchangePayload {
2021
0
    fn from(value: ServerKeyExchange) -> Self {
2022
0
        Self::Known(value)
2023
0
    }
2024
}
2025
2026
impl Codec<'_> for ServerKeyExchangePayload {
2027
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2028
0
        match self {
2029
0
            Self::Known(x) => x.encode(bytes),
2030
0
            Self::Unknown(x) => x.encode(bytes),
2031
        }
2032
0
    }
2033
2034
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2035
0
        // read as Unknown, fully parse when we know the
2036
0
        // KeyExchangeAlgorithm
2037
0
        Ok(Self::Unknown(Payload::read(r).into_owned()))
2038
0
    }
2039
}
2040
2041
impl ServerKeyExchangePayload {
2042
    #[cfg(feature = "tls12")]
2043
0
    pub(crate) fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ServerKeyExchange> {
2044
0
        if let Self::Unknown(unk) = self {
2045
0
            let mut rd = Reader::init(unk.bytes());
2046
2047
0
            let result = ServerKeyExchange {
2048
0
                params: ServerKeyExchangeParams::decode(&mut rd, kxa).ok()?,
2049
0
                dss: DigitallySignedStruct::read(&mut rd).ok()?,
2050
            };
2051
2052
0
            if !rd.any_left() {
2053
0
                return Some(result);
2054
0
            };
2055
0
        }
2056
2057
0
        None
2058
0
    }
2059
}
2060
2061
// -- EncryptedExtensions (TLS1.3 only) --
2062
2063
impl TlsListElement for ServerExtension {
2064
    const SIZE_LEN: ListLength = ListLength::U16;
2065
}
2066
2067
pub(crate) trait HasServerExtensions {
2068
    fn extensions(&self) -> &[ServerExtension];
2069
2070
    /// Returns true if there is more than one extension of a given
2071
    /// type.
2072
0
    fn has_duplicate_extension(&self) -> bool {
2073
0
        has_duplicates::<_, _, u16>(
2074
0
            self.extensions()
2075
0
                .iter()
2076
0
                .map(|ext| ext.ext_type()),
Unexecuted instantiation: <alloc::vec::Vec<rustls::msgs::handshake::ServerExtension> as rustls::msgs::handshake::HasServerExtensions>::has_duplicate_extension::{closure#0}
Unexecuted instantiation: <rustls::msgs::handshake::ServerHelloPayload as rustls::msgs::handshake::HasServerExtensions>::has_duplicate_extension::{closure#0}
2077
0
        )
2078
0
    }
Unexecuted instantiation: <alloc::vec::Vec<rustls::msgs::handshake::ServerExtension> as rustls::msgs::handshake::HasServerExtensions>::has_duplicate_extension
Unexecuted instantiation: <rustls::msgs::handshake::ServerHelloPayload as rustls::msgs::handshake::HasServerExtensions>::has_duplicate_extension
2079
2080
0
    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
2081
0
        self.extensions()
2082
0
            .iter()
2083
0
            .find(|x| x.ext_type() == ext)
Unexecuted instantiation: <alloc::vec::Vec<rustls::msgs::handshake::ServerExtension> as rustls::msgs::handshake::HasServerExtensions>::find_extension::{closure#0}
Unexecuted instantiation: <rustls::msgs::handshake::ServerHelloPayload as rustls::msgs::handshake::HasServerExtensions>::find_extension::{closure#0}
2084
0
    }
Unexecuted instantiation: <alloc::vec::Vec<rustls::msgs::handshake::ServerExtension> as rustls::msgs::handshake::HasServerExtensions>::find_extension
Unexecuted instantiation: <rustls::msgs::handshake::ServerHelloPayload as rustls::msgs::handshake::HasServerExtensions>::find_extension
2085
2086
0
    fn alpn_protocol(&self) -> Option<&[u8]> {
2087
0
        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
2088
0
        match ext {
2089
0
            ServerExtension::Protocols(protos) => protos.as_single_slice(),
2090
0
            _ => None,
2091
        }
2092
0
    }
Unexecuted instantiation: <alloc::vec::Vec<rustls::msgs::handshake::ServerExtension> as rustls::msgs::handshake::HasServerExtensions>::alpn_protocol
Unexecuted instantiation: <rustls::msgs::handshake::ServerHelloPayload as rustls::msgs::handshake::HasServerExtensions>::alpn_protocol
2093
2094
0
    fn server_cert_type(&self) -> Option<&CertificateType> {
2095
0
        let ext = self.find_extension(ExtensionType::ServerCertificateType)?;
2096
0
        match ext {
2097
0
            ServerExtension::ServerCertType(req) => Some(req),
2098
0
            _ => None,
2099
        }
2100
0
    }
2101
2102
0
    fn client_cert_type(&self) -> Option<&CertificateType> {
2103
0
        let ext = self.find_extension(ExtensionType::ClientCertificateType)?;
2104
0
        match ext {
2105
0
            ServerExtension::ClientCertType(req) => Some(req),
2106
0
            _ => None,
2107
        }
2108
0
    }
2109
2110
0
    fn quic_params_extension(&self) -> Option<Vec<u8>> {
2111
0
        let ext = self
2112
0
            .find_extension(ExtensionType::TransportParameters)
2113
0
            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
2114
0
        match ext {
2115
0
            ServerExtension::TransportParameters(bytes)
2116
0
            | ServerExtension::TransportParametersDraft(bytes) => Some(bytes.to_vec()),
2117
0
            _ => None,
2118
        }
2119
0
    }
2120
2121
0
    fn server_ech_extension(&self) -> Option<ServerEncryptedClientHello> {
2122
0
        let ext = self.find_extension(ExtensionType::EncryptedClientHello)?;
2123
0
        match ext {
2124
0
            ServerExtension::EncryptedClientHello(ech) => Some(ech.clone()),
2125
0
            _ => None,
2126
        }
2127
0
    }
2128
2129
0
    fn early_data_extension_offered(&self) -> bool {
2130
0
        self.find_extension(ExtensionType::EarlyData)
2131
0
            .is_some()
2132
0
    }
2133
}
2134
2135
impl HasServerExtensions for Vec<ServerExtension> {
2136
0
    fn extensions(&self) -> &[ServerExtension] {
2137
0
        self
2138
0
    }
2139
}
2140
2141
impl TlsListElement for ClientCertificateType {
2142
    const SIZE_LEN: ListLength = ListLength::U8;
2143
}
2144
2145
wrapped_payload!(
2146
    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
2147
    ///
2148
    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
2149
    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
2150
    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2151
    ///
2152
    /// ```ignore
2153
    /// for name in distinguished_names {
2154
    ///     use x509_parser::prelude::FromDer;
2155
    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
2156
    /// }
2157
    /// ```
2158
    pub struct DistinguishedName,
2159
    PayloadU16,
2160
);
2161
2162
impl DistinguishedName {
2163
    /// Create a [`DistinguishedName`] after prepending its outer SEQUENCE encoding.
2164
    ///
2165
    /// This can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
2166
    ///
2167
    /// ```ignore
2168
    /// use x509_parser::prelude::FromDer;
2169
    /// println!("{}", x509_parser::x509::X509Name::from_der(dn.as_ref())?.1);
2170
    /// ```
2171
0
    pub fn in_sequence(bytes: &[u8]) -> Self {
2172
0
        Self(PayloadU16::new(wrap_in_sequence(bytes)))
2173
0
    }
2174
}
2175
2176
impl TlsListElement for DistinguishedName {
2177
    const SIZE_LEN: ListLength = ListLength::U16;
2178
}
2179
2180
#[derive(Debug)]
2181
pub struct CertificateRequestPayload {
2182
    pub(crate) certtypes: Vec<ClientCertificateType>,
2183
    pub(crate) sigschemes: Vec<SignatureScheme>,
2184
    pub(crate) canames: Vec<DistinguishedName>,
2185
}
2186
2187
impl Codec<'_> for CertificateRequestPayload {
2188
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2189
0
        self.certtypes.encode(bytes);
2190
0
        self.sigschemes.encode(bytes);
2191
0
        self.canames.encode(bytes);
2192
0
    }
2193
2194
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2195
0
        let certtypes = Vec::read(r)?;
2196
0
        let sigschemes = Vec::read(r)?;
2197
0
        let canames = Vec::read(r)?;
2198
2199
0
        if sigschemes.is_empty() {
2200
0
            warn!("meaningless CertificateRequest message");
2201
0
            Err(InvalidMessage::NoSignatureSchemes)
2202
        } else {
2203
0
            Ok(Self {
2204
0
                certtypes,
2205
0
                sigschemes,
2206
0
                canames,
2207
0
            })
2208
        }
2209
0
    }
2210
}
2211
2212
#[derive(Debug)]
2213
pub(crate) enum CertReqExtension {
2214
    SignatureAlgorithms(Vec<SignatureScheme>),
2215
    AuthorityNames(Vec<DistinguishedName>),
2216
    CertificateCompressionAlgorithms(Vec<CertificateCompressionAlgorithm>),
2217
    Unknown(UnknownExtension),
2218
}
2219
2220
impl CertReqExtension {
2221
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
2222
0
        match self {
2223
0
            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
2224
0
            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
2225
0
            Self::CertificateCompressionAlgorithms(_) => ExtensionType::CompressCertificate,
2226
0
            Self::Unknown(r) => r.typ,
2227
        }
2228
0
    }
2229
}
2230
2231
impl Codec<'_> for CertReqExtension {
2232
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2233
0
        self.ext_type().encode(bytes);
2234
0
2235
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2236
0
        match self {
2237
0
            Self::SignatureAlgorithms(r) => r.encode(nested.buf),
2238
0
            Self::AuthorityNames(r) => r.encode(nested.buf),
2239
0
            Self::CertificateCompressionAlgorithms(r) => r.encode(nested.buf),
2240
0
            Self::Unknown(r) => r.encode(nested.buf),
2241
        }
2242
0
    }
2243
2244
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2245
0
        let typ = ExtensionType::read(r)?;
2246
0
        let len = u16::read(r)? as usize;
2247
0
        let mut sub = r.sub(len)?;
2248
2249
0
        let ext = match typ {
2250
            ExtensionType::SignatureAlgorithms => {
2251
0
                let schemes = Vec::read(&mut sub)?;
2252
0
                if schemes.is_empty() {
2253
0
                    return Err(InvalidMessage::NoSignatureSchemes);
2254
0
                }
2255
0
                Self::SignatureAlgorithms(schemes)
2256
            }
2257
            ExtensionType::CertificateAuthorities => {
2258
0
                let cas = Vec::read(&mut sub)?;
2259
0
                Self::AuthorityNames(cas)
2260
            }
2261
            ExtensionType::CompressCertificate => {
2262
0
                Self::CertificateCompressionAlgorithms(Vec::read(&mut sub)?)
2263
            }
2264
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2265
        };
2266
2267
0
        sub.expect_empty("CertReqExtension")
2268
0
            .map(|_| ext)
2269
0
    }
2270
}
2271
2272
impl TlsListElement for CertReqExtension {
2273
    const SIZE_LEN: ListLength = ListLength::U16;
2274
}
2275
2276
#[derive(Debug)]
2277
pub struct CertificateRequestPayloadTls13 {
2278
    pub(crate) context: PayloadU8,
2279
    pub(crate) extensions: Vec<CertReqExtension>,
2280
}
2281
2282
impl Codec<'_> for CertificateRequestPayloadTls13 {
2283
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2284
0
        self.context.encode(bytes);
2285
0
        self.extensions.encode(bytes);
2286
0
    }
2287
2288
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2289
0
        let context = PayloadU8::read(r)?;
2290
0
        let extensions = Vec::read(r)?;
2291
2292
0
        Ok(Self {
2293
0
            context,
2294
0
            extensions,
2295
0
        })
2296
0
    }
2297
}
2298
2299
impl CertificateRequestPayloadTls13 {
2300
0
    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
2301
0
        self.extensions
2302
0
            .iter()
2303
0
            .find(|x| x.ext_type() == ext)
2304
0
    }
2305
2306
0
    pub(crate) fn sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
2307
0
        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
2308
0
        match ext {
2309
0
            CertReqExtension::SignatureAlgorithms(sa) => Some(sa),
2310
0
            _ => None,
2311
        }
2312
0
    }
2313
2314
0
    pub(crate) fn authorities_extension(&self) -> Option<&[DistinguishedName]> {
2315
0
        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
2316
0
        match ext {
2317
0
            CertReqExtension::AuthorityNames(an) => Some(an),
2318
0
            _ => None,
2319
        }
2320
0
    }
2321
2322
0
    pub(crate) fn certificate_compression_extension(
2323
0
        &self,
2324
0
    ) -> Option<&[CertificateCompressionAlgorithm]> {
2325
0
        let ext = self.find_extension(ExtensionType::CompressCertificate)?;
2326
0
        match ext {
2327
0
            CertReqExtension::CertificateCompressionAlgorithms(comps) => Some(comps),
2328
0
            _ => None,
2329
        }
2330
0
    }
2331
}
2332
2333
// -- NewSessionTicket --
2334
#[derive(Debug)]
2335
pub struct NewSessionTicketPayload {
2336
    pub(crate) lifetime_hint: u32,
2337
    // Tickets can be large (KB), so we deserialise this straight
2338
    // into an Arc, so it can be passed directly into the client's
2339
    // session object without copying.
2340
    pub(crate) ticket: Arc<PayloadU16>,
2341
}
2342
2343
impl NewSessionTicketPayload {
2344
    #[cfg(feature = "tls12")]
2345
0
    pub(crate) fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
2346
0
        Self {
2347
0
            lifetime_hint,
2348
0
            ticket: Arc::new(PayloadU16::new(ticket)),
2349
0
        }
2350
0
    }
2351
}
2352
2353
impl Codec<'_> for NewSessionTicketPayload {
2354
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2355
0
        self.lifetime_hint.encode(bytes);
2356
0
        self.ticket.encode(bytes);
2357
0
    }
2358
2359
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2360
0
        let lifetime = u32::read(r)?;
2361
0
        let ticket = Arc::new(PayloadU16::read(r)?);
2362
2363
0
        Ok(Self {
2364
0
            lifetime_hint: lifetime,
2365
0
            ticket,
2366
0
        })
2367
0
    }
2368
}
2369
2370
// -- NewSessionTicket electric boogaloo --
2371
#[derive(Debug)]
2372
pub(crate) enum NewSessionTicketExtension {
2373
    EarlyData(u32),
2374
    Unknown(UnknownExtension),
2375
}
2376
2377
impl NewSessionTicketExtension {
2378
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
2379
0
        match self {
2380
0
            Self::EarlyData(_) => ExtensionType::EarlyData,
2381
0
            Self::Unknown(r) => r.typ,
2382
        }
2383
0
    }
2384
}
2385
2386
impl Codec<'_> for NewSessionTicketExtension {
2387
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2388
0
        self.ext_type().encode(bytes);
2389
0
2390
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2391
0
        match self {
2392
0
            Self::EarlyData(r) => r.encode(nested.buf),
2393
0
            Self::Unknown(r) => r.encode(nested.buf),
2394
        }
2395
0
    }
2396
2397
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2398
0
        let typ = ExtensionType::read(r)?;
2399
0
        let len = u16::read(r)? as usize;
2400
0
        let mut sub = r.sub(len)?;
2401
2402
0
        let ext = match typ {
2403
0
            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
2404
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
2405
        };
2406
2407
0
        sub.expect_empty("NewSessionTicketExtension")
2408
0
            .map(|_| ext)
2409
0
    }
2410
}
2411
2412
impl TlsListElement for NewSessionTicketExtension {
2413
    const SIZE_LEN: ListLength = ListLength::U16;
2414
}
2415
2416
#[derive(Debug)]
2417
pub struct NewSessionTicketPayloadTls13 {
2418
    pub(crate) lifetime: u32,
2419
    pub(crate) age_add: u32,
2420
    pub(crate) nonce: PayloadU8,
2421
    pub(crate) ticket: Arc<PayloadU16>,
2422
    pub(crate) exts: Vec<NewSessionTicketExtension>,
2423
}
2424
2425
impl NewSessionTicketPayloadTls13 {
2426
0
    pub(crate) fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2427
0
        Self {
2428
0
            lifetime,
2429
0
            age_add,
2430
0
            nonce: PayloadU8::new(nonce),
2431
0
            ticket: Arc::new(PayloadU16::new(ticket)),
2432
0
            exts: vec![],
2433
0
        }
2434
0
    }
2435
2436
0
    pub(crate) fn has_duplicate_extension(&self) -> bool {
2437
0
        has_duplicates::<_, _, u16>(
2438
0
            self.exts
2439
0
                .iter()
2440
0
                .map(|ext| ext.ext_type()),
2441
0
        )
2442
0
    }
2443
2444
0
    pub(crate) fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2445
0
        self.exts
2446
0
            .iter()
2447
0
            .find(|x| x.ext_type() == ext)
2448
0
    }
2449
2450
0
    pub(crate) fn max_early_data_size(&self) -> Option<u32> {
2451
0
        let ext = self.find_extension(ExtensionType::EarlyData)?;
2452
0
        match ext {
2453
0
            NewSessionTicketExtension::EarlyData(sz) => Some(*sz),
2454
0
            _ => None,
2455
        }
2456
0
    }
2457
}
2458
2459
impl Codec<'_> for NewSessionTicketPayloadTls13 {
2460
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2461
0
        self.lifetime.encode(bytes);
2462
0
        self.age_add.encode(bytes);
2463
0
        self.nonce.encode(bytes);
2464
0
        self.ticket.encode(bytes);
2465
0
        self.exts.encode(bytes);
2466
0
    }
2467
2468
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2469
0
        let lifetime = u32::read(r)?;
2470
0
        let age_add = u32::read(r)?;
2471
0
        let nonce = PayloadU8::read(r)?;
2472
0
        let ticket = Arc::new(PayloadU16::read(r)?);
2473
0
        let exts = Vec::read(r)?;
2474
2475
0
        Ok(Self {
2476
0
            lifetime,
2477
0
            age_add,
2478
0
            nonce,
2479
0
            ticket,
2480
0
            exts,
2481
0
        })
2482
0
    }
2483
}
2484
2485
// -- RFC6066 certificate status types
2486
2487
/// Only supports OCSP
2488
#[derive(Debug)]
2489
pub struct CertificateStatus<'a> {
2490
    pub(crate) ocsp_response: PayloadU24<'a>,
2491
}
2492
2493
impl<'a> Codec<'a> for CertificateStatus<'a> {
2494
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2495
0
        CertificateStatusType::OCSP.encode(bytes);
2496
0
        self.ocsp_response.encode(bytes);
2497
0
    }
2498
2499
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2500
0
        let typ = CertificateStatusType::read(r)?;
2501
2502
0
        match typ {
2503
            CertificateStatusType::OCSP => Ok(Self {
2504
0
                ocsp_response: PayloadU24::read(r)?,
2505
            }),
2506
0
            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2507
        }
2508
0
    }
2509
}
2510
2511
impl<'a> CertificateStatus<'a> {
2512
0
    pub(crate) fn new(ocsp: &'a [u8]) -> Self {
2513
0
        CertificateStatus {
2514
0
            ocsp_response: PayloadU24(Payload::Borrowed(ocsp)),
2515
0
        }
2516
0
    }
2517
2518
    #[cfg(feature = "tls12")]
2519
0
    pub(crate) fn into_inner(self) -> Vec<u8> {
2520
0
        self.ocsp_response.0.into_vec()
2521
0
    }
2522
2523
0
    pub(crate) fn into_owned(self) -> CertificateStatus<'static> {
2524
0
        CertificateStatus {
2525
0
            ocsp_response: self.ocsp_response.into_owned(),
2526
0
        }
2527
0
    }
2528
}
2529
2530
// -- RFC8879 compressed certificates
2531
2532
#[derive(Debug)]
2533
pub struct CompressedCertificatePayload<'a> {
2534
    pub(crate) alg: CertificateCompressionAlgorithm,
2535
    pub(crate) uncompressed_len: u32,
2536
    pub(crate) compressed: PayloadU24<'a>,
2537
}
2538
2539
impl<'a> Codec<'a> for CompressedCertificatePayload<'a> {
2540
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2541
0
        self.alg.encode(bytes);
2542
0
        codec::u24(self.uncompressed_len).encode(bytes);
2543
0
        self.compressed.encode(bytes);
2544
0
    }
2545
2546
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2547
0
        Ok(Self {
2548
0
            alg: CertificateCompressionAlgorithm::read(r)?,
2549
0
            uncompressed_len: codec::u24::read(r)?.0,
2550
0
            compressed: PayloadU24::read(r)?,
2551
        })
2552
0
    }
2553
}
2554
2555
impl CompressedCertificatePayload<'_> {
2556
0
    fn into_owned(self) -> CompressedCertificatePayload<'static> {
2557
0
        CompressedCertificatePayload {
2558
0
            compressed: self.compressed.into_owned(),
2559
0
            ..self
2560
0
        }
2561
0
    }
2562
2563
0
    pub(crate) fn as_borrowed(&self) -> CompressedCertificatePayload<'_> {
2564
0
        CompressedCertificatePayload {
2565
0
            alg: self.alg,
2566
0
            uncompressed_len: self.uncompressed_len,
2567
0
            compressed: PayloadU24(Payload::Borrowed(self.compressed.0.bytes())),
2568
0
        }
2569
0
    }
2570
}
2571
2572
#[derive(Debug)]
2573
pub enum HandshakePayload<'a> {
2574
    HelloRequest,
2575
    ClientHello(ClientHelloPayload),
2576
    ServerHello(ServerHelloPayload),
2577
    HelloRetryRequest(HelloRetryRequest),
2578
    Certificate(CertificateChain<'a>),
2579
    CertificateTls13(CertificatePayloadTls13<'a>),
2580
    CompressedCertificate(CompressedCertificatePayload<'a>),
2581
    ServerKeyExchange(ServerKeyExchangePayload),
2582
    CertificateRequest(CertificateRequestPayload),
2583
    CertificateRequestTls13(CertificateRequestPayloadTls13),
2584
    CertificateVerify(DigitallySignedStruct),
2585
    ServerHelloDone,
2586
    EndOfEarlyData,
2587
    ClientKeyExchange(Payload<'a>),
2588
    NewSessionTicket(NewSessionTicketPayload),
2589
    NewSessionTicketTls13(NewSessionTicketPayloadTls13),
2590
    EncryptedExtensions(Vec<ServerExtension>),
2591
    KeyUpdate(KeyUpdateRequest),
2592
    Finished(Payload<'a>),
2593
    CertificateStatus(CertificateStatus<'a>),
2594
    MessageHash(Payload<'a>),
2595
    Unknown(Payload<'a>),
2596
}
2597
2598
impl HandshakePayload<'_> {
2599
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2600
        use self::HandshakePayload::*;
2601
0
        match self {
2602
0
            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2603
0
            ClientHello(x) => x.encode(bytes),
2604
0
            ServerHello(x) => x.encode(bytes),
2605
0
            HelloRetryRequest(x) => x.encode(bytes),
2606
0
            Certificate(x) => x.encode(bytes),
2607
0
            CertificateTls13(x) => x.encode(bytes),
2608
0
            CompressedCertificate(x) => x.encode(bytes),
2609
0
            ServerKeyExchange(x) => x.encode(bytes),
2610
0
            ClientKeyExchange(x) => x.encode(bytes),
2611
0
            CertificateRequest(x) => x.encode(bytes),
2612
0
            CertificateRequestTls13(x) => x.encode(bytes),
2613
0
            CertificateVerify(x) => x.encode(bytes),
2614
0
            NewSessionTicket(x) => x.encode(bytes),
2615
0
            NewSessionTicketTls13(x) => x.encode(bytes),
2616
0
            EncryptedExtensions(x) => x.encode(bytes),
2617
0
            KeyUpdate(x) => x.encode(bytes),
2618
0
            Finished(x) => x.encode(bytes),
2619
0
            CertificateStatus(x) => x.encode(bytes),
2620
0
            MessageHash(x) => x.encode(bytes),
2621
0
            Unknown(x) => x.encode(bytes),
2622
        }
2623
0
    }
2624
2625
0
    fn into_owned(self) -> HandshakePayload<'static> {
2626
        use HandshakePayload::*;
2627
2628
0
        match self {
2629
0
            HelloRequest => HelloRequest,
2630
0
            ClientHello(x) => ClientHello(x),
2631
0
            ServerHello(x) => ServerHello(x),
2632
0
            HelloRetryRequest(x) => HelloRetryRequest(x),
2633
0
            Certificate(x) => Certificate(x.into_owned()),
2634
0
            CertificateTls13(x) => CertificateTls13(x.into_owned()),
2635
0
            CompressedCertificate(x) => CompressedCertificate(x.into_owned()),
2636
0
            ServerKeyExchange(x) => ServerKeyExchange(x),
2637
0
            CertificateRequest(x) => CertificateRequest(x),
2638
0
            CertificateRequestTls13(x) => CertificateRequestTls13(x),
2639
0
            CertificateVerify(x) => CertificateVerify(x),
2640
0
            ServerHelloDone => ServerHelloDone,
2641
0
            EndOfEarlyData => EndOfEarlyData,
2642
0
            ClientKeyExchange(x) => ClientKeyExchange(x.into_owned()),
2643
0
            NewSessionTicket(x) => NewSessionTicket(x),
2644
0
            NewSessionTicketTls13(x) => NewSessionTicketTls13(x),
2645
0
            EncryptedExtensions(x) => EncryptedExtensions(x),
2646
0
            KeyUpdate(x) => KeyUpdate(x),
2647
0
            Finished(x) => Finished(x.into_owned()),
2648
0
            CertificateStatus(x) => CertificateStatus(x.into_owned()),
2649
0
            MessageHash(x) => MessageHash(x.into_owned()),
2650
0
            Unknown(x) => Unknown(x.into_owned()),
2651
        }
2652
0
    }
2653
}
2654
2655
#[derive(Debug)]
2656
pub struct HandshakeMessagePayload<'a> {
2657
    pub typ: HandshakeType,
2658
    pub payload: HandshakePayload<'a>,
2659
}
2660
2661
impl<'a> Codec<'a> for HandshakeMessagePayload<'a> {
2662
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2663
0
        self.payload_encode(bytes, Encoding::Standard);
2664
0
    }
2665
2666
0
    fn read(r: &mut Reader<'a>) -> Result<Self, InvalidMessage> {
2667
0
        Self::read_version(r, ProtocolVersion::TLSv1_2)
2668
0
    }
2669
}
2670
2671
impl<'a> HandshakeMessagePayload<'a> {
2672
0
    pub(crate) fn read_version(
2673
0
        r: &mut Reader<'a>,
2674
0
        vers: ProtocolVersion,
2675
0
    ) -> Result<Self, InvalidMessage> {
2676
0
        let mut typ = HandshakeType::read(r)?;
2677
0
        let len = codec::u24::read(r)?.0 as usize;
2678
0
        let mut sub = r.sub(len)?;
2679
2680
0
        let payload = match typ {
2681
0
            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2682
            HandshakeType::ClientHello => {
2683
0
                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2684
            }
2685
            HandshakeType::ServerHello => {
2686
0
                let version = ProtocolVersion::read(&mut sub)?;
2687
0
                let random = Random::read(&mut sub)?;
2688
2689
0
                if random == HELLO_RETRY_REQUEST_RANDOM {
2690
0
                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2691
0
                    hrr.legacy_version = version;
2692
0
                    typ = HandshakeType::HelloRetryRequest;
2693
0
                    HandshakePayload::HelloRetryRequest(hrr)
2694
                } else {
2695
0
                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2696
0
                    shp.legacy_version = version;
2697
0
                    shp.random = random;
2698
0
                    HandshakePayload::ServerHello(shp)
2699
                }
2700
            }
2701
0
            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2702
0
                let p = CertificatePayloadTls13::read(&mut sub)?;
2703
0
                HandshakePayload::CertificateTls13(p)
2704
            }
2705
            HandshakeType::Certificate => {
2706
0
                HandshakePayload::Certificate(CertificateChain::read(&mut sub)?)
2707
            }
2708
            HandshakeType::ServerKeyExchange => {
2709
0
                let p = ServerKeyExchangePayload::read(&mut sub)?;
2710
0
                HandshakePayload::ServerKeyExchange(p)
2711
            }
2712
            HandshakeType::ServerHelloDone => {
2713
0
                sub.expect_empty("ServerHelloDone")?;
2714
0
                HandshakePayload::ServerHelloDone
2715
            }
2716
            HandshakeType::ClientKeyExchange => {
2717
0
                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2718
            }
2719
0
            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2720
0
                let p = CertificateRequestPayloadTls13::read(&mut sub)?;
2721
0
                HandshakePayload::CertificateRequestTls13(p)
2722
            }
2723
            HandshakeType::CertificateRequest => {
2724
0
                let p = CertificateRequestPayload::read(&mut sub)?;
2725
0
                HandshakePayload::CertificateRequest(p)
2726
            }
2727
            HandshakeType::CompressedCertificate => HandshakePayload::CompressedCertificate(
2728
0
                CompressedCertificatePayload::read(&mut sub)?,
2729
            ),
2730
            HandshakeType::CertificateVerify => {
2731
0
                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2732
            }
2733
0
            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2734
0
                let p = NewSessionTicketPayloadTls13::read(&mut sub)?;
2735
0
                HandshakePayload::NewSessionTicketTls13(p)
2736
            }
2737
            HandshakeType::NewSessionTicket => {
2738
0
                let p = NewSessionTicketPayload::read(&mut sub)?;
2739
0
                HandshakePayload::NewSessionTicket(p)
2740
            }
2741
            HandshakeType::EncryptedExtensions => {
2742
0
                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2743
            }
2744
            HandshakeType::KeyUpdate => {
2745
0
                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2746
            }
2747
            HandshakeType::EndOfEarlyData => {
2748
0
                sub.expect_empty("EndOfEarlyData")?;
2749
0
                HandshakePayload::EndOfEarlyData
2750
            }
2751
0
            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2752
            HandshakeType::CertificateStatus => {
2753
0
                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2754
            }
2755
            HandshakeType::MessageHash => {
2756
                // does not appear on the wire
2757
0
                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2758
            }
2759
            HandshakeType::HelloRetryRequest => {
2760
                // not legal on wire
2761
0
                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2762
            }
2763
0
            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2764
        };
2765
2766
0
        sub.expect_empty("HandshakeMessagePayload")
2767
0
            .map(|_| Self { typ, payload })
2768
0
    }
2769
2770
0
    pub(crate) fn encoding_for_binder_signing(&self) -> Vec<u8> {
2771
0
        let mut ret = self.get_encoding();
2772
0
        let ret_len = ret.len() - self.total_binder_length();
2773
0
        ret.truncate(ret_len);
2774
0
        ret
2775
0
    }
2776
2777
0
    pub(crate) fn total_binder_length(&self) -> usize {
2778
0
        match &self.payload {
2779
0
            HandshakePayload::ClientHello(ch) => match ch.extensions.last() {
2780
0
                Some(ClientExtension::PresharedKey(offer)) => {
2781
0
                    let mut binders_encoding = Vec::new();
2782
0
                    offer
2783
0
                        .binders
2784
0
                        .encode(&mut binders_encoding);
2785
0
                    binders_encoding.len()
2786
                }
2787
0
                _ => 0,
2788
            },
2789
0
            _ => 0,
2790
        }
2791
0
    }
2792
2793
0
    pub(crate) fn payload_encode(&self, bytes: &mut Vec<u8>, encoding: Encoding) {
2794
0
        // output type, length, and encoded payload
2795
0
        match self.typ {
2796
0
            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2797
0
            _ => self.typ,
2798
        }
2799
0
        .encode(bytes);
2800
0
2801
0
        let nested = LengthPrefixedBuffer::new(
2802
0
            ListLength::U24 {
2803
0
                max: usize::MAX,
2804
0
                error: InvalidMessage::MessageTooLarge,
2805
0
            },
2806
0
            bytes,
2807
0
        );
2808
0
2809
0
        match &self.payload {
2810
            // for Server Hello and HelloRetryRequest payloads we need to encode the payload
2811
            // differently based on the purpose of the encoding.
2812
0
            HandshakePayload::ServerHello(payload) => payload.payload_encode(nested.buf, encoding),
2813
0
            HandshakePayload::HelloRetryRequest(payload) => {
2814
0
                payload.payload_encode(nested.buf, encoding)
2815
            }
2816
2817
            // All other payload types are encoded the same regardless of purpose.
2818
0
            _ => self.payload.encode(nested.buf),
2819
        }
2820
0
    }
2821
2822
0
    pub(crate) fn build_handshake_hash(hash: &[u8]) -> Self {
2823
0
        Self {
2824
0
            typ: HandshakeType::MessageHash,
2825
0
            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2826
0
        }
2827
0
    }
2828
2829
0
    pub(crate) fn into_owned(self) -> HandshakeMessagePayload<'static> {
2830
0
        let Self { typ, payload } = self;
2831
0
        HandshakeMessagePayload {
2832
0
            typ,
2833
0
            payload: payload.into_owned(),
2834
0
        }
2835
0
    }
2836
}
2837
2838
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
2839
pub struct HpkeSymmetricCipherSuite {
2840
    pub kdf_id: HpkeKdf,
2841
    pub aead_id: HpkeAead,
2842
}
2843
2844
impl Codec<'_> for HpkeSymmetricCipherSuite {
2845
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2846
0
        self.kdf_id.encode(bytes);
2847
0
        self.aead_id.encode(bytes);
2848
0
    }
2849
2850
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2851
0
        Ok(Self {
2852
0
            kdf_id: HpkeKdf::read(r)?,
2853
0
            aead_id: HpkeAead::read(r)?,
2854
        })
2855
0
    }
2856
}
2857
2858
impl TlsListElement for HpkeSymmetricCipherSuite {
2859
    const SIZE_LEN: ListLength = ListLength::U16;
2860
}
2861
2862
#[derive(Clone, Debug, PartialEq)]
2863
pub struct HpkeKeyConfig {
2864
    pub config_id: u8,
2865
    pub kem_id: HpkeKem,
2866
    pub public_key: PayloadU16,
2867
    pub symmetric_cipher_suites: Vec<HpkeSymmetricCipherSuite>,
2868
}
2869
2870
impl Codec<'_> for HpkeKeyConfig {
2871
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2872
0
        self.config_id.encode(bytes);
2873
0
        self.kem_id.encode(bytes);
2874
0
        self.public_key.encode(bytes);
2875
0
        self.symmetric_cipher_suites
2876
0
            .encode(bytes);
2877
0
    }
2878
2879
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2880
0
        Ok(Self {
2881
0
            config_id: u8::read(r)?,
2882
0
            kem_id: HpkeKem::read(r)?,
2883
0
            public_key: PayloadU16::read(r)?,
2884
0
            symmetric_cipher_suites: Vec::<HpkeSymmetricCipherSuite>::read(r)?,
2885
        })
2886
0
    }
2887
}
2888
2889
#[derive(Clone, Debug, PartialEq)]
2890
pub struct EchConfigContents {
2891
    pub key_config: HpkeKeyConfig,
2892
    pub maximum_name_length: u8,
2893
    pub public_name: DnsName<'static>,
2894
    pub extensions: Vec<EchConfigExtension>,
2895
}
2896
2897
impl EchConfigContents {
2898
    /// Returns true if there is more than one extension of a given
2899
    /// type.
2900
0
    pub(crate) fn has_duplicate_extension(&self) -> bool {
2901
0
        has_duplicates::<_, _, u16>(
2902
0
            self.extensions
2903
0
                .iter()
2904
0
                .map(|ext| ext.ext_type()),
2905
0
        )
2906
0
    }
2907
2908
    /// Returns true if there is at least one mandatory unsupported extension.
2909
0
    pub(crate) fn has_unknown_mandatory_extension(&self) -> bool {
2910
0
        self.extensions
2911
0
            .iter()
2912
0
            // An extension is considered mandatory if the high bit of its type is set.
2913
0
            .any(|ext| {
2914
0
                matches!(ext.ext_type(), ExtensionType::Unknown(_))
2915
0
                    && u16::from(ext.ext_type()) & 0x8000 != 0
2916
0
            })
2917
0
    }
2918
}
2919
2920
impl Codec<'_> for EchConfigContents {
2921
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2922
0
        self.key_config.encode(bytes);
2923
0
        self.maximum_name_length.encode(bytes);
2924
0
        let dns_name = &self.public_name.borrow();
2925
0
        PayloadU8::encode_slice(dns_name.as_ref().as_ref(), bytes);
2926
0
        self.extensions.encode(bytes);
2927
0
    }
2928
2929
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2930
0
        Ok(Self {
2931
0
            key_config: HpkeKeyConfig::read(r)?,
2932
0
            maximum_name_length: u8::read(r)?,
2933
            public_name: {
2934
0
                DnsName::try_from(PayloadU8::read(r)?.0.as_slice())
2935
0
                    .map_err(|_| InvalidMessage::InvalidServerName)?
2936
0
                    .to_owned()
2937
0
            },
2938
0
            extensions: Vec::read(r)?,
2939
        })
2940
0
    }
2941
}
2942
2943
/// An encrypted client hello (ECH) config.
2944
#[derive(Clone, Debug, PartialEq)]
2945
pub enum EchConfigPayload {
2946
    /// A recognized V18 ECH configuration.
2947
    V18(EchConfigContents),
2948
    /// An unknown version ECH configuration.
2949
    Unknown {
2950
        version: EchVersion,
2951
        contents: PayloadU16,
2952
    },
2953
}
2954
2955
impl TlsListElement for EchConfigPayload {
2956
    const SIZE_LEN: ListLength = ListLength::U16;
2957
}
2958
2959
impl Codec<'_> for EchConfigPayload {
2960
0
    fn encode(&self, bytes: &mut Vec<u8>) {
2961
0
        match self {
2962
0
            Self::V18(c) => {
2963
0
                // Write the version, the length, and the contents.
2964
0
                EchVersion::V18.encode(bytes);
2965
0
                let inner = LengthPrefixedBuffer::new(ListLength::U16, bytes);
2966
0
                c.encode(inner.buf);
2967
0
            }
2968
0
            Self::Unknown { version, contents } => {
2969
0
                // Unknown configuration versions are opaque.
2970
0
                version.encode(bytes);
2971
0
                contents.encode(bytes);
2972
0
            }
2973
        }
2974
0
    }
2975
2976
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
2977
0
        let version = EchVersion::read(r)?;
2978
0
        let length = u16::read(r)?;
2979
0
        let mut contents = r.sub(length as usize)?;
2980
2981
0
        Ok(match version {
2982
0
            EchVersion::V18 => Self::V18(EchConfigContents::read(&mut contents)?),
2983
            _ => {
2984
                // Note: we don't PayloadU16::read() here because we've already read the length prefix.
2985
0
                let data = PayloadU16::new(contents.rest().into());
2986
0
                Self::Unknown {
2987
0
                    version,
2988
0
                    contents: data,
2989
0
                }
2990
            }
2991
        })
2992
0
    }
2993
}
2994
2995
#[derive(Clone, Debug, PartialEq)]
2996
pub enum EchConfigExtension {
2997
    Unknown(UnknownExtension),
2998
}
2999
3000
impl EchConfigExtension {
3001
0
    pub(crate) fn ext_type(&self) -> ExtensionType {
3002
0
        match self {
3003
0
            Self::Unknown(r) => r.typ,
3004
0
        }
3005
0
    }
3006
}
3007
3008
impl Codec<'_> for EchConfigExtension {
3009
0
    fn encode(&self, bytes: &mut Vec<u8>) {
3010
0
        self.ext_type().encode(bytes);
3011
0
3012
0
        let nested = LengthPrefixedBuffer::new(ListLength::U16, bytes);
3013
0
        match self {
3014
0
            Self::Unknown(r) => r.encode(nested.buf),
3015
0
        }
3016
0
    }
3017
3018
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3019
0
        let typ = ExtensionType::read(r)?;
3020
0
        let len = u16::read(r)? as usize;
3021
0
        let mut sub = r.sub(len)?;
3022
3023
        #[allow(clippy::match_single_binding)] // Future-proofing.
3024
0
        let ext = match typ {
3025
0
            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
3026
0
        };
3027
0
3028
0
        sub.expect_empty("EchConfigExtension")
3029
0
            .map(|_| ext)
3030
0
    }
3031
}
3032
3033
impl TlsListElement for EchConfigExtension {
3034
    const SIZE_LEN: ListLength = ListLength::U16;
3035
}
3036
3037
/// Representation of the `ECHClientHello` client extension specified in
3038
/// [draft-ietf-tls-esni Section 5].
3039
///
3040
/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3041
#[derive(Clone, Debug)]
3042
pub enum EncryptedClientHello {
3043
    /// A `ECHClientHello` with type [EchClientHelloType::ClientHelloOuter].
3044
    Outer(EncryptedClientHelloOuter),
3045
    /// An empty `ECHClientHello` with type [EchClientHelloType::ClientHelloInner].
3046
    ///
3047
    /// This variant has no payload.
3048
    Inner,
3049
}
3050
3051
impl Codec<'_> for EncryptedClientHello {
3052
0
    fn encode(&self, bytes: &mut Vec<u8>) {
3053
0
        match self {
3054
0
            Self::Outer(payload) => {
3055
0
                EchClientHelloType::ClientHelloOuter.encode(bytes);
3056
0
                payload.encode(bytes);
3057
0
            }
3058
0
            Self::Inner => {
3059
0
                EchClientHelloType::ClientHelloInner.encode(bytes);
3060
0
                // Empty payload.
3061
0
            }
3062
        }
3063
0
    }
3064
3065
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3066
0
        match EchClientHelloType::read(r)? {
3067
            EchClientHelloType::ClientHelloOuter => {
3068
0
                Ok(Self::Outer(EncryptedClientHelloOuter::read(r)?))
3069
            }
3070
0
            EchClientHelloType::ClientHelloInner => Ok(Self::Inner),
3071
0
            _ => Err(InvalidMessage::InvalidContentType),
3072
        }
3073
0
    }
3074
}
3075
3076
/// Representation of the ECHClientHello extension with type outer specified in
3077
/// [draft-ietf-tls-esni Section 5].
3078
///
3079
/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3080
#[derive(Clone, Debug)]
3081
pub struct EncryptedClientHelloOuter {
3082
    /// The cipher suite used to encrypt ClientHelloInner. Must match a value from
3083
    /// ECHConfigContents.cipher_suites list.
3084
    pub cipher_suite: HpkeSymmetricCipherSuite,
3085
    /// The ECHConfigContents.key_config.config_id for the chosen ECHConfig.
3086
    pub config_id: u8,
3087
    /// The HPKE encapsulated key, used by servers to decrypt the corresponding payload field.
3088
    /// This field is empty in a ClientHelloOuter sent in response to a HelloRetryRequest.
3089
    pub enc: PayloadU16,
3090
    /// The serialized and encrypted ClientHelloInner structure, encrypted using HPKE.
3091
    pub payload: PayloadU16,
3092
}
3093
3094
impl Codec<'_> for EncryptedClientHelloOuter {
3095
0
    fn encode(&self, bytes: &mut Vec<u8>) {
3096
0
        self.cipher_suite.encode(bytes);
3097
0
        self.config_id.encode(bytes);
3098
0
        self.enc.encode(bytes);
3099
0
        self.payload.encode(bytes);
3100
0
    }
3101
3102
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3103
0
        Ok(Self {
3104
0
            cipher_suite: HpkeSymmetricCipherSuite::read(r)?,
3105
0
            config_id: u8::read(r)?,
3106
0
            enc: PayloadU16::read(r)?,
3107
0
            payload: PayloadU16::read(r)?,
3108
        })
3109
0
    }
3110
}
3111
3112
/// Representation of the ECHEncryptedExtensions extension specified in
3113
/// [draft-ietf-tls-esni Section 5].
3114
///
3115
/// [draft-ietf-tls-esni Section 5]: <https://www.ietf.org/archive/id/draft-ietf-tls-esni-18.html#section-5>
3116
#[derive(Clone, Debug)]
3117
pub struct ServerEncryptedClientHello {
3118
    pub(crate) retry_configs: Vec<EchConfigPayload>,
3119
}
3120
3121
impl Codec<'_> for ServerEncryptedClientHello {
3122
0
    fn encode(&self, bytes: &mut Vec<u8>) {
3123
0
        self.retry_configs.encode(bytes);
3124
0
    }
3125
3126
0
    fn read(r: &mut Reader<'_>) -> Result<Self, InvalidMessage> {
3127
0
        Ok(Self {
3128
0
            retry_configs: Vec::<EchConfigPayload>::read(r)?,
3129
        })
3130
0
    }
3131
}
3132
3133
/// The method of encoding to use for a handshake message.
3134
///
3135
/// In some cases a handshake message may be encoded differently depending on the purpose
3136
/// the encoded message is being used for. For example, a [ServerHelloPayload] may be encoded
3137
/// with the last 8 bytes of the random zeroed out when being encoded for ECH confirmation.
3138
pub(crate) enum Encoding {
3139
    /// Standard RFC 8446 encoding.
3140
    Standard,
3141
    /// Encoding for ECH confirmation.
3142
    EchConfirmation,
3143
    /// Encoding for ECH inner client hello.
3144
    EchInnerHello { to_compress: Vec<ExtensionType> },
3145
}
3146
3147
0
fn has_duplicates<I: IntoIterator<Item = E>, E: Into<T>, T: Eq + Ord>(iter: I) -> bool {
3148
0
    let mut seen = BTreeSet::new();
3149
3150
0
    for x in iter {
3151
0
        if !seen.insert(x.into()) {
3152
0
            return true;
3153
0
        }
3154
    }
3155
3156
0
    false
3157
0
}
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::ServerName>, <[rustls::msgs::handshake::ServerName] as rustls::msgs::handshake::ConvertServerNameList>::has_duplicate_names_for_type::{closure#0}>, rustls::msgs::enums::ServerNameType, u8>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::KeyShareEntry>, <rustls::msgs::handshake::ClientHelloPayload>::has_keyshare_extension_with_duplicates::{closure#0}::{closure#0}>, u16, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::ClientExtension>, <rustls::msgs::handshake::ClientHelloPayload>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::ServerExtension>, <alloc::vec::Vec<rustls::msgs::handshake::ServerExtension> as rustls::msgs::handshake::HasServerExtensions>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::ServerExtension>, <rustls::msgs::handshake::ServerHelloPayload as rustls::msgs::handshake::HasServerExtensions>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::EchConfigExtension>, <rustls::msgs::handshake::EchConfigContents>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::HelloRetryExtension>, <rustls::msgs::handshake::HelloRetryRequest>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::CertificateExtension>, <rustls::msgs::handshake::CertificateEntry>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::map::Map<core::slice::iter::Iter<rustls::msgs::handshake::NewSessionTicketExtension>, <rustls::msgs::handshake::NewSessionTicketPayloadTls13>::has_duplicate_extension::{closure#0}>, rustls::msgs::enums::ExtensionType, u16>
Unexecuted instantiation: rustls::msgs::handshake::has_duplicates::<core::iter::adapters::cloned::Cloned<core::slice::iter::Iter<rustls::enums::CertificateCompressionAlgorithm>>, rustls::enums::CertificateCompressionAlgorithm, u16>
3158
3159
#[cfg(test)]
3160
mod tests {
3161
    use super::*;
3162
3163
    #[test]
3164
    fn test_ech_config_dupe_exts() {
3165
        let unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3166
            typ: ExtensionType::Unknown(0x42),
3167
            payload: Payload::new(vec![0x42]),
3168
        });
3169
        let mut config = config_template();
3170
        config
3171
            .extensions
3172
            .push(unknown_ext.clone());
3173
        config.extensions.push(unknown_ext);
3174
3175
        assert!(config.has_duplicate_extension());
3176
        assert!(!config.has_unknown_mandatory_extension());
3177
    }
3178
3179
    #[test]
3180
    fn test_ech_config_mandatory_exts() {
3181
        let mandatory_unknown_ext = EchConfigExtension::Unknown(UnknownExtension {
3182
            typ: ExtensionType::Unknown(0x42 | 0x8000), // Note: high bit set.
3183
            payload: Payload::new(vec![0x42]),
3184
        });
3185
        let mut config = config_template();
3186
        config
3187
            .extensions
3188
            .push(mandatory_unknown_ext);
3189
3190
        assert!(!config.has_duplicate_extension());
3191
        assert!(config.has_unknown_mandatory_extension());
3192
    }
3193
3194
    fn config_template() -> EchConfigContents {
3195
        EchConfigContents {
3196
            key_config: HpkeKeyConfig {
3197
                config_id: 0,
3198
                kem_id: HpkeKem::DHKEM_P256_HKDF_SHA256,
3199
                public_key: PayloadU16(b"xxx".into()),
3200
                symmetric_cipher_suites: vec![HpkeSymmetricCipherSuite {
3201
                    kdf_id: HpkeKdf::HKDF_SHA256,
3202
                    aead_id: HpkeAead::AES_128_GCM,
3203
                }],
3204
            },
3205
            maximum_name_length: 0,
3206
            public_name: DnsName::try_from("example.com").unwrap(),
3207
            extensions: vec![],
3208
        }
3209
    }
3210
}