Coverage Report

Created: 2025-11-28 06:44

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.12.3/src/codec/compression.rs
Line
Count
Source
1
use crate::{metadata::MetadataValue, Status};
2
use bytes::{Buf, BufMut, BytesMut};
3
#[cfg(feature = "gzip")]
4
use flate2::read::{GzDecoder, GzEncoder};
5
use std::fmt;
6
#[cfg(feature = "zstd")]
7
use zstd::stream::read::{Decoder, Encoder};
8
9
pub(crate) const ENCODING_HEADER: &str = "grpc-encoding";
10
pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding";
11
12
/// Struct used to configure which encodings are enabled on a server or channel.
13
///
14
/// Represents an ordered list of compression encodings that are enabled.
15
#[derive(Debug, Default, Clone, Copy)]
16
pub struct EnabledCompressionEncodings {
17
    inner: [Option<CompressionEncoding>; 2],
18
}
19
20
impl EnabledCompressionEncodings {
21
    /// Enable a [`CompressionEncoding`].
22
    ///
23
    /// Adds the new encoding to the end of the encoding list.
24
0
    pub fn enable(&mut self, encoding: CompressionEncoding) {
25
0
        for e in self.inner.iter_mut() {
26
0
            match e {
27
0
                Some(e) if *e == encoding => return,
28
                None => {
29
0
                    *e = Some(encoding);
30
0
                    return;
31
                }
32
                _ => continue,
33
            }
34
        }
35
0
    }
36
37
    /// Remove the last [`CompressionEncoding`].
38
0
    pub fn pop(&mut self) -> Option<CompressionEncoding> {
39
0
        self.inner
40
0
            .iter_mut()
41
0
            .rev()
42
0
            .find(|entry| entry.is_some())?
43
0
            .take()
44
0
    }
45
46
0
    pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> {
47
0
        let mut value = BytesMut::new();
48
0
        for encoding in self.inner.into_iter().flatten() {
49
            value.put_slice(encoding.as_str().as_bytes());
50
            value.put_u8(b',');
51
        }
52
53
0
        if value.is_empty() {
54
0
            return None;
55
0
        }
56
57
0
        value.put_slice(b"identity");
58
0
        Some(http::HeaderValue::from_maybe_shared(value).unwrap())
59
0
    }
60
61
    /// Check if a [`CompressionEncoding`] is enabled.
62
0
    pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool {
63
0
        self.inner.contains(&Some(encoding))
64
0
    }
65
66
    /// Check if any [`CompressionEncoding`]s are enabled.
67
0
    pub fn is_empty(&self) -> bool {
68
0
        self.inner.iter().all(|e| e.is_none())
69
0
    }
70
}
71
72
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73
pub(crate) struct CompressionSettings {
74
    pub(crate) encoding: CompressionEncoding,
75
    /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste.
76
    /// The default buffer growth interval is 8 kilobytes.
77
    pub(crate) buffer_growth_interval: usize,
78
}
79
80
/// The compression encodings Tonic supports.
81
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
82
#[non_exhaustive]
83
pub enum CompressionEncoding {
84
    #[allow(missing_docs)]
85
    #[cfg(feature = "gzip")]
86
    Gzip,
87
    #[allow(missing_docs)]
88
    #[cfg(feature = "zstd")]
89
    Zstd,
90
}
91
92
impl CompressionEncoding {
93
    pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[
94
        #[cfg(feature = "gzip")]
95
        CompressionEncoding::Gzip,
96
        #[cfg(feature = "zstd")]
97
        CompressionEncoding::Zstd,
98
    ];
99
100
    /// Based on the `grpc-accept-encoding` header, pick an encoding to use.
101
0
    pub(crate) fn from_accept_encoding_header(
102
0
        map: &http::HeaderMap,
103
0
        enabled_encodings: EnabledCompressionEncodings,
104
0
    ) -> Option<Self> {
105
0
        if enabled_encodings.is_empty() {
106
0
            return None;
107
0
        }
108
109
0
        let header_value = map.get(ACCEPT_ENCODING_HEADER)?;
110
0
        let header_value_str = header_value.to_str().ok()?;
111
112
0
        split_by_comma(header_value_str).find_map(|value| match value {
113
            #[cfg(feature = "gzip")]
114
            "gzip" => Some(CompressionEncoding::Gzip),
115
            #[cfg(feature = "zstd")]
116
            "zstd" => Some(CompressionEncoding::Zstd),
117
0
            _ => None,
118
0
        })
119
0
    }
120
121
    /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported.
122
0
    pub(crate) fn from_encoding_header(
123
0
        map: &http::HeaderMap,
124
0
        enabled_encodings: EnabledCompressionEncodings,
125
0
    ) -> Result<Option<Self>, Status> {
126
0
        let Some(header_value) = map.get(ENCODING_HEADER) else {
127
0
            return Ok(None);
128
        };
129
130
0
        match header_value.as_bytes() {
131
            #[cfg(feature = "gzip")]
132
            b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => {
133
                Ok(Some(CompressionEncoding::Gzip))
134
            }
135
            #[cfg(feature = "zstd")]
136
            b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => {
137
                Ok(Some(CompressionEncoding::Zstd))
138
            }
139
0
            b"identity" => Ok(None),
140
0
            other => {
141
                // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79.
142
                // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension
143
                let other_debug_string;
144
145
0
                let mut status = Status::unimplemented(format!(
146
0
                    "Content is compressed with `{}` which isn't supported",
147
0
                    match std::str::from_utf8(other) {
148
0
                        Ok(s) => s,
149
                        Err(_) => {
150
0
                            other_debug_string = format!("{other:?}");
151
0
                            &other_debug_string
152
                        }
153
                    }
154
                ));
155
156
0
                let header_value = enabled_encodings
157
0
                    .into_accept_encoding_header_value()
158
0
                    .map(MetadataValue::unchecked_from_header_value)
159
0
                    .unwrap_or_else(|| MetadataValue::from_static("identity"));
160
0
                status
161
0
                    .metadata_mut()
162
0
                    .insert(ACCEPT_ENCODING_HEADER, header_value);
163
164
0
                Err(status)
165
            }
166
        }
167
0
    }
168
169
    pub(crate) fn as_str(self) -> &'static str {
170
        match self {
171
            #[cfg(feature = "gzip")]
172
            CompressionEncoding::Gzip => "gzip",
173
            #[cfg(feature = "zstd")]
174
            CompressionEncoding::Zstd => "zstd",
175
        }
176
    }
177
178
    #[cfg(any(feature = "gzip", feature = "zstd"))]
179
    pub(crate) fn into_header_value(self) -> http::HeaderValue {
180
        http::HeaderValue::from_static(self.as_str())
181
    }
182
}
183
184
impl fmt::Display for CompressionEncoding {
185
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186
0
        f.write_str(self.as_str())
187
0
    }
188
}
189
190
0
fn split_by_comma(s: &str) -> impl Iterator<Item = &str> {
191
0
    s.split(',').map(|s| s.trim())
192
0
}
193
194
/// Compress `len` bytes from `decompressed_buf` into `out_buf`.
195
/// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it.
196
#[allow(unused_variables, unreachable_code)]
197
0
pub(crate) fn compress(
198
0
    settings: CompressionSettings,
199
0
    decompressed_buf: &mut BytesMut,
200
0
    out_buf: &mut BytesMut,
201
0
    len: usize,
202
0
) -> Result<(), std::io::Error> {
203
0
    let buffer_growth_interval = settings.buffer_growth_interval;
204
0
    let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval;
205
0
    out_buf.reserve(capacity);
206
207
    #[cfg(any(feature = "gzip", feature = "zstd"))]
208
    let mut out_writer = out_buf.writer();
209
210
    match settings.encoding {
211
        #[cfg(feature = "gzip")]
212
        CompressionEncoding::Gzip => {
213
            let mut gzip_encoder = GzEncoder::new(
214
                &decompressed_buf[0..len],
215
                // FIXME: support customizing the compression level
216
                flate2::Compression::new(6),
217
            );
218
            std::io::copy(&mut gzip_encoder, &mut out_writer)?;
219
        }
220
        #[cfg(feature = "zstd")]
221
        CompressionEncoding::Zstd => {
222
            let mut zstd_encoder = Encoder::new(
223
                &decompressed_buf[0..len],
224
                // FIXME: support customizing the compression level
225
                zstd::DEFAULT_COMPRESSION_LEVEL,
226
            )?;
227
            std::io::copy(&mut zstd_encoder, &mut out_writer)?;
228
        }
229
    }
230
231
    decompressed_buf.advance(len);
232
233
    Ok(())
234
}
235
236
/// Decompress `len` bytes from `compressed_buf` into `out_buf`.
237
#[allow(unused_variables, unreachable_code)]
238
0
pub(crate) fn decompress(
239
0
    settings: CompressionSettings,
240
0
    compressed_buf: &mut BytesMut,
241
0
    out_buf: &mut BytesMut,
242
0
    len: usize,
243
0
) -> Result<(), std::io::Error> {
244
0
    let buffer_growth_interval = settings.buffer_growth_interval;
245
0
    let estimate_decompressed_len = len * 2;
246
0
    let capacity =
247
0
        ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval;
248
0
    out_buf.reserve(capacity);
249
250
    #[cfg(any(feature = "gzip", feature = "zstd"))]
251
    let mut out_writer = out_buf.writer();
252
253
    match settings.encoding {
254
        #[cfg(feature = "gzip")]
255
        CompressionEncoding::Gzip => {
256
            let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]);
257
            std::io::copy(&mut gzip_decoder, &mut out_writer)?;
258
        }
259
        #[cfg(feature = "zstd")]
260
        CompressionEncoding::Zstd => {
261
            let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?;
262
            std::io::copy(&mut zstd_decoder, &mut out_writer)?;
263
        }
264
    }
265
266
    compressed_buf.advance(len);
267
268
    Ok(())
269
}
270
271
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
272
pub enum SingleMessageCompressionOverride {
273
    /// Inherit whatever compression is already configured. If the stream is compressed this
274
    /// message will also be configured.
275
    ///
276
    /// This is the default.
277
    #[default]
278
    Inherit,
279
    /// Don't compress this message, even if compression is enabled on the stream.
280
    Disable,
281
}
282
283
#[cfg(test)]
284
mod tests {
285
    use http::HeaderValue;
286
287
    use super::*;
288
289
    #[test]
290
    fn convert_none_into_header_value() {
291
        let encodings = EnabledCompressionEncodings::default();
292
293
        assert!(encodings.into_accept_encoding_header_value().is_none());
294
    }
295
296
    #[test]
297
    #[cfg(feature = "gzip")]
298
    fn convert_gzip_into_header_value() {
299
        const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity");
300
301
        let encodings = EnabledCompressionEncodings {
302
            inner: [Some(CompressionEncoding::Gzip), None],
303
        };
304
305
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
306
307
        let encodings = EnabledCompressionEncodings {
308
            inner: [None, Some(CompressionEncoding::Gzip)],
309
        };
310
311
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP);
312
    }
313
314
    #[test]
315
    #[cfg(feature = "zstd")]
316
    fn convert_zstd_into_header_value() {
317
        const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity");
318
319
        let encodings = EnabledCompressionEncodings {
320
            inner: [Some(CompressionEncoding::Zstd), None],
321
        };
322
323
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
324
325
        let encodings = EnabledCompressionEncodings {
326
            inner: [None, Some(CompressionEncoding::Zstd)],
327
        };
328
329
        assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD);
330
    }
331
332
    #[test]
333
    #[cfg(all(feature = "gzip", feature = "zstd"))]
334
    fn convert_gzip_and_zstd_into_header_value() {
335
        let encodings = EnabledCompressionEncodings {
336
            inner: [
337
                Some(CompressionEncoding::Gzip),
338
                Some(CompressionEncoding::Zstd),
339
            ],
340
        };
341
342
        assert_eq!(
343
            encodings.into_accept_encoding_header_value().unwrap(),
344
            HeaderValue::from_static("gzip,zstd,identity"),
345
        );
346
347
        let encodings = EnabledCompressionEncodings {
348
            inner: [
349
                Some(CompressionEncoding::Zstd),
350
                Some(CompressionEncoding::Gzip),
351
            ],
352
        };
353
354
        assert_eq!(
355
            encodings.into_accept_encoding_header_value().unwrap(),
356
            HeaderValue::from_static("zstd,gzip,identity"),
357
        );
358
    }
359
}