Coverage Report

Created: 2026-03-23 07:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.14.5/src/codec/compression.rs
Line
Count
Source
1
use crate::{metadata::MetadataValue, Status};
2
use bytes::{Buf, BufMut, BytesMut};
3
#[cfg(feature = "gzip")]
4
use flate2::read::{GzDecoder, GzEncoder};
5
#[cfg(feature = "deflate")]
6
use flate2::read::{ZlibDecoder, ZlibEncoder};
7
use std::{borrow::Cow, fmt};
8
#[cfg(feature = "zstd")]
9
use zstd::stream::read::{Decoder, Encoder};
10
11
pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12
pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14
/// Struct used to configure which encodings are enabled on a server or channel.
15
///
16
/// Represents an ordered list of compression encodings that are enabled.
17
#[derive(Debug, Default, Clone, Copy)]
18
pub struct EnabledCompressionEncodings {
19
    inner: [Option<CompressionEncoding>; 3],
20
}
21
22
impl EnabledCompressionEncodings {
23
    /// Enable a [`CompressionEncoding`].
24
    ///
25
    /// Adds the new encoding to the end of the encoding list.
26
0
    pub fn enable(&mut self, encoding: CompressionEncoding) {
27
0
        for e in self.inner.iter_mut() {
28
0
            match e {
29
0
                Some(e) if *e == encoding => return,
30
                None => {
31
0
                    *e = Some(encoding);
32
0
                    return;
33
                }
34
                _ => continue,
35
            }
36
        }
37
0
    }
38
39
    /// Remove the last [`CompressionEncoding`].
40
0
    pub fn pop(&mut self) -> Option<CompressionEncoding> {
41
0
        self.inner
42
0
            .iter_mut()
43
0
            .rev()
44
0
            .find(|entry| entry.is_some())?
45
0
            .take()
46
0
    }
47
48
0
    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49
0
        let mut value = BytesMut::new();
50
0
        for encoding in self.inner.into_iter().flatten() {
51
            value.put_slice(encoding.as_str().as_bytes());
52
            value.put_u8(b',');
53
        }
54
55
0
        if value.is_empty() {
56
0
            return None;
57
0
        }
58
59
0
        value.put_slice(b"identity");
60
0
        Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61
0
    }
62
63
    /// Check if a [`CompressionEncoding`] is enabled.
64
0
    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65
0
        self.inner.contains(&Some(encoding))
66
0
    }
67
68
    /// Check if any [`CompressionEncoding`]s are enabled.
69
0
    pub fn is_empty(&self) -> bool {
70
0
        self.inner.iter().all(|e| e.is_none())
71
0
    }
72
}
73
74
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
75
pub(crate) struct CompressionSettings {
76
    pub(crate) encoding: CompressionEncoding,
77
    /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
78
    /// The default buffer growth interval is 8 kilobytes.
79
    pub(crate) buffer_growth_interval: usize,
80
}
81
82
/// The compression encodings Tonic supports.
83
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
84
#[non_exhaustive]
85
pub enum CompressionEncoding {
86
    #[allow(missing_docs)]
87
    #[cfg(feature = "gzip")]
88
    Gzip,
89
    #[allow(missing_docs)]
90
    #[cfg(feature = "deflate")]
91
    Deflate,
92
    #[allow(missing_docs)]
93
    #[cfg(feature = "zstd")]
94
    Zstd,
95
}
96
97
impl CompressionEncoding {
98
    pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99
        #[cfg(feature = "gzip")]
100
        CompressionEncoding::Gzip,
101
        #[cfg(feature = "deflate")]
102
        CompressionEncoding::Deflate,
103
        #[cfg(feature = "zstd")]
104
        CompressionEncoding::Zstd,
105
    ];
106
107
    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
108
0
    pub(crate) fn from_accept_encoding_header(
109
0
        map: &http::HeaderMap,
110
0
        enabled_encodings: EnabledCompressionEncodings,
111
0
    ) -> Option<Self> {
112
0
        if enabled_encodings.is_empty() {
113
0
            return None;
114
0
        }
115
116
0
        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117
0
        let header_value_str = header_value.to_str().ok()?;
118
119
0
        split_by_comma(header_value_str).find_map(|value| match value {
120
            #[cfg(feature = "gzip")]
121
            "gzip" => Some(CompressionEncoding::Gzip),
122
            #[cfg(feature = "deflate")]
123
            "deflate" => Some(CompressionEncoding::Deflate),
124
            #[cfg(feature = "zstd")]
125
            "zstd" => Some(CompressionEncoding::Zstd),
126
0
            _ => None,
127
0
        })
128
0
    }
129
130
    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
131
0
    pub(crate) fn from_encoding_header(
132
0
        map: &http::HeaderMap,
133
0
        enabled_encodings: EnabledCompressionEncodings,
134
0
    ) -> Result<Option<Self>, Status> {
135
0
        let Some(header_value) = map.get(ENCODING_HEADER) else {
136
0
            return Ok(None);
137
        };
138
139
0
        match header_value.as_bytes() {
140
            #[cfg(feature = "gzip")]
141
            b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142
                Ok(Some(CompressionEncoding::Gzip))
143
            }
144
            #[cfg(feature = "deflate")]
145
            b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146
                Ok(Some(CompressionEncoding::Deflate))
147
            }
148
            #[cfg(feature = "zstd")]
149
            b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150
                Ok(Some(CompressionEncoding::Zstd))
151
            }
152
0
            b"identity" => Ok(None),
153
0
            other => {
154
0
                let other = match std::str::from_utf8(other) {
155
0
                    Ok(s) => Cow::Borrowed(s),
156
0
                    Err(_) => Cow::Owned(format!("{other:?}")),
157
                };
158
159
0
                let mut status = Status::unimplemented(format!(
160
                    "Content is compressed with `{other}` which isn't supported"
161
                ));
162
163
0
                let header_value = enabled_encodings
164
0
                    .into_accept_encoding_header_value()
165
0
                    .map(MetadataValue::unchecked_from_header_value)
166
0
                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
167
0
                status
168
0
                    .metadata_mut()
169
0
                    .insert(ACCEPT_ENCODING_HEADER, header_value);
170
171
0
                Err(status)
172
            }
173
        }
174
0
    }
175
176
    pub(crate) fn as_str(self) -> &'static str {
177
        match self {
178
            #[cfg(feature = "gzip")]
179
            CompressionEncoding::Gzip => "gzip",
180
            #[cfg(feature = "deflate")]
181
            CompressionEncoding::Deflate => "deflate",
182
            #[cfg(feature = "zstd")]
183
            CompressionEncoding::Zstd => "zstd",
184
        }
185
    }
186
187
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
188
    pub(crate) fn into_header_value(self) -> http::HeaderValue {
189
        http::HeaderValue::from_static(self.as_str())
190
    }
191
}
192
193
impl fmt::Display for CompressionEncoding {
194
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195
0
        f.write_str(self.as_str())
196
0
    }
197
}
198
199
0
fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
200
0
    s.split(',').map(|s| s.trim())
201
0
}
202
203
/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
204
/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
205
#[allow(unused_variables, unreachable_code)]
206
0
pub(crate) fn compress(
207
0
    settings: CompressionSettings,
208
0
    decompressed_buf: &mut BytesMut,
209
0
    out_buf: &mut BytesMut,
210
0
    len: usize,
211
0
) -> Result<(), std::io::Error> {
212
0
    let buffer_growth_interval = settings.buffer_growth_interval;
213
0
    let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
214
0
    out_buf.reserve(capacity);
215
216
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
217
    let mut out_writer = out_buf.writer();
218
219
    match settings.encoding {
220
        #[cfg(feature = "gzip")]
221
        CompressionEncoding::Gzip => {
222
            let mut gzip_encoder = GzEncoder::new(
223
                &decompressed_buf[0..len],
224
                // FIXME: support customizing the compression level
225
                flate2::Compression::new(6),
226
            );
227
            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
228
        }
229
        #[cfg(feature = "deflate")]
230
        CompressionEncoding::Deflate => {
231
            let mut deflate_encoder = ZlibEncoder::new(
232
                &decompressed_buf[0..len],
233
                // FIXME: support customizing the compression level
234
                flate2::Compression::new(6),
235
            );
236
            std::io::copy(&mut deflate_encoder, &mut out_writer)?;
237
        }
238
        #[cfg(feature = "zstd")]
239
        CompressionEncoding::Zstd => {
240
            let mut zstd_encoder = Encoder::new(
241
                &decompressed_buf[0..len],
242
                // FIXME: support customizing the compression level
243
                zstd::DEFAULT_COMPRESSION_LEVEL,
244
            )?;
245
            std::io::copy(&mut zstd_encoder, &mut out_writer)?;
246
        }
247
    }
248
249
    decompressed_buf.advance(len);
250
251
    Ok(())
252
}
253
254
/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
255
#[allow(unused_variables, unreachable_code)]
256
0
pub(crate) fn decompress(
257
0
    settings: CompressionSettings,
258
0
    compressed_buf: &mut BytesMut,
259
0
    mut out_buf: bytes::buf::Limit<&mut BytesMut>,
260
0
    len: usize,
261
0
) -> Result<(), std::io::Error> {
262
0
    let buffer_growth_interval = settings.buffer_growth_interval;
263
0
    let estimate_decompressed_len = len * 2;
264
0
    let capacity = std::cmp::min(
265
0
        bytes::buf::Limit::limit(&out_buf),
266
0
        ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval,
267
    );
268
269
0
    out_buf.get_mut().reserve(capacity);
270
271
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
272
    let mut out_writer = out_buf.writer();
273
274
    match settings.encoding {
275
        #[cfg(feature = "gzip")]
276
        CompressionEncoding::Gzip => {
277
            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
278
            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
279
        }
280
        #[cfg(feature = "deflate")]
281
        CompressionEncoding::Deflate => {
282
            let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
283
            std::io::copy(&mut deflate_decoder, &mut out_writer)?;
284
        }
285
        #[cfg(feature = "zstd")]
286
        CompressionEncoding::Zstd => {
287
            let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
288
            std::io::copy(&mut zstd_decoder, &mut out_writer)?;
289
        }
290
    }
291
292
    compressed_buf.advance(len);
293
294
    Ok(())
295
}
296
297
/// Controls compression behavior for individual messages within a stream.
298
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
299
pub enum SingleMessageCompressionOverride {
300
    /// Inherit whatever compression is already configured. If the stream is compressed this
301
    /// message will also be configured.
302
    ///
303
    /// This is the default.
304
    #[default]
305
    Inherit,
306
    /// Don't compress this message, even if compression is enabled on the stream.
307
    Disable,
308
}
309
310
#[cfg(test)]
311
mod tests {
312
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
313
    use http::HeaderValue;
314
315
    use super::*;
316
317
    #[test]
318
    fn convert_none_into_header_value() {
319
        let encodings = EnabledCompressionEncodings::default();
320
321
        assert!(encodings.into_accept_encoding_header_value().is_none());
322
    }
323
324
    #[test]
325
    #[cfg(feature = "gzip")]
326
    fn convert_gzip_into_header_value() {
327
        const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
328
329
        let encodings = EnabledCompressionEncodings {
330
            inner: [Some(CompressionEncoding::Gzip), None, None],
331
        };
332
333
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
334
335
        let encodings = EnabledCompressionEncodings {
336
            inner: [None, None, Some(CompressionEncoding::Gzip)],
337
        };
338
339
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
340
    }
341
342
    #[test]
343
    #[cfg(feature = "zstd")]
344
    fn convert_zstd_into_header_value() {
345
        const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
346
347
        let encodings = EnabledCompressionEncodings {
348
            inner: [Some(CompressionEncoding::Zstd), None, None],
349
        };
350
351
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
352
353
        let encodings = EnabledCompressionEncodings {
354
            inner: [None, None, Some(CompressionEncoding::Zstd)],
355
        };
356
357
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
358
    }
359
360
    #[test]
361
    #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
362
    fn convert_compression_encodings_into_header_value() {
363
        let encodings = EnabledCompressionEncodings {
364
            inner: [
365
                Some(CompressionEncoding::Gzip),
366
                Some(CompressionEncoding::Deflate),
367
                Some(CompressionEncoding::Zstd),
368
            ],
369
        };
370
371
        assert_eq!(
372
            encodings.into_accept_encoding_header_value().unwrap(),
373
            HeaderValue::from_static("gzip,deflate,zstd,identity"),
374
        );
375
376
        let encodings = EnabledCompressionEncodings {
377
            inner: [
378
                Some(CompressionEncoding::Zstd),
379
                Some(CompressionEncoding::Deflate),
380
                Some(CompressionEncoding::Gzip),
381
            ],
382
        };
383
384
        assert_eq!(
385
            encodings.into_accept_encoding_header_value().unwrap(),
386
            HeaderValue::from_static("zstd,deflate,gzip,identity"),
387
        );
388
    }
389
}