/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.13.0/src/codec/compression.rs
Line | Count | Source |
1 | | use crate::{metadata::MetadataValue, Status}; |
2 | | use bytes::{Buf, BufMut, BytesMut}; |
3 | | #[cfg(feature = "gzip")] |
4 | | use flate2::read::{GzDecoder, GzEncoder}; |
5 | | #[cfg(feature = "deflate")] |
6 | | use flate2::read::{ZlibDecoder, ZlibEncoder}; |
7 | | use std::fmt; |
8 | | #[cfg(feature = "zstd")] |
9 | | use zstd::stream::read::{Decoder, Encoder}; |
10 | | |
11 | | pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; |
12 | | pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; |
13 | | |
14 | | /// Struct used to configure which encodings are enabled on a server or channel. |
15 | | /// |
16 | | /// Represents an ordered list of compression encodings that are enabled. |
17 | | #[derive(Debug, Default, Clone, Copy)] |
18 | | pub struct EnabledCompressionEncodings { |
19 | | inner: [Option<CompressionEncoding>; 3], |
20 | | } |
21 | | |
22 | | impl EnabledCompressionEncodings { |
23 | | /// Enable a [`CompressionEncoding`]. |
24 | | /// |
25 | | /// Adds the new encoding to the end of the encoding list. |
26 | 0 | pub fn enable(&mut self, encoding: CompressionEncoding) { |
27 | 0 | for e in self.inner.iter_mut() { |
28 | 0 | match e { |
29 | 0 | Some(e) if *e == encoding => return, |
30 | | None => { |
31 | 0 | *e = Some(encoding); |
32 | 0 | return; |
33 | | } |
34 | | _ => continue, |
35 | | } |
36 | | } |
37 | 0 | } |
38 | | |
39 | | /// Remove the last [`CompressionEncoding`]. |
40 | 0 | pub fn pop(&mut self) -> Option<CompressionEncoding> { |
41 | 0 | self.inner |
42 | 0 | .iter_mut() |
43 | 0 | .rev() |
44 | 0 | .find(|entry| entry.is_some())? |
45 | 0 | .take() |
46 | 0 | } |
47 | | |
48 | 0 | pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> { |
49 | 0 | let mut value = BytesMut::new(); |
50 | 0 | for encoding in self.inner.into_iter().flatten() { |
51 | | value.put_slice(encoding.as_str().as_bytes()); |
52 | | value.put_u8(b','); |
53 | | } |
54 | | |
55 | 0 | if value.is_empty() { |
56 | 0 | return None; |
57 | 0 | } |
58 | | |
59 | 0 | value.put_slice(b"identity"); |
60 | 0 | Some(http::HeaderValue::from_maybe_shared(value).unwrap()) |
61 | 0 | } |
62 | | |
63 | | /// Check if a [`CompressionEncoding`] is enabled. |
64 | 0 | pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool { |
65 | 0 | self.inner.contains(&Some(encoding)) |
66 | 0 | } |
67 | | |
68 | | /// Check if any [`CompressionEncoding`]s are enabled. |
69 | 0 | pub fn is_empty(&self) -> bool { |
70 | 0 | self.inner.iter().all(|e| e.is_none()) |
71 | 0 | } |
72 | | } |
73 | | |
74 | | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
75 | | pub(crate) struct CompressionSettings { |
76 | | pub(crate) encoding: CompressionEncoding, |
77 | | /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste. |
78 | | /// The default buffer growth interval is 8 kilobytes. |
79 | | pub(crate) buffer_growth_interval: usize, |
80 | | } |
81 | | |
82 | | /// The compression encodings Tonic supports. |
83 | | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
84 | | #[non_exhaustive] |
85 | | pub enum CompressionEncoding { |
86 | | #[allow(missing_docs)] |
87 | | #[cfg(feature = "gzip")] |
88 | | Gzip, |
89 | | #[allow(missing_docs)] |
90 | | #[cfg(feature = "deflate")] |
91 | | Deflate, |
92 | | #[allow(missing_docs)] |
93 | | #[cfg(feature = "zstd")] |
94 | | Zstd, |
95 | | } |
96 | | |
97 | | impl CompressionEncoding { |
98 | | pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[ |
99 | | #[cfg(feature = "gzip")] |
100 | | CompressionEncoding::Gzip, |
101 | | #[cfg(feature = "deflate")] |
102 | | CompressionEncoding::Deflate, |
103 | | #[cfg(feature = "zstd")] |
104 | | CompressionEncoding::Zstd, |
105 | | ]; |
106 | | |
107 | | /// Based on the `grpc-accept-encoding` header, pick an encoding to use. |
108 | 0 | pub(crate) fn from_accept_encoding_header( |
109 | 0 | map: &http::HeaderMap, |
110 | 0 | enabled_encodings: EnabledCompressionEncodings, |
111 | 0 | ) -> Option<Self> { |
112 | 0 | if enabled_encodings.is_empty() { |
113 | 0 | return None; |
114 | 0 | } |
115 | | |
116 | 0 | let header_value = map.get(ACCEPT_ENCODING_HEADER)?; |
117 | 0 | let header_value_str = header_value.to_str().ok()?; |
118 | | |
119 | 0 | split_by_comma(header_value_str).find_map(|value| match value { |
120 | | #[cfg(feature = "gzip")] |
121 | | "gzip" => Some(CompressionEncoding::Gzip), |
122 | | #[cfg(feature = "deflate")] |
123 | | "deflate" => Some(CompressionEncoding::Deflate), |
124 | | #[cfg(feature = "zstd")] |
125 | | "zstd" => Some(CompressionEncoding::Zstd), |
126 | 0 | _ => None, |
127 | 0 | }) |
128 | 0 | } |
129 | | |
130 | | /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported. |
131 | 0 | pub(crate) fn from_encoding_header( |
132 | 0 | map: &http::HeaderMap, |
133 | 0 | enabled_encodings: EnabledCompressionEncodings, |
134 | 0 | ) -> Result<Option<Self>, Status> { |
135 | 0 | let Some(header_value) = map.get(ENCODING_HEADER) else { |
136 | 0 | return Ok(None); |
137 | | }; |
138 | | |
139 | 0 | match header_value.as_bytes() { |
140 | | #[cfg(feature = "gzip")] |
141 | | b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { |
142 | | Ok(Some(CompressionEncoding::Gzip)) |
143 | | } |
144 | | #[cfg(feature = "deflate")] |
145 | | b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => { |
146 | | Ok(Some(CompressionEncoding::Deflate)) |
147 | | } |
148 | | #[cfg(feature = "zstd")] |
149 | | b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { |
150 | | Ok(Some(CompressionEncoding::Zstd)) |
151 | | } |
152 | 0 | b"identity" => Ok(None), |
153 | 0 | other => { |
154 | | // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79. |
155 | | // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension |
156 | | let other_debug_string; |
157 | | |
158 | 0 | let mut status = Status::unimplemented(format!( |
159 | 0 | "Content is compressed with `{}` which isn't supported", |
160 | 0 | match std::str::from_utf8(other) { |
161 | 0 | Ok(s) => s, |
162 | | Err(_) => { |
163 | 0 | other_debug_string = format!("{other:?}"); |
164 | 0 | &other_debug_string |
165 | | } |
166 | | } |
167 | | )); |
168 | | |
169 | 0 | let header_value = enabled_encodings |
170 | 0 | .into_accept_encoding_header_value() |
171 | 0 | .map(MetadataValue::unchecked_from_header_value) |
172 | 0 | .unwrap_or_else(|| MetadataValue::from_static("identity")); |
173 | 0 | status |
174 | 0 | .metadata_mut() |
175 | 0 | .insert(ACCEPT_ENCODING_HEADER, header_value); |
176 | | |
177 | 0 | Err(status) |
178 | | } |
179 | | } |
180 | 0 | } |
181 | | |
182 | | pub(crate) fn as_str(self) -> &'static str { |
183 | | match self { |
184 | | #[cfg(feature = "gzip")] |
185 | | CompressionEncoding::Gzip => "gzip", |
186 | | #[cfg(feature = "deflate")] |
187 | | CompressionEncoding::Deflate => "deflate", |
188 | | #[cfg(feature = "zstd")] |
189 | | CompressionEncoding::Zstd => "zstd", |
190 | | } |
191 | | } |
192 | | |
193 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
194 | | pub(crate) fn into_header_value(self) -> http::HeaderValue { |
195 | | http::HeaderValue::from_static(self.as_str()) |
196 | | } |
197 | | } |
198 | | |
199 | | impl fmt::Display for CompressionEncoding { |
200 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
201 | 0 | f.write_str(self.as_str()) |
202 | 0 | } |
203 | | } |
204 | | |
205 | 0 | fn split_by_comma(s: &str) -> impl Iterator<Item = &str> { |
206 | 0 | s.split(',').map(|s| s.trim()) |
207 | 0 | } |
208 | | |
209 | | /// Compress `len` bytes from `decompressed_buf` into `out_buf`. |
210 | | /// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it. |
211 | | #[allow(unused_variables, unreachable_code)] |
212 | 0 | pub(crate) fn compress( |
213 | 0 | settings: CompressionSettings, |
214 | 0 | decompressed_buf: &mut BytesMut, |
215 | 0 | out_buf: &mut BytesMut, |
216 | 0 | len: usize, |
217 | 0 | ) -> Result<(), std::io::Error> { |
218 | 0 | let buffer_growth_interval = settings.buffer_growth_interval; |
219 | 0 | let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval; |
220 | 0 | out_buf.reserve(capacity); |
221 | | |
222 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
223 | | let mut out_writer = out_buf.writer(); |
224 | | |
225 | | match settings.encoding { |
226 | | #[cfg(feature = "gzip")] |
227 | | CompressionEncoding::Gzip => { |
228 | | let mut gzip_encoder = GzEncoder::new( |
229 | | &decompressed_buf[0..len], |
230 | | // FIXME: support customizing the compression level |
231 | | flate2::Compression::new(6), |
232 | | ); |
233 | | std::io::copy(&mut gzip_encoder, &mut out_writer)?; |
234 | | } |
235 | | #[cfg(feature = "deflate")] |
236 | | CompressionEncoding::Deflate => { |
237 | | let mut deflate_encoder = ZlibEncoder::new( |
238 | | &decompressed_buf[0..len], |
239 | | // FIXME: support customizing the compression level |
240 | | flate2::Compression::new(6), |
241 | | ); |
242 | | std::io::copy(&mut deflate_encoder, &mut out_writer)?; |
243 | | } |
244 | | #[cfg(feature = "zstd")] |
245 | | CompressionEncoding::Zstd => { |
246 | | let mut zstd_encoder = Encoder::new( |
247 | | &decompressed_buf[0..len], |
248 | | // FIXME: support customizing the compression level |
249 | | zstd::DEFAULT_COMPRESSION_LEVEL, |
250 | | )?; |
251 | | std::io::copy(&mut zstd_encoder, &mut out_writer)?; |
252 | | } |
253 | | } |
254 | | |
255 | | decompressed_buf.advance(len); |
256 | | |
257 | | Ok(()) |
258 | | } |
259 | | |
260 | | /// Decompress `len` bytes from `compressed_buf` into `out_buf`. |
261 | | #[allow(unused_variables, unreachable_code)] |
262 | 0 | pub(crate) fn decompress( |
263 | 0 | settings: CompressionSettings, |
264 | 0 | compressed_buf: &mut BytesMut, |
265 | 0 | out_buf: &mut BytesMut, |
266 | 0 | len: usize, |
267 | 0 | ) -> Result<(), std::io::Error> { |
268 | 0 | let buffer_growth_interval = settings.buffer_growth_interval; |
269 | 0 | let estimate_decompressed_len = len * 2; |
270 | 0 | let capacity = |
271 | 0 | ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval; |
272 | 0 | out_buf.reserve(capacity); |
273 | | |
274 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
275 | | let mut out_writer = out_buf.writer(); |
276 | | |
277 | | match settings.encoding { |
278 | | #[cfg(feature = "gzip")] |
279 | | CompressionEncoding::Gzip => { |
280 | | let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); |
281 | | std::io::copy(&mut gzip_decoder, &mut out_writer)?; |
282 | | } |
283 | | #[cfg(feature = "deflate")] |
284 | | CompressionEncoding::Deflate => { |
285 | | let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]); |
286 | | std::io::copy(&mut deflate_decoder, &mut out_writer)?; |
287 | | } |
288 | | #[cfg(feature = "zstd")] |
289 | | CompressionEncoding::Zstd => { |
290 | | let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; |
291 | | std::io::copy(&mut zstd_decoder, &mut out_writer)?; |
292 | | } |
293 | | } |
294 | | |
295 | | compressed_buf.advance(len); |
296 | | |
297 | | Ok(()) |
298 | | } |
299 | | |
300 | | #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] |
301 | | pub enum SingleMessageCompressionOverride { |
302 | | /// Inherit whatever compression is already configured. If the stream is compressed this |
303 | | /// message will also be configured. |
304 | | /// |
305 | | /// This is the default. |
306 | | #[default] |
307 | | Inherit, |
308 | | /// Don't compress this message, even if compression is enabled on the stream. |
309 | | Disable, |
310 | | } |
311 | | |
312 | | #[cfg(test)] |
313 | | mod tests { |
314 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
315 | | use http::HeaderValue; |
316 | | |
317 | | use super::*; |
318 | | |
319 | | #[test] |
320 | | fn convert_none_into_header_value() { |
321 | | let encodings = EnabledCompressionEncodings::default(); |
322 | | |
323 | | assert!(encodings.into_accept_encoding_header_value().is_none()); |
324 | | } |
325 | | |
326 | | #[test] |
327 | | #[cfg(feature = "gzip")] |
328 | | fn convert_gzip_into_header_value() { |
329 | | const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity"); |
330 | | |
331 | | let encodings = EnabledCompressionEncodings { |
332 | | inner: [Some(CompressionEncoding::Gzip), None, None], |
333 | | }; |
334 | | |
335 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP); |
336 | | |
337 | | let encodings = EnabledCompressionEncodings { |
338 | | inner: [None, None, Some(CompressionEncoding::Gzip)], |
339 | | }; |
340 | | |
341 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP); |
342 | | } |
343 | | |
344 | | #[test] |
345 | | #[cfg(feature = "zstd")] |
346 | | fn convert_zstd_into_header_value() { |
347 | | const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity"); |
348 | | |
349 | | let encodings = EnabledCompressionEncodings { |
350 | | inner: [Some(CompressionEncoding::Zstd), None, None], |
351 | | }; |
352 | | |
353 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD); |
354 | | |
355 | | let encodings = EnabledCompressionEncodings { |
356 | | inner: [None, None, Some(CompressionEncoding::Zstd)], |
357 | | }; |
358 | | |
359 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD); |
360 | | } |
361 | | |
362 | | #[test] |
363 | | #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))] |
364 | | fn convert_compression_encodings_into_header_value() { |
365 | | let encodings = EnabledCompressionEncodings { |
366 | | inner: [ |
367 | | Some(CompressionEncoding::Gzip), |
368 | | Some(CompressionEncoding::Deflate), |
369 | | Some(CompressionEncoding::Zstd), |
370 | | ], |
371 | | }; |
372 | | |
373 | | assert_eq!( |
374 | | encodings.into_accept_encoding_header_value().unwrap(), |
375 | | HeaderValue::from_static("gzip,deflate,zstd,identity"), |
376 | | ); |
377 | | |
378 | | let encodings = EnabledCompressionEncodings { |
379 | | inner: [ |
380 | | Some(CompressionEncoding::Zstd), |
381 | | Some(CompressionEncoding::Deflate), |
382 | | Some(CompressionEncoding::Gzip), |
383 | | ], |
384 | | }; |
385 | | |
386 | | assert_eq!( |
387 | | encodings.into_accept_encoding_header_value().unwrap(), |
388 | | HeaderValue::from_static("zstd,deflate,gzip,identity"), |
389 | | ); |
390 | | } |
391 | | } |