Coverage Report

Created: 2026-01-10 06:41

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/h2/src/frame/headers.rs
Line
Count
Source
1
use super::{util, StreamDependency, StreamId};
2
use crate::ext::Protocol;
3
use crate::frame::{Error, Frame, Head, Kind};
4
use crate::hpack::{self, BytesStr};
5
6
use http::header::{self, HeaderName, HeaderValue};
7
use http::{uri, HeaderMap, Method, Request, StatusCode, Uri};
8
9
use bytes::{Buf, BufMut, Bytes, BytesMut};
10
11
use std::fmt;
12
use std::io::Cursor;
13
14
type EncodeBuf<'a> = bytes::buf::Limit<&'a mut BytesMut>;
15
16
/// Header frame
17
///
18
/// This could be either a request or a response.
19
#[derive(Eq, PartialEq)]
20
pub struct Headers {
21
    /// The ID of the stream with which this frame is associated.
22
    stream_id: StreamId,
23
24
    /// The stream dependency information, if any.
25
    stream_dep: Option<StreamDependency>,
26
27
    /// The header block fragment
28
    header_block: HeaderBlock,
29
30
    /// The associated flags
31
    flags: HeadersFlag,
32
}
33
34
#[derive(Copy, Clone, Eq, PartialEq)]
35
pub struct HeadersFlag(u8);
36
37
#[derive(Eq, PartialEq)]
38
pub struct PushPromise {
39
    /// The ID of the stream with which this frame is associated.
40
    stream_id: StreamId,
41
42
    /// The ID of the stream being reserved by this PushPromise.
43
    promised_id: StreamId,
44
45
    /// The header block fragment
46
    header_block: HeaderBlock,
47
48
    /// The associated flags
49
    flags: PushPromiseFlag,
50
}
51
52
#[derive(Copy, Clone, Eq, PartialEq)]
53
pub struct PushPromiseFlag(u8);
54
55
#[derive(Debug)]
56
pub struct Continuation {
57
    /// Stream ID of continuation frame
58
    stream_id: StreamId,
59
60
    header_block: EncodingHeaderBlock,
61
}
62
63
// TODO: These fields shouldn't be `pub`
64
#[derive(Debug, Default, Eq, PartialEq)]
65
pub struct Pseudo {
66
    // Request
67
    pub method: Option<Method>,
68
    pub scheme: Option<BytesStr>,
69
    pub authority: Option<BytesStr>,
70
    pub path: Option<BytesStr>,
71
    pub protocol: Option<Protocol>,
72
73
    // Response
74
    pub status: Option<StatusCode>,
75
}
76
77
#[derive(Debug)]
78
pub struct Iter {
79
    /// Pseudo headers
80
    pseudo: Option<Pseudo>,
81
82
    /// Header fields
83
    fields: header::IntoIter<HeaderValue>,
84
}
85
86
#[derive(Debug, PartialEq, Eq)]
87
struct HeaderBlock {
88
    /// The decoded header fields
89
    fields: HeaderMap,
90
91
    /// Precomputed size of all of our header fields, for perf reasons
92
    field_size: usize,
93
94
    /// Set to true if decoding went over the max header list size.
95
    is_over_size: bool,
96
97
    /// Pseudo headers, these are broken out as they must be sent as part of the
98
    /// headers frame.
99
    pseudo: Pseudo,
100
}
101
102
#[derive(Debug)]
103
struct EncodingHeaderBlock {
104
    hpack: Bytes,
105
}
106
107
const END_STREAM: u8 = 0x1;
108
const END_HEADERS: u8 = 0x4;
109
const PADDED: u8 = 0x8;
110
const PRIORITY: u8 = 0x20;
111
const ALL: u8 = END_STREAM | END_HEADERS | PADDED | PRIORITY;
112
113
// ===== impl Headers =====
114
115
impl Headers {
116
    /// Create a new HEADERS frame
117
440k
    pub fn new(stream_id: StreamId, pseudo: Pseudo, fields: HeaderMap) -> Self {
118
440k
        Headers {
119
440k
            stream_id,
120
440k
            stream_dep: None,
121
440k
            header_block: HeaderBlock {
122
440k
                field_size: calculate_headermap_size(&fields),
123
440k
                fields,
124
440k
                is_over_size: false,
125
440k
                pseudo,
126
440k
            },
127
440k
            flags: HeadersFlag::default(),
128
440k
        }
129
440k
    }
130
131
0
    pub fn trailers(stream_id: StreamId, fields: HeaderMap) -> Self {
132
0
        let mut flags = HeadersFlag::default();
133
0
        flags.set_end_stream();
134
135
0
        Headers {
136
0
            stream_id,
137
0
            stream_dep: None,
138
0
            header_block: HeaderBlock {
139
0
                field_size: calculate_headermap_size(&fields),
140
0
                fields,
141
0
                is_over_size: false,
142
0
                pseudo: Pseudo::default(),
143
0
            },
144
0
            flags,
145
0
        }
146
0
    }
147
148
    /// Loads the header frame but doesn't actually do HPACK decoding.
149
    ///
150
    /// HPACK decoding is done in the `load_hpack` step.
151
45.8k
    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
152
45.8k
        let flags = HeadersFlag(head.flag());
153
45.8k
        let mut pad = 0;
154
155
45.8k
        tracing::trace!("loading headers; flags={:?}", flags);
156
157
45.8k
        if head.stream_id().is_zero() {
158
3
            return Err(Error::InvalidStreamId);
159
45.8k
        }
160
161
        // Read the padding length
162
45.8k
        if flags.is_padded() {
163
3.20k
            if src.is_empty() {
164
1
                return Err(Error::MalformedMessage);
165
3.20k
            }
166
3.20k
            pad = src[0] as usize;
167
168
            // Drop the padding
169
3.20k
            src.advance(1);
170
42.6k
        }
171
172
        // Read the stream dependency
173
45.8k
        let stream_dep = if flags.is_priority() {
174
5.30k
            if src.len() < 5 {
175
3
                return Err(Error::MalformedMessage);
176
5.29k
            }
177
5.29k
            let stream_dep = StreamDependency::load(&src[..5])?;
178
179
5.29k
            if stream_dep.dependency_id() == head.stream_id() {
180
72
                return Err(Error::InvalidDependencyId);
181
5.22k
            }
182
183
            // Drop the next 5 bytes
184
5.22k
            src.advance(5);
185
186
5.22k
            Some(stream_dep)
187
        } else {
188
40.5k
            None
189
        };
190
191
45.8k
        if pad > 0 {
192
3.10k
            if pad > src.len() {
193
14
                return Err(Error::TooMuchPadding);
194
3.09k
            }
195
196
3.09k
            let len = src.len() - pad;
197
3.09k
            src.truncate(len);
198
42.7k
        }
199
200
45.8k
        let headers = Headers {
201
45.8k
            stream_id: head.stream_id(),
202
45.8k
            stream_dep,
203
45.8k
            header_block: HeaderBlock {
204
45.8k
                fields: HeaderMap::new(),
205
45.8k
                field_size: 0,
206
45.8k
                is_over_size: false,
207
45.8k
                pseudo: Pseudo::default(),
208
45.8k
            },
209
45.8k
            flags,
210
45.8k
        };
211
212
45.8k
        Ok((headers, src))
213
45.8k
    }
214
215
57.8k
    pub fn load_hpack(
216
57.8k
        &mut self,
217
57.8k
        src: &mut BytesMut,
218
57.8k
        max_header_list_size: usize,
219
57.8k
        decoder: &mut hpack::Decoder,
220
57.8k
    ) -> Result<(), Error> {
221
57.8k
        self.header_block.load(src, max_header_list_size, decoder)
222
57.8k
    }
223
224
462k
    pub fn stream_id(&self) -> StreamId {
225
462k
        self.stream_id
226
462k
    }
227
228
45.8k
    pub fn is_end_headers(&self) -> bool {
229
45.8k
        self.flags.is_end_headers()
230
45.8k
    }
231
232
148
    pub fn set_end_headers(&mut self) {
233
148
        self.flags.set_end_headers();
234
148
    }
235
236
443k
    pub fn is_end_stream(&self) -> bool {
237
443k
        self.flags.is_end_stream()
238
443k
    }
239
240
708
    pub fn set_end_stream(&mut self) {
241
708
        self.flags.set_end_stream()
242
708
    }
243
244
11.9k
    pub fn is_over_size(&self) -> bool {
245
11.9k
        self.header_block.is_over_size
246
11.9k
    }
247
248
2.02k
    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
249
2.02k
        (self.header_block.pseudo, self.header_block.fields)
250
2.02k
    }
251
252
    #[cfg(feature = "unstable")]
253
0
    pub fn pseudo_mut(&mut self) -> &mut Pseudo {
254
0
        &mut self.header_block.pseudo
255
0
    }
256
257
13
    pub(crate) fn pseudo(&self) -> &Pseudo {
258
13
        &self.header_block.pseudo
259
13
    }
260
261
    /// Whether it has status 1xx
262
1.83k
    pub(crate) fn is_informational(&self) -> bool {
263
1.83k
        self.header_block.pseudo.is_informational()
264
1.83k
    }
265
266
442k
    pub fn fields(&self) -> &HeaderMap {
267
442k
        &self.header_block.fields
268
442k
    }
269
270
37
    pub fn into_fields(self) -> HeaderMap {
271
37
        self.header_block.fields
272
37
    }
273
274
184k
    pub fn encode(
275
184k
        self,
276
184k
        encoder: &mut hpack::Encoder,
277
184k
        dst: &mut EncodeBuf<'_>,
278
184k
    ) -> Option<Continuation> {
279
        // At this point, the `is_end_headers` flag should always be set
280
184k
        debug_assert!(self.flags.is_end_headers());
281
282
        // Get the HEADERS frame head
283
184k
        let head = self.head();
284
285
184k
        self.header_block
286
184k
            .into_encoding(encoder)
287
184k
            .encode(&head, dst, |_| {})
288
184k
    }
289
290
184k
    fn head(&self) -> Head {
291
184k
        Head::new(Kind::Headers, self.flags.into(), self.stream_id)
292
184k
    }
293
}
294
295
impl<T> From<Headers> for Frame<T> {
296
633k
    fn from(src: Headers) -> Self {
297
633k
        Frame::Headers(src)
298
633k
    }
<h2::frame::Frame as core::convert::From<h2::frame::headers::Headers>>::from
Line
Count
Source
296
448k
    fn from(src: Headers) -> Self {
297
448k
        Frame::Headers(src)
298
448k
    }
<h2::frame::Frame<h2::proto::streams::prioritize::Prioritized<bytes::bytes::Bytes>> as core::convert::From<h2::frame::headers::Headers>>::from
Line
Count
Source
296
184k
    fn from(src: Headers) -> Self {
297
184k
        Frame::Headers(src)
298
184k
    }
299
}
300
301
impl fmt::Debug for Headers {
302
0
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
303
0
        let mut builder = f.debug_struct("Headers");
304
0
        builder
305
0
            .field("stream_id", &self.stream_id)
306
0
            .field("flags", &self.flags);
307
308
0
        if let Some(ref protocol) = self.header_block.pseudo.protocol {
309
0
            builder.field("protocol", protocol);
310
0
        }
311
312
0
        if let Some(ref dep) = self.stream_dep {
313
0
            builder.field("stream_dep", dep);
314
0
        }
315
316
        // `fields` and `pseudo` purposefully not included
317
0
        builder.finish()
318
0
    }
319
}
320
321
// ===== util =====
322
323
#[derive(Debug, PartialEq, Eq)]
324
pub struct ParseU64Error;
325
326
212
pub fn parse_u64(src: &[u8]) -> Result<u64, ParseU64Error> {
327
212
    if src.len() > 19 {
328
        // At danger for overflow...
329
4
        return Err(ParseU64Error);
330
208
    }
331
332
208
    let mut ret = 0;
333
334
1.09k
    for &d in src {
335
930
        if d < b'0' || d > b'9' {
336
39
            return Err(ParseU64Error);
337
891
        }
338
339
891
        ret *= 10;
340
891
        ret += (d - b'0') as u64;
341
    }
342
343
169
    Ok(ret)
344
212
}
345
346
// ===== impl PushPromise =====
347
348
#[derive(Debug)]
349
pub enum PushPromiseHeaderError {
350
    InvalidContentLength(Result<u64, ParseU64Error>),
351
    NotSafeAndCacheable,
352
}
353
354
impl PushPromise {
355
0
    pub fn new(
356
0
        stream_id: StreamId,
357
0
        promised_id: StreamId,
358
0
        pseudo: Pseudo,
359
0
        fields: HeaderMap,
360
0
    ) -> Self {
361
0
        PushPromise {
362
0
            flags: PushPromiseFlag::default(),
363
0
            header_block: HeaderBlock {
364
0
                field_size: calculate_headermap_size(&fields),
365
0
                fields,
366
0
                is_over_size: false,
367
0
                pseudo,
368
0
            },
369
0
            promised_id,
370
0
            stream_id,
371
0
        }
372
0
    }
373
374
347
    pub fn validate_request(req: &Request<()>) -> Result<(), PushPromiseHeaderError> {
375
        use PushPromiseHeaderError::*;
376
        // The spec has some requirements for promised request headers
377
        // [https://httpwg.org/specs/rfc7540.html#PushRequests]
378
379
        // A promised request "that indicates the presence of a request body
380
        // MUST reset the promised stream with a stream error"
381
347
        if let Some(content_length) = req.headers().get(header::CONTENT_LENGTH) {
382
60
            let parsed_length = parse_u64(content_length.as_bytes());
383
60
            if parsed_length != Ok(0) {
384
23
                return Err(InvalidContentLength(parsed_length));
385
37
            }
386
287
        }
387
        // "The server MUST include a method in the :method pseudo-header field
388
        // that is safe and cacheable"
389
324
        if !Self::safe_and_cacheable(req.method()) {
390
142
            return Err(NotSafeAndCacheable);
391
182
        }
392
393
182
        Ok(())
394
347
    }
395
396
324
    fn safe_and_cacheable(method: &Method) -> bool {
397
        // Cacheable: https://httpwg.org/specs/rfc7231.html#cacheable.methods
398
        // Safe: https://httpwg.org/specs/rfc7231.html#safe.methods
399
324
        method == Method::GET || method == Method::HEAD
400
324
    }
401
402
0
    pub fn fields(&self) -> &HeaderMap {
403
0
        &self.header_block.fields
404
0
    }
405
406
    #[cfg(feature = "unstable")]
407
0
    pub fn into_fields(self) -> HeaderMap {
408
0
        self.header_block.fields
409
0
    }
410
411
    /// Loads the push promise frame but doesn't actually do HPACK decoding.
412
    ///
413
    /// HPACK decoding is done in the `load_hpack` step.
414
2.51k
    pub fn load(head: Head, mut src: BytesMut) -> Result<(Self, BytesMut), Error> {
415
2.51k
        let flags = PushPromiseFlag(head.flag());
416
2.51k
        let mut pad = 0;
417
418
2.51k
        if head.stream_id().is_zero() {
419
2
            return Err(Error::InvalidStreamId);
420
2.51k
        }
421
422
        // Read the padding length
423
2.51k
        if flags.is_padded() {
424
135
            if src.is_empty() {
425
4
                return Err(Error::MalformedMessage);
426
131
            }
427
428
            // TODO: Ensure payload is sized correctly
429
131
            pad = src[0] as usize;
430
431
            // Drop the padding
432
131
            src.advance(1);
433
2.38k
        }
434
435
2.51k
        if src.len() < 5 {
436
17
            return Err(Error::MalformedMessage);
437
2.49k
        }
438
439
2.49k
        let (promised_id, _) = StreamId::parse(&src[..4]);
440
        // Drop promised_id bytes
441
2.49k
        src.advance(4);
442
443
2.49k
        if pad > 0 {
444
109
            if pad > src.len() {
445
9
                return Err(Error::TooMuchPadding);
446
100
            }
447
448
100
            let len = src.len() - pad;
449
100
            src.truncate(len);
450
2.38k
        }
451
452
2.48k
        let frame = PushPromise {
453
2.48k
            flags,
454
2.48k
            header_block: HeaderBlock {
455
2.48k
                fields: HeaderMap::new(),
456
2.48k
                field_size: 0,
457
2.48k
                is_over_size: false,
458
2.48k
                pseudo: Pseudo::default(),
459
2.48k
            },
460
2.48k
            promised_id,
461
2.48k
            stream_id: head.stream_id(),
462
2.48k
        };
463
2.48k
        Ok((frame, src))
464
2.51k
    }
465
466
2.48k
    pub fn load_hpack(
467
2.48k
        &mut self,
468
2.48k
        src: &mut BytesMut,
469
2.48k
        max_header_list_size: usize,
470
2.48k
        decoder: &mut hpack::Decoder,
471
2.48k
    ) -> Result<(), Error> {
472
2.48k
        self.header_block.load(src, max_header_list_size, decoder)
473
2.48k
    }
474
475
1.27k
    pub fn stream_id(&self) -> StreamId {
476
1.27k
        self.stream_id
477
1.27k
    }
478
479
2.40k
    pub fn promised_id(&self) -> StreamId {
480
2.40k
        self.promised_id
481
2.40k
    }
482
483
2.48k
    pub fn is_end_headers(&self) -> bool {
484
2.48k
        self.flags.is_end_headers()
485
2.48k
    }
486
487
0
    pub fn set_end_headers(&mut self) {
488
0
        self.flags.set_end_headers();
489
0
    }
490
491
1.13k
    pub fn is_over_size(&self) -> bool {
492
1.13k
        self.header_block.is_over_size
493
1.13k
    }
494
495
0
    pub fn encode(
496
0
        self,
497
0
        encoder: &mut hpack::Encoder,
498
0
        dst: &mut EncodeBuf<'_>,
499
0
    ) -> Option<Continuation> {
500
        // At this point, the `is_end_headers` flag should always be set
501
0
        debug_assert!(self.flags.is_end_headers());
502
503
0
        let head = self.head();
504
0
        let promised_id = self.promised_id;
505
506
0
        self.header_block
507
0
            .into_encoding(encoder)
508
0
            .encode(&head, dst, |dst| {
509
0
                dst.put_u32(promised_id.into());
510
0
            })
511
0
    }
512
513
0
    fn head(&self) -> Head {
514
0
        Head::new(Kind::PushPromise, self.flags.into(), self.stream_id)
515
0
    }
516
517
    /// Consume `self`, returning the parts of the frame
518
1.13k
    pub fn into_parts(self) -> (Pseudo, HeaderMap) {
519
1.13k
        (self.header_block.pseudo, self.header_block.fields)
520
1.13k
    }
521
}
522
523
impl<T> From<PushPromise> for Frame<T> {
524
1.27k
    fn from(src: PushPromise) -> Self {
525
1.27k
        Frame::PushPromise(src)
526
1.27k
    }
<h2::frame::Frame as core::convert::From<h2::frame::headers::PushPromise>>::from
Line
Count
Source
524
1.27k
    fn from(src: PushPromise) -> Self {
525
1.27k
        Frame::PushPromise(src)
526
1.27k
    }
Unexecuted instantiation: <h2::frame::Frame<h2::proto::streams::prioritize::Prioritized<bytes::bytes::Bytes>> as core::convert::From<h2::frame::headers::PushPromise>>::from
527
}
528
529
impl fmt::Debug for PushPromise {
530
0
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
531
0
        f.debug_struct("PushPromise")
532
0
            .field("stream_id", &self.stream_id)
533
0
            .field("promised_id", &self.promised_id)
534
0
            .field("flags", &self.flags)
535
            // `fields` and `pseudo` purposefully not included
536
0
            .finish()
537
0
    }
538
}
539
540
// ===== impl Continuation =====
541
542
impl Continuation {
543
0
    fn head(&self) -> Head {
544
0
        Head::new(Kind::Continuation, END_HEADERS, self.stream_id)
545
0
    }
546
547
0
    pub fn encode(self, dst: &mut EncodeBuf<'_>) -> Option<Continuation> {
548
        // Get the CONTINUATION frame head
549
0
        let head = self.head();
550
551
0
        self.header_block.encode(&head, dst, |_| {})
552
0
    }
553
}
554
555
// ===== impl Pseudo =====
556
557
impl Pseudo {
558
440k
    pub fn request(method: Method, uri: Uri, protocol: Option<Protocol>) -> Self {
559
440k
        let parts = uri::Parts::from(uri);
560
561
440k
        let (scheme, path) = if method == Method::CONNECT && protocol.is_none() {
562
0
            (None, None)
563
        } else {
564
440k
            let path = parts
565
440k
                .path_and_query
566
440k
                .map(|v| BytesStr::from(v.as_str()))
567
440k
                .unwrap_or(BytesStr::from_static(""));
568
569
440k
            let path = if !path.is_empty() {
570
439k
                path
571
556
            } else if method == Method::OPTIONS {
572
0
                BytesStr::from_static("*")
573
            } else {
574
556
                BytesStr::from_static("/")
575
            };
576
577
440k
            (parts.scheme, Some(path))
578
        };
579
580
440k
        let mut pseudo = Pseudo {
581
440k
            method: Some(method),
582
440k
            scheme: None,
583
440k
            authority: None,
584
440k
            path,
585
440k
            protocol,
586
440k
            status: None,
587
440k
        };
588
589
        // If the URI includes a scheme component, add it to the pseudo headers
590
440k
        if let Some(scheme) = scheme {
591
439k
            pseudo.set_scheme(scheme);
592
439k
        }
593
594
        // If the URI includes an authority component, add it to the pseudo
595
        // headers
596
440k
        if let Some(authority) = parts.authority {
597
440k
            pseudo.set_authority(BytesStr::from(authority.as_str()));
598
440k
        }
599
600
440k
        pseudo
601
440k
    }
602
603
0
    pub fn response(status: StatusCode) -> Self {
604
0
        Pseudo {
605
0
            method: None,
606
0
            scheme: None,
607
0
            authority: None,
608
0
            path: None,
609
0
            protocol: None,
610
0
            status: Some(status),
611
0
        }
612
0
    }
613
614
    #[cfg(feature = "unstable")]
615
0
    pub fn set_status(&mut self, value: StatusCode) {
616
0
        self.status = Some(value);
617
0
    }
618
619
439k
    pub fn set_scheme(&mut self, scheme: uri::Scheme) {
620
439k
        let bytes_str = match scheme.as_str() {
621
439k
            "http" => BytesStr::from_static("http"),
622
439k
            "https" => BytesStr::from_static("https"),
623
34
            s => BytesStr::from(s),
624
        };
625
439k
        self.scheme = Some(bytes_str);
626
439k
    }
627
628
    #[cfg(feature = "unstable")]
629
0
    pub fn set_protocol(&mut self, protocol: Protocol) {
630
0
        self.protocol = Some(protocol);
631
0
    }
632
633
440k
    pub fn set_authority(&mut self, authority: BytesStr) {
634
440k
        self.authority = Some(authority);
635
440k
    }
636
637
    /// Whether it has status 1xx
638
3.86k
    pub(crate) fn is_informational(&self) -> bool {
639
3.86k
        self.status
640
3.86k
            .map_or(false, |status| status.is_informational())
641
3.86k
    }
642
}
643
644
// ===== impl EncodingHeaderBlock =====
645
646
impl EncodingHeaderBlock {
647
184k
    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
648
184k
    where
649
184k
        F: FnOnce(&mut EncodeBuf<'_>),
650
    {
651
184k
        let head_pos = dst.get_ref().len();
652
653
        // At this point, we don't know how big the h2 frame will be.
654
        // So, we write the head with length 0, then write the body, and
655
        // finally write the length once we know the size.
656
184k
        head.encode(0, dst);
657
658
184k
        let payload_pos = dst.get_ref().len();
659
660
184k
        f(dst);
661
662
        // Now, encode the header payload
663
184k
        let continuation = if self.hpack.len() > dst.remaining_mut() {
664
0
            dst.put((&mut self.hpack).take(dst.remaining_mut()));
665
666
0
            Some(Continuation {
667
0
                stream_id: head.stream_id(),
668
0
                header_block: self,
669
0
            })
670
        } else {
671
184k
            dst.put_slice(&self.hpack);
672
673
184k
            None
674
        };
675
676
        // Compute the header block length
677
184k
        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
678
679
        // Write the frame length
680
184k
        let payload_len_be = payload_len.to_be_bytes();
681
923k
        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
<h2::frame::headers::EncodingHeaderBlock>::encode::<<h2::frame::headers::Headers>::encode::{closure#0}>::{closure#0}
Line
Count
Source
681
923k
        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
Unexecuted instantiation: <h2::frame::headers::EncodingHeaderBlock>::encode::<<h2::frame::headers::PushPromise>::encode::{closure#0}>::{closure#0}
Unexecuted instantiation: <h2::frame::headers::EncodingHeaderBlock>::encode::<<h2::frame::headers::Continuation>::encode::{closure#0}>::{closure#0}
682
184k
        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
683
684
184k
        if continuation.is_some() {
685
            // There will be continuation frames, so the `is_end_headers` flag
686
            // must be unset
687
0
            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
688
689
0
            dst.get_mut()[head_pos + 4] -= END_HEADERS;
690
184k
        }
691
692
184k
        continuation
693
184k
    }
<h2::frame::headers::EncodingHeaderBlock>::encode::<<h2::frame::headers::Headers>::encode::{closure#0}>
Line
Count
Source
647
184k
    fn encode<F>(mut self, head: &Head, dst: &mut EncodeBuf<'_>, f: F) -> Option<Continuation>
648
184k
    where
649
184k
        F: FnOnce(&mut EncodeBuf<'_>),
650
    {
651
184k
        let head_pos = dst.get_ref().len();
652
653
        // At this point, we don't know how big the h2 frame will be.
654
        // So, we write the head with length 0, then write the body, and
655
        // finally write the length once we know the size.
656
184k
        head.encode(0, dst);
657
658
184k
        let payload_pos = dst.get_ref().len();
659
660
184k
        f(dst);
661
662
        // Now, encode the header payload
663
184k
        let continuation = if self.hpack.len() > dst.remaining_mut() {
664
0
            dst.put((&mut self.hpack).take(dst.remaining_mut()));
665
666
0
            Some(Continuation {
667
0
                stream_id: head.stream_id(),
668
0
                header_block: self,
669
0
            })
670
        } else {
671
184k
            dst.put_slice(&self.hpack);
672
673
184k
            None
674
        };
675
676
        // Compute the header block length
677
184k
        let payload_len = (dst.get_ref().len() - payload_pos) as u64;
678
679
        // Write the frame length
680
184k
        let payload_len_be = payload_len.to_be_bytes();
681
184k
        assert!(payload_len_be[0..5].iter().all(|b| *b == 0));
682
184k
        (dst.get_mut()[head_pos..head_pos + 3]).copy_from_slice(&payload_len_be[5..]);
683
684
184k
        if continuation.is_some() {
685
            // There will be continuation frames, so the `is_end_headers` flag
686
            // must be unset
687
0
            debug_assert!(dst.get_ref()[head_pos + 4] & END_HEADERS == END_HEADERS);
688
689
0
            dst.get_mut()[head_pos + 4] -= END_HEADERS;
690
184k
        }
691
692
184k
        continuation
693
184k
    }
Unexecuted instantiation: <h2::frame::headers::EncodingHeaderBlock>::encode::<<h2::frame::headers::PushPromise>::encode::{closure#0}>
Unexecuted instantiation: <h2::frame::headers::EncodingHeaderBlock>::encode::<<h2::frame::headers::Continuation>::encode::{closure#0}>
694
}
695
696
// ===== impl Iter =====
697
698
impl Iterator for Iter {
699
    type Item = hpack::Header<Option<HeaderName>>;
700
701
923k
    fn next(&mut self) -> Option<Self::Item> {
702
        use crate::hpack::Header::*;
703
704
923k
        if let Some(ref mut pseudo) = self.pseudo {
705
923k
            if let Some(method) = pseudo.method.take() {
706
184k
                return Some(Method(method));
707
739k
            }
708
709
739k
            if let Some(scheme) = pseudo.scheme.take() {
710
184k
                return Some(Scheme(scheme));
711
554k
            }
712
713
554k
            if let Some(authority) = pseudo.authority.take() {
714
184k
                return Some(Authority(authority));
715
369k
            }
716
717
369k
            if let Some(path) = pseudo.path.take() {
718
184k
                return Some(Path(path));
719
184k
            }
720
721
184k
            if let Some(protocol) = pseudo.protocol.take() {
722
0
                return Some(Protocol(protocol));
723
184k
            }
724
725
184k
            if let Some(status) = pseudo.status.take() {
726
0
                return Some(Status(status));
727
184k
            }
728
0
        }
729
730
184k
        self.pseudo = None;
731
732
184k
        self.fields
733
184k
            .next()
734
184k
            .map(|(name, value)| Field { name, value })
735
923k
    }
736
}
737
738
// ===== impl HeadersFlag =====
739
740
impl HeadersFlag {
741
0
    pub fn empty() -> HeadersFlag {
742
0
        HeadersFlag(0)
743
0
    }
744
745
0
    pub fn load(bits: u8) -> HeadersFlag {
746
0
        HeadersFlag(bits & ALL)
747
0
    }
748
749
443k
    pub fn is_end_stream(&self) -> bool {
750
443k
        self.0 & END_STREAM == END_STREAM
751
443k
    }
752
753
708
    pub fn set_end_stream(&mut self) {
754
708
        self.0 |= END_STREAM;
755
708
    }
756
757
45.8k
    pub fn is_end_headers(&self) -> bool {
758
45.8k
        self.0 & END_HEADERS == END_HEADERS
759
45.8k
    }
760
761
148
    pub fn set_end_headers(&mut self) {
762
148
        self.0 |= END_HEADERS;
763
148
    }
764
765
45.8k
    pub fn is_padded(&self) -> bool {
766
45.8k
        self.0 & PADDED == PADDED
767
45.8k
    }
768
769
45.8k
    pub fn is_priority(&self) -> bool {
770
45.8k
        self.0 & PRIORITY == PRIORITY
771
45.8k
    }
772
}
773
774
impl Default for HeadersFlag {
775
    /// Returns a `HeadersFlag` value with `END_HEADERS` set.
776
440k
    fn default() -> Self {
777
440k
        HeadersFlag(END_HEADERS)
778
440k
    }
779
}
780
781
impl From<HeadersFlag> for u8 {
782
184k
    fn from(src: HeadersFlag) -> u8 {
783
184k
        src.0
784
184k
    }
785
}
786
787
impl fmt::Debug for HeadersFlag {
788
0
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
789
0
        util::debug_flags(fmt, self.0)
790
0
            .flag_if(self.is_end_headers(), "END_HEADERS")
791
0
            .flag_if(self.is_end_stream(), "END_STREAM")
792
0
            .flag_if(self.is_padded(), "PADDED")
793
0
            .flag_if(self.is_priority(), "PRIORITY")
794
0
            .finish()
795
0
    }
796
}
797
798
// ===== impl PushPromiseFlag =====
799
800
impl PushPromiseFlag {
801
0
    pub fn empty() -> PushPromiseFlag {
802
0
        PushPromiseFlag(0)
803
0
    }
804
805
0
    pub fn load(bits: u8) -> PushPromiseFlag {
806
0
        PushPromiseFlag(bits & ALL)
807
0
    }
808
809
2.48k
    pub fn is_end_headers(&self) -> bool {
810
2.48k
        self.0 & END_HEADERS == END_HEADERS
811
2.48k
    }
812
813
0
    pub fn set_end_headers(&mut self) {
814
0
        self.0 |= END_HEADERS;
815
0
    }
816
817
2.51k
    pub fn is_padded(&self) -> bool {
818
2.51k
        self.0 & PADDED == PADDED
819
2.51k
    }
820
}
821
822
impl Default for PushPromiseFlag {
823
    /// Returns a `PushPromiseFlag` value with `END_HEADERS` set.
824
0
    fn default() -> Self {
825
0
        PushPromiseFlag(END_HEADERS)
826
0
    }
827
}
828
829
impl From<PushPromiseFlag> for u8 {
830
0
    fn from(src: PushPromiseFlag) -> u8 {
831
0
        src.0
832
0
    }
833
}
834
835
impl fmt::Debug for PushPromiseFlag {
836
0
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
837
0
        util::debug_flags(fmt, self.0)
838
0
            .flag_if(self.is_end_headers(), "END_HEADERS")
839
0
            .flag_if(self.is_padded(), "PADDED")
840
0
            .finish()
841
0
    }
842
}
843
844
// ===== HeaderBlock =====
845
846
impl HeaderBlock {
847
60.3k
    fn load(
848
60.3k
        &mut self,
849
60.3k
        src: &mut BytesMut,
850
60.3k
        max_header_list_size: usize,
851
60.3k
        decoder: &mut hpack::Decoder,
852
60.3k
    ) -> Result<(), Error> {
853
60.3k
        let mut reg = !self.fields.is_empty();
854
60.3k
        let mut malformed = false;
855
60.3k
        let mut headers_size = self.calculate_header_list_size();
856
857
        macro_rules! set_pseudo {
858
            ($field:ident, $val:expr) => {{
859
                if reg {
860
                    tracing::trace!("load_hpack; header malformed -- pseudo not at head of block");
861
                    malformed = true;
862
                } else if self.pseudo.$field.is_some() {
863
                    tracing::trace!("load_hpack; header malformed -- repeated pseudo");
864
                    malformed = true;
865
                } else {
866
                    let __val = $val;
867
                    headers_size +=
868
                        decoded_header_size(stringify!($field).len() + 1, __val.as_str().len());
869
                    if headers_size < max_header_list_size {
870
                        self.pseudo.$field = Some(__val);
871
                    } else if !self.is_over_size {
872
                        tracing::trace!("load_hpack; header list size over max");
873
                        self.is_over_size = true;
874
                    }
875
                }
876
            }};
877
        }
878
879
60.3k
        let mut cursor = Cursor::new(src);
880
881
        // If the header frame is malformed, we still have to continue decoding
882
        // the headers. A malformed header frame is a stream level error, but
883
        // the hpack state is connection level. In order to maintain correct
884
        // state for other streams, the hpack decoding process must complete.
885
692k
        let res = decoder.decode(&mut cursor, |header| {
886
            use crate::hpack::Header::*;
887
888
692k
            match header {
889
475k
                Field { name, value } => {
890
                    // Connection level header fields are not supported and must
891
                    // result in a protocol error.
892
893
475k
                    if name == header::CONNECTION
894
474k
                        || name == header::TRANSFER_ENCODING
895
469k
                        || name == header::UPGRADE
896
469k
                        || name == "keep-alive"
897
469k
                        || name == "proxy-connection"
898
                    {
899
5.88k
                        tracing::trace!("load_hpack; connection level header");
900
5.88k
                        malformed = true;
901
469k
                    } else if name == header::TE && value != "trailers" {
902
491
                        tracing::trace!(
903
0
                            "load_hpack; TE header not set to trailers; val={:?}",
904
                            value
905
                        );
906
491
                        malformed = true;
907
                    } else {
908
469k
                        reg = true;
909
910
469k
                        headers_size += decoded_header_size(name.as_str().len(), value.len());
911
469k
                        if headers_size < max_header_list_size {
912
469k
                            self.field_size +=
913
469k
                                decoded_header_size(name.as_str().len(), value.len());
914
469k
                            self.fields.append(name, value);
915
469k
                        } else if !self.is_over_size {
916
0
                            tracing::trace!("load_hpack; header list size over max");
917
0
                            self.is_over_size = true;
918
0
                        }
919
                    }
920
                }
921
20.0k
                Authority(v) => set_pseudo!(authority, v),
922
58.7k
                Method(v) => set_pseudo!(method, v),
923
25.6k
                Scheme(v) => set_pseudo!(scheme, v),
924
21.5k
                Path(v) => set_pseudo!(path, v),
925
2.18k
                Protocol(v) => set_pseudo!(protocol, v),
926
89.2k
                Status(v) => set_pseudo!(status, v),
927
            }
928
692k
        });
929
930
60.3k
        if let Err(e) = res {
931
16.8k
            tracing::trace!("hpack decoding error; err={:?}", e);
932
16.8k
            return Err(e.into());
933
43.4k
        }
934
935
43.4k
        if malformed {
936
31.6k
            tracing::trace!("malformed message");
937
31.6k
            return Err(Error::MalformedMessage);
938
11.8k
        }
939
940
11.8k
        Ok(())
941
60.3k
    }
942
943
184k
    fn into_encoding(self, encoder: &mut hpack::Encoder) -> EncodingHeaderBlock {
944
184k
        let mut hpack = BytesMut::new();
945
184k
        let headers = Iter {
946
184k
            pseudo: Some(self.pseudo),
947
184k
            fields: self.fields.into_iter(),
948
184k
        };
949
950
184k
        encoder.encode(headers, &mut hpack);
951
952
184k
        EncodingHeaderBlock {
953
184k
            hpack: hpack.freeze(),
954
184k
        }
955
184k
    }
956
957
    /// Calculates the size of the currently decoded header list.
958
    ///
959
    /// According to http://httpwg.org/specs/rfc7540.html#SETTINGS_MAX_HEADER_LIST_SIZE
960
    ///
961
    /// > The value is based on the uncompressed size of header fields,
962
    /// > including the length of the name and value in octets plus an
963
    /// > overhead of 32 octets for each header field.
964
60.3k
    fn calculate_header_list_size(&self) -> usize {
965
        macro_rules! pseudo_size {
966
            ($name:ident) => {{
967
                self.pseudo
968
                    .$name
969
                    .as_ref()
970
9.06k
                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
<h2::frame::headers::HeaderBlock>::calculate_header_list_size::{closure#0}
Line
Count
Source
970
3.58k
                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
<h2::frame::headers::HeaderBlock>::calculate_header_list_size::{closure#2}
Line
Count
Source
970
552
                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
<h2::frame::headers::HeaderBlock>::calculate_header_list_size::{closure#3}
Line
Count
Source
970
1.67k
                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
<h2::frame::headers::HeaderBlock>::calculate_header_list_size::{closure#4}
Line
Count
Source
970
2.61k
                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
<h2::frame::headers::HeaderBlock>::calculate_header_list_size::{closure#1}
Line
Count
Source
970
641
                    .map(|m| decoded_header_size(stringify!($name).len() + 1, m.as_str().len()))
971
                    .unwrap_or(0)
972
            }};
973
        }
974
975
60.3k
        pseudo_size!(method)
976
60.3k
            + pseudo_size!(scheme)
977
60.3k
            + pseudo_size!(status)
978
60.3k
            + pseudo_size!(authority)
979
60.3k
            + pseudo_size!(path)
980
60.3k
            + self.field_size
981
60.3k
    }
982
}
983
984
440k
fn calculate_headermap_size(map: &HeaderMap) -> usize {
985
440k
    map.iter()
986
440k
        .map(|(name, value)| decoded_header_size(name.as_str().len(), value.len()))
987
440k
        .sum::<usize>()
988
440k
}
989
990
973k
fn decoded_header_size(name: usize, value: usize) -> usize {
991
973k
    name + value + 32
992
973k
}
993
994
#[cfg(test)]
995
mod test {
996
    use super::*;
997
    use crate::frame;
998
    use crate::hpack::{huffman, Encoder};
999
1000
    #[test]
1001
    fn test_nameless_header_at_resume() {
1002
        let mut encoder = Encoder::default();
1003
        let mut dst = BytesMut::new();
1004
1005
        let headers = Headers::new(
1006
            StreamId::ZERO,
1007
            Default::default(),
1008
            HeaderMap::from_iter(vec![
1009
                (
1010
                    HeaderName::from_static("hello"),
1011
                    HeaderValue::from_static("world"),
1012
                ),
1013
                (
1014
                    HeaderName::from_static("hello"),
1015
                    HeaderValue::from_static("zomg"),
1016
                ),
1017
                (
1018
                    HeaderName::from_static("hello"),
1019
                    HeaderValue::from_static("sup"),
1020
                ),
1021
            ]),
1022
        );
1023
1024
        let continuation = headers
1025
            .encode(&mut encoder, &mut (&mut dst).limit(frame::HEADER_LEN + 8))
1026
            .unwrap();
1027
1028
        assert_eq!(17, dst.len());
1029
        assert_eq!([0, 0, 8, 1, 0, 0, 0, 0, 0], &dst[0..9]);
1030
        assert_eq!(&[0x40, 0x80 | 4], &dst[9..11]);
1031
        assert_eq!("hello", huff_decode(&dst[11..15]));
1032
        assert_eq!(0x80 | 4, dst[15]);
1033
1034
        let mut world = dst[16..17].to_owned();
1035
1036
        dst.clear();
1037
1038
        assert!(continuation
1039
            .encode(&mut (&mut dst).limit(frame::HEADER_LEN + 16))
1040
            .is_none());
1041
1042
        world.extend_from_slice(&dst[9..12]);
1043
        assert_eq!("world", huff_decode(&world));
1044
1045
        assert_eq!(24, dst.len());
1046
        assert_eq!([0, 0, 15, 9, 4, 0, 0, 0, 0], &dst[0..9]);
1047
1048
        // // Next is not indexed
1049
        assert_eq!(&[15, 47, 0x80 | 3], &dst[12..15]);
1050
        assert_eq!("zomg", huff_decode(&dst[15..18]));
1051
        assert_eq!(&[15, 47, 0x80 | 3], &dst[18..21]);
1052
        assert_eq!("sup", huff_decode(&dst[21..]));
1053
    }
1054
1055
    fn huff_decode(src: &[u8]) -> BytesMut {
1056
        let mut buf = BytesMut::new();
1057
        huffman::decode(src, &mut buf).unwrap()
1058
    }
1059
1060
    #[test]
1061
    fn test_connect_request_pseudo_headers_omits_path_and_scheme() {
1062
        // CONNECT requests MUST NOT include :scheme & :path pseudo-header fields
1063
        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.5
1064
1065
        assert_eq!(
1066
            Pseudo::request(
1067
                Method::CONNECT,
1068
                Uri::from_static("https://example.com:8443"),
1069
                None
1070
            ),
1071
            Pseudo {
1072
                method: Method::CONNECT.into(),
1073
                authority: BytesStr::from_static("example.com:8443").into(),
1074
                ..Default::default()
1075
            }
1076
        );
1077
1078
        assert_eq!(
1079
            Pseudo::request(
1080
                Method::CONNECT,
1081
                Uri::from_static("https://example.com/test"),
1082
                None
1083
            ),
1084
            Pseudo {
1085
                method: Method::CONNECT.into(),
1086
                authority: BytesStr::from_static("example.com").into(),
1087
                ..Default::default()
1088
            }
1089
        );
1090
1091
        assert_eq!(
1092
            Pseudo::request(Method::CONNECT, Uri::from_static("example.com:8443"), None),
1093
            Pseudo {
1094
                method: Method::CONNECT.into(),
1095
                authority: BytesStr::from_static("example.com:8443").into(),
1096
                ..Default::default()
1097
            }
1098
        );
1099
    }
1100
1101
    #[test]
1102
    fn test_extended_connect_request_pseudo_headers_includes_path_and_scheme() {
1103
        // On requests that contain the :protocol pseudo-header field, the
1104
        // :scheme and :path pseudo-header fields of the target URI (see
1105
        // Section 5) MUST also be included.
1106
        // See: https://datatracker.ietf.org/doc/html/rfc8441#section-4
1107
1108
        assert_eq!(
1109
            Pseudo::request(
1110
                Method::CONNECT,
1111
                Uri::from_static("https://example.com:8443"),
1112
                Protocol::from_static("the-bread-protocol").into()
1113
            ),
1114
            Pseudo {
1115
                method: Method::CONNECT.into(),
1116
                authority: BytesStr::from_static("example.com:8443").into(),
1117
                scheme: BytesStr::from_static("https").into(),
1118
                path: BytesStr::from_static("/").into(),
1119
                protocol: Protocol::from_static("the-bread-protocol").into(),
1120
                ..Default::default()
1121
            }
1122
        );
1123
1124
        assert_eq!(
1125
            Pseudo::request(
1126
                Method::CONNECT,
1127
                Uri::from_static("https://example.com:8443/test"),
1128
                Protocol::from_static("the-bread-protocol").into()
1129
            ),
1130
            Pseudo {
1131
                method: Method::CONNECT.into(),
1132
                authority: BytesStr::from_static("example.com:8443").into(),
1133
                scheme: BytesStr::from_static("https").into(),
1134
                path: BytesStr::from_static("/test").into(),
1135
                protocol: Protocol::from_static("the-bread-protocol").into(),
1136
                ..Default::default()
1137
            }
1138
        );
1139
1140
        assert_eq!(
1141
            Pseudo::request(
1142
                Method::CONNECT,
1143
                Uri::from_static("http://example.com/a/b/c"),
1144
                Protocol::from_static("the-bread-protocol").into()
1145
            ),
1146
            Pseudo {
1147
                method: Method::CONNECT.into(),
1148
                authority: BytesStr::from_static("example.com").into(),
1149
                scheme: BytesStr::from_static("http").into(),
1150
                path: BytesStr::from_static("/a/b/c").into(),
1151
                protocol: Protocol::from_static("the-bread-protocol").into(),
1152
                ..Default::default()
1153
            }
1154
        );
1155
    }
1156
1157
    #[test]
1158
    fn test_options_request_with_empty_path_has_asterisk_as_pseudo_path() {
1159
        // an OPTIONS request for an "http" or "https" URI that does not include a path component;
1160
        // these MUST include a ":path" pseudo-header field with a value of '*' (see Section 7.1 of [HTTP]).
1161
        // See: https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1
1162
        assert_eq!(
1163
            Pseudo::request(Method::OPTIONS, Uri::from_static("example.com:8080"), None,),
1164
            Pseudo {
1165
                method: Method::OPTIONS.into(),
1166
                authority: BytesStr::from_static("example.com:8080").into(),
1167
                path: BytesStr::from_static("*").into(),
1168
                ..Default::default()
1169
            }
1170
        );
1171
    }
1172
1173
    #[test]
1174
    fn test_non_option_and_non_connect_requests_include_path_and_scheme() {
1175
        let methods = [
1176
            Method::GET,
1177
            Method::POST,
1178
            Method::PUT,
1179
            Method::DELETE,
1180
            Method::HEAD,
1181
            Method::PATCH,
1182
            Method::TRACE,
1183
        ];
1184
1185
        for method in methods {
1186
            assert_eq!(
1187
                Pseudo::request(
1188
                    method.clone(),
1189
                    Uri::from_static("http://example.com:8080"),
1190
                    None,
1191
                ),
1192
                Pseudo {
1193
                    method: method.clone().into(),
1194
                    authority: BytesStr::from_static("example.com:8080").into(),
1195
                    scheme: BytesStr::from_static("http").into(),
1196
                    path: BytesStr::from_static("/").into(),
1197
                    ..Default::default()
1198
                }
1199
            );
1200
            assert_eq!(
1201
                Pseudo::request(
1202
                    method.clone(),
1203
                    Uri::from_static("https://example.com/a/b/c"),
1204
                    None,
1205
                ),
1206
                Pseudo {
1207
                    method: method.into(),
1208
                    authority: BytesStr::from_static("example.com").into(),
1209
                    scheme: BytesStr::from_static("https").into(),
1210
                    path: BytesStr::from_static("/a/b/c").into(),
1211
                    ..Default::default()
1212
                }
1213
            );
1214
        }
1215
    }
1216
}