Coverage Report

Created: 2025-11-16 06:37

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.13.0/src/codec/compression.rs
Line
Count
Source
1
use crate::{metadata::MetadataValue, Status};
2
use bytes::{Buf, BufMut, BytesMut};
3
#[cfg(feature = "gzip")]
4
use flate2::read::{GzDecoder, GzEncoder};
5
#[cfg(feature = "deflate")]
6
use flate2::read::{ZlibDecoder, ZlibEncoder};
7
use std::fmt;
8
#[cfg(feature = "zstd")]
9
use zstd::stream::read::{Decoder, Encoder};
10
11
pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
12
pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
13
14
/// Struct used to configure which encodings are enabled on a server or channel.
15
///
16
/// Represents an ordered list of compression encodings that are enabled.
17
#[derive(Debug, Default, Clone, Copy)]
18
pub struct EnabledCompressionEncodings {
19
    inner: [Option<CompressionEncoding>; 3],
20
}
21
22
impl EnabledCompressionEncodings {
23
    /// Enable a [`CompressionEncoding`].
24
    ///
25
    /// Adds the new encoding to the end of the encoding list.
26
0
    pub fn enable(&mut self, encoding: CompressionEncoding) {
27
0
        for e in self.inner.iter_mut() {
28
0
            match e {
29
0
                Some(e) if *e == encoding => return,
30
                None => {
31
0
                    *e = Some(encoding);
32
0
                    return;
33
                }
34
                _ => continue,
35
            }
36
        }
37
0
    }
38
39
    /// Remove the last [`CompressionEncoding`].
40
0
    pub fn pop(&mut self) -> Option<CompressionEncoding> {
41
0
        self.inner
42
0
            .iter_mut()
43
0
            .rev()
44
0
            .find(|entry| entry.is_some())?
45
0
            .take()
46
0
    }
47
48
0
    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
49
0
        let mut value = BytesMut::new();
50
0
        for encoding in self.inner.into_iter().flatten() {
51
            value.put_slice(encoding.as_str().as_bytes());
52
            value.put_u8(b',');
53
        }
54
55
0
        if value.is_empty() {
56
0
            return None;
57
0
        }
58
59
0
        value.put_slice(b"identity");
60
0
        Some(http::HeaderValue::from_maybe_shared(value).unwrap())
61
0
    }
62
63
    /// Check if a [`CompressionEncoding`] is enabled.
64
0
    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
65
0
        self.inner.contains(&Some(encoding))
66
0
    }
67
68
    /// Check if any [`CompressionEncoding`]s are enabled.
69
0
    pub fn is_empty(&self) -> bool {
70
0
        self.inner.iter().all(|e| e.is_none())
71
0
    }
72
}
73
74
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
75
pub(crate) struct CompressionSettings {
76
    pub(crate) encoding: CompressionEncoding,
77
    /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
78
    /// The default buffer growth interval is 8 kilobytes.
79
    pub(crate) buffer_growth_interval: usize,
80
}
81
82
/// The compression encodings Tonic supports.
83
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
84
#[non_exhaustive]
85
pub enum CompressionEncoding {
86
    #[allow(missing_docs)]
87
    #[cfg(feature = "gzip")]
88
    Gzip,
89
    #[allow(missing_docs)]
90
    #[cfg(feature = "deflate")]
91
    Deflate,
92
    #[allow(missing_docs)]
93
    #[cfg(feature = "zstd")]
94
    Zstd,
95
}
96
97
impl CompressionEncoding {
98
    pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
99
        #[cfg(feature = "gzip")]
100
        CompressionEncoding::Gzip,
101
        #[cfg(feature = "deflate")]
102
        CompressionEncoding::Deflate,
103
        #[cfg(feature = "zstd")]
104
        CompressionEncoding::Zstd,
105
    ];
106
107
    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
108
0
    pub(crate) fn from_accept_encoding_header(
109
0
        map: &http::HeaderMap,
110
0
        enabled_encodings: EnabledCompressionEncodings,
111
0
    ) -> Option<Self> {
112
0
        if enabled_encodings.is_empty() {
113
0
            return None;
114
0
        }
115
116
0
        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
117
0
        let header_value_str = header_value.to_str().ok()?;
118
119
0
        split_by_comma(header_value_str).find_map(|value| match value {
120
            #[cfg(feature = "gzip")]
121
            "gzip" => Some(CompressionEncoding::Gzip),
122
            #[cfg(feature = "deflate")]
123
            "deflate" => Some(CompressionEncoding::Deflate),
124
            #[cfg(feature = "zstd")]
125
            "zstd" => Some(CompressionEncoding::Zstd),
126
0
            _ => None,
127
0
        })
128
0
    }
129
130
    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
131
0
    pub(crate) fn from_encoding_header(
132
0
        map: &http::HeaderMap,
133
0
        enabled_encodings: EnabledCompressionEncodings,
134
0
    ) -> Result<Option<Self>, Status> {
135
0
        let Some(header_value) = map.get(ENCODING_HEADER) else {
136
0
            return Ok(None);
137
        };
138
139
0
        match header_value.as_bytes() {
140
            #[cfg(feature = "gzip")]
141
            b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
142
                Ok(Some(CompressionEncoding::Gzip))
143
            }
144
            #[cfg(feature = "deflate")]
145
            b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => {
146
                Ok(Some(CompressionEncoding::Deflate))
147
            }
148
            #[cfg(feature = "zstd")]
149
            b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
150
                Ok(Some(CompressionEncoding::Zstd))
151
            }
152
0
            b"identity" => Ok(None),
153
0
            other => {
154
                // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79.
155
                // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension
156
                let other_debug_string;
157
158
0
                let mut status = Status::unimplemented(format!(
159
0
                    "Content is compressed with `{}` which isn't supported",
160
0
                    match std::str::from_utf8(other) {
161
0
                        Ok(s) => s,
162
                        Err(_) => {
163
0
                            other_debug_string = format!("{other:?}");
164
0
                            &other_debug_string
165
                        }
166
                    }
167
                ));
168
169
0
                let header_value = enabled_encodings
170
0
                    .into_accept_encoding_header_value()
171
0
                    .map(MetadataValue::unchecked_from_header_value)
172
0
                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
173
0
                status
174
0
                    .metadata_mut()
175
0
                    .insert(ACCEPT_ENCODING_HEADER, header_value);
176
177
0
                Err(status)
178
            }
179
        }
180
0
    }
181
182
    pub(crate) fn as_str(self) -> &'static str {
183
        match self {
184
            #[cfg(feature = "gzip")]
185
            CompressionEncoding::Gzip => "gzip",
186
            #[cfg(feature = "deflate")]
187
            CompressionEncoding::Deflate => "deflate",
188
            #[cfg(feature = "zstd")]
189
            CompressionEncoding::Zstd => "zstd",
190
        }
191
    }
192
193
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
194
    pub(crate) fn into_header_value(self) -> http::HeaderValue {
195
        http::HeaderValue::from_static(self.as_str())
196
    }
197
}
198
199
impl fmt::Display for CompressionEncoding {
200
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
201
0
        f.write_str(self.as_str())
202
0
    }
203
}
204
205
0
fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
206
0
    s.split(',').map(|s| s.trim())
207
0
}
208
209
/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
210
/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
211
#[allow(unused_variables, unreachable_code)]
212
0
pub(crate) fn compress(
213
0
    settings: CompressionSettings,
214
0
    decompressed_buf: &mut BytesMut,
215
0
    out_buf: &mut BytesMut,
216
0
    len: usize,
217
0
) -> Result<(), std::io::Error> {
218
0
    let buffer_growth_interval = settings.buffer_growth_interval;
219
0
    let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
220
0
    out_buf.reserve(capacity);
221
222
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
223
    let mut out_writer = out_buf.writer();
224
225
    match settings.encoding {
226
        #[cfg(feature = "gzip")]
227
        CompressionEncoding::Gzip => {
228
            let mut gzip_encoder = GzEncoder::new(
229
                &decompressed_buf[0..len],
230
                // FIXME: support customizing the compression level
231
                flate2::Compression::new(6),
232
            );
233
            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
234
        }
235
        #[cfg(feature = "deflate")]
236
        CompressionEncoding::Deflate => {
237
            let mut deflate_encoder = ZlibEncoder::new(
238
                &decompressed_buf[0..len],
239
                // FIXME: support customizing the compression level
240
                flate2::Compression::new(6),
241
            );
242
            std::io::copy(&mut deflate_encoder, &mut out_writer)?;
243
        }
244
        #[cfg(feature = "zstd")]
245
        CompressionEncoding::Zstd => {
246
            let mut zstd_encoder = Encoder::new(
247
                &decompressed_buf[0..len],
248
                // FIXME: support customizing the compression level
249
                zstd::DEFAULT_COMPRESSION_LEVEL,
250
            )?;
251
            std::io::copy(&mut zstd_encoder, &mut out_writer)?;
252
        }
253
    }
254
255
    decompressed_buf.advance(len);
256
257
    Ok(())
258
}
259
260
/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
261
#[allow(unused_variables, unreachable_code)]
262
0
pub(crate) fn decompress(
263
0
    settings: CompressionSettings,
264
0
    compressed_buf: &mut BytesMut,
265
0
    out_buf: &mut BytesMut,
266
0
    len: usize,
267
0
) -> Result<(), std::io::Error> {
268
0
    let buffer_growth_interval = settings.buffer_growth_interval;
269
0
    let estimate_decompressed_len = len * 2;
270
0
    let capacity =
271
0
        ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
272
0
    out_buf.reserve(capacity);
273
274
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
275
    let mut out_writer = out_buf.writer();
276
277
    match settings.encoding {
278
        #[cfg(feature = "gzip")]
279
        CompressionEncoding::Gzip => {
280
            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
281
            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
282
        }
283
        #[cfg(feature = "deflate")]
284
        CompressionEncoding::Deflate => {
285
            let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]);
286
            std::io::copy(&mut deflate_decoder, &mut out_writer)?;
287
        }
288
        #[cfg(feature = "zstd")]
289
        CompressionEncoding::Zstd => {
290
            let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
291
            std::io::copy(&mut zstd_decoder, &mut out_writer)?;
292
        }
293
    }
294
295
    compressed_buf.advance(len);
296
297
    Ok(())
298
}
299
300
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
301
pub enum SingleMessageCompressionOverride {
302
    /// Inherit whatever compression is already configured. If the stream is compressed this
303
    /// message will also be configured.
304
    ///
305
    /// This is the default.
306
    #[default]
307
    Inherit,
308
    /// Don't compress this message, even if compression is enabled on the stream.
309
    Disable,
310
}
311
312
#[cfg(test)]
313
mod tests {
314
    #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
315
    use http::HeaderValue;
316
317
    use super::*;
318
319
    #[test]
320
    fn convert_none_into_header_value() {
321
        let encodings = EnabledCompressionEncodings::default();
322
323
        assert!(encodings.into_accept_encoding_header_value().is_none());
324
    }
325
326
    #[test]
327
    #[cfg(feature = "gzip")]
328
    fn convert_gzip_into_header_value() {
329
        const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
330
331
        let encodings = EnabledCompressionEncodings {
332
            inner: [Some(CompressionEncoding::Gzip), None, None],
333
        };
334
335
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
336
337
        let encodings = EnabledCompressionEncodings {
338
            inner: [None, None, Some(CompressionEncoding::Gzip)],
339
        };
340
341
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
342
    }
343
344
    #[test]
345
    #[cfg(feature = "zstd")]
346
    fn convert_zstd_into_header_value() {
347
        const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
348
349
        let encodings = EnabledCompressionEncodings {
350
            inner: [Some(CompressionEncoding::Zstd), None, None],
351
        };
352
353
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
354
355
        let encodings = EnabledCompressionEncodings {
356
            inner: [None, None, Some(CompressionEncoding::Zstd)],
357
        };
358
359
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
360
    }
361
362
    #[test]
363
    #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))]
364
    fn convert_compression_encodings_into_header_value() {
365
        let encodings = EnabledCompressionEncodings {
366
            inner: [
367
                Some(CompressionEncoding::Gzip),
368
                Some(CompressionEncoding::Deflate),
369
                Some(CompressionEncoding::Zstd),
370
            ],
371
        };
372
373
        assert_eq!(
374
            encodings.into_accept_encoding_header_value().unwrap(),
375
            HeaderValue::from_static("gzip,deflate,zstd,identity"),
376
        );
377
378
        let encodings = EnabledCompressionEncodings {
379
            inner: [
380
                Some(CompressionEncoding::Zstd),
381
                Some(CompressionEncoding::Deflate),
382
                Some(CompressionEncoding::Gzip),
383
            ],
384
        };
385
386
        assert_eq!(
387
            encodings.into_accept_encoding_header_value().unwrap(),
388
            HeaderValue::from_static("zstd,deflate,gzip,identity"),
389
        );
390
    }
391
}