/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.14.5/src/codec/compression.rs
Line | Count | Source |
1 | | use crate::{metadata::MetadataValue, Status}; |
2 | | use bytes::{Buf, BufMut, BytesMut}; |
3 | | #[cfg(feature = "gzip")] |
4 | | use flate2::read::{GzDecoder, GzEncoder}; |
5 | | #[cfg(feature = "deflate")] |
6 | | use flate2::read::{ZlibDecoder, ZlibEncoder}; |
7 | | use std::{borrow::Cow, fmt}; |
8 | | #[cfg(feature = "zstd")] |
9 | | use zstd::stream::read::{Decoder, Encoder}; |
10 | | |
11 | | pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; |
12 | | pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; |
13 | | |
14 | | /// Struct used to configure which encodings are enabled on a server or channel. |
15 | | /// |
16 | | /// Represents an ordered list of compression encodings that are enabled. |
17 | | #[derive(Debug, Default, Clone, Copy)] |
18 | | pub struct EnabledCompressionEncodings { |
19 | | inner: [Option<CompressionEncoding>; 3], |
20 | | } |
21 | | |
22 | | impl EnabledCompressionEncodings { |
23 | | /// Enable a [`CompressionEncoding`]. |
24 | | /// |
25 | | /// Adds the new encoding to the end of the encoding list. |
26 | 0 | pub fn enable(&mut self, encoding: CompressionEncoding) { |
27 | 0 | for e in self.inner.iter_mut() { |
28 | 0 | match e { |
29 | 0 | Some(e) if *e == encoding => return, |
30 | | None => { |
31 | 0 | *e = Some(encoding); |
32 | 0 | return; |
33 | | } |
34 | | _ => continue, |
35 | | } |
36 | | } |
37 | 0 | } |
38 | | |
39 | | /// Remove the last [`CompressionEncoding`]. |
40 | 0 | pub fn pop(&mut self) -> Option<CompressionEncoding> { |
41 | 0 | self.inner |
42 | 0 | .iter_mut() |
43 | 0 | .rev() |
44 | 0 | .find(|entry| entry.is_some())? |
45 | 0 | .take() |
46 | 0 | } |
47 | | |
48 | 0 | pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> { |
49 | 0 | let mut value = BytesMut::new(); |
50 | 0 | for encoding in self.inner.into_iter().flatten() { |
51 | | value.put_slice(encoding.as_str().as_bytes()); |
52 | | value.put_u8(b','); |
53 | | } |
54 | | |
55 | 0 | if value.is_empty() { |
56 | 0 | return None; |
57 | 0 | } |
58 | | |
59 | 0 | value.put_slice(b"identity"); |
60 | 0 | Some(http::HeaderValue::from_maybe_shared(value).unwrap()) |
61 | 0 | } |
62 | | |
63 | | /// Check if a [`CompressionEncoding`] is enabled. |
64 | 0 | pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool { |
65 | 0 | self.inner.contains(&Some(encoding)) |
66 | 0 | } |
67 | | |
68 | | /// Check if any [`CompressionEncoding`]s are enabled. |
69 | 0 | pub fn is_empty(&self) -> bool { |
70 | 0 | self.inner.iter().all(|e| e.is_none()) |
71 | 0 | } |
72 | | } |
73 | | |
74 | | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
75 | | pub(crate) struct CompressionSettings { |
76 | | pub(crate) encoding: CompressionEncoding, |
77 | | /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste. |
78 | | /// The default buffer growth interval is 8 kilobytes. |
79 | | pub(crate) buffer_growth_interval: usize, |
80 | | } |
81 | | |
82 | | /// The compression encodings Tonic supports. |
83 | | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
84 | | #[non_exhaustive] |
85 | | pub enum CompressionEncoding { |
86 | | #[allow(missing_docs)] |
87 | | #[cfg(feature = "gzip")] |
88 | | Gzip, |
89 | | #[allow(missing_docs)] |
90 | | #[cfg(feature = "deflate")] |
91 | | Deflate, |
92 | | #[allow(missing_docs)] |
93 | | #[cfg(feature = "zstd")] |
94 | | Zstd, |
95 | | } |
96 | | |
97 | | impl CompressionEncoding { |
98 | | pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[ |
99 | | #[cfg(feature = "gzip")] |
100 | | CompressionEncoding::Gzip, |
101 | | #[cfg(feature = "deflate")] |
102 | | CompressionEncoding::Deflate, |
103 | | #[cfg(feature = "zstd")] |
104 | | CompressionEncoding::Zstd, |
105 | | ]; |
106 | | |
107 | | /// Based on the `grpc-accept-encoding` header, pick an encoding to use. |
108 | 0 | pub(crate) fn from_accept_encoding_header( |
109 | 0 | map: &http::HeaderMap, |
110 | 0 | enabled_encodings: EnabledCompressionEncodings, |
111 | 0 | ) -> Option<Self> { |
112 | 0 | if enabled_encodings.is_empty() { |
113 | 0 | return None; |
114 | 0 | } |
115 | | |
116 | 0 | let header_value = map.get(ACCEPT_ENCODING_HEADER)?; |
117 | 0 | let header_value_str = header_value.to_str().ok()?; |
118 | | |
119 | 0 | split_by_comma(header_value_str).find_map(|value| match value { |
120 | | #[cfg(feature = "gzip")] |
121 | | "gzip" => Some(CompressionEncoding::Gzip), |
122 | | #[cfg(feature = "deflate")] |
123 | | "deflate" => Some(CompressionEncoding::Deflate), |
124 | | #[cfg(feature = "zstd")] |
125 | | "zstd" => Some(CompressionEncoding::Zstd), |
126 | 0 | _ => None, |
127 | 0 | }) |
128 | 0 | } |
129 | | |
130 | | /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported. |
131 | 0 | pub(crate) fn from_encoding_header( |
132 | 0 | map: &http::HeaderMap, |
133 | 0 | enabled_encodings: EnabledCompressionEncodings, |
134 | 0 | ) -> Result<Option<Self>, Status> { |
135 | 0 | let Some(header_value) = map.get(ENCODING_HEADER) else { |
136 | 0 | return Ok(None); |
137 | | }; |
138 | | |
139 | 0 | match header_value.as_bytes() { |
140 | | #[cfg(feature = "gzip")] |
141 | | b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { |
142 | | Ok(Some(CompressionEncoding::Gzip)) |
143 | | } |
144 | | #[cfg(feature = "deflate")] |
145 | | b"deflate" if enabled_encodings.is_enabled(CompressionEncoding::Deflate) => { |
146 | | Ok(Some(CompressionEncoding::Deflate)) |
147 | | } |
148 | | #[cfg(feature = "zstd")] |
149 | | b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { |
150 | | Ok(Some(CompressionEncoding::Zstd)) |
151 | | } |
152 | 0 | b"identity" => Ok(None), |
153 | 0 | other => { |
154 | 0 | let other = match std::str::from_utf8(other) { |
155 | 0 | Ok(s) => Cow::Borrowed(s), |
156 | 0 | Err(_) => Cow::Owned(format!("{other:?}")), |
157 | | }; |
158 | | |
159 | 0 | let mut status = Status::unimplemented(format!( |
160 | | "Content is compressed with `{other}` which isn't supported" |
161 | | )); |
162 | | |
163 | 0 | let header_value = enabled_encodings |
164 | 0 | .into_accept_encoding_header_value() |
165 | 0 | .map(MetadataValue::unchecked_from_header_value) |
166 | 0 | .unwrap_or_else(|| MetadataValue::from_static("identity")); |
167 | 0 | status |
168 | 0 | .metadata_mut() |
169 | 0 | .insert(ACCEPT_ENCODING_HEADER, header_value); |
170 | | |
171 | 0 | Err(status) |
172 | | } |
173 | | } |
174 | 0 | } |
175 | | |
176 | | pub(crate) fn as_str(self) -> &'static str { |
177 | | match self { |
178 | | #[cfg(feature = "gzip")] |
179 | | CompressionEncoding::Gzip => "gzip", |
180 | | #[cfg(feature = "deflate")] |
181 | | CompressionEncoding::Deflate => "deflate", |
182 | | #[cfg(feature = "zstd")] |
183 | | CompressionEncoding::Zstd => "zstd", |
184 | | } |
185 | | } |
186 | | |
187 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
188 | | pub(crate) fn into_header_value(self) -> http::HeaderValue { |
189 | | http::HeaderValue::from_static(self.as_str()) |
190 | | } |
191 | | } |
192 | | |
193 | | impl fmt::Display for CompressionEncoding { |
194 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
195 | 0 | f.write_str(self.as_str()) |
196 | 0 | } |
197 | | } |
198 | | |
199 | 0 | fn split_by_comma(s: &str) -> impl Iterator<Item = &str> { |
200 | 0 | s.split(',').map(|s| s.trim()) |
201 | 0 | } |
202 | | |
203 | | /// Compress `len` bytes from `decompressed_buf` into `out_buf`. |
204 | | /// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it. |
205 | | #[allow(unused_variables, unreachable_code)] |
206 | 0 | pub(crate) fn compress( |
207 | 0 | settings: CompressionSettings, |
208 | 0 | decompressed_buf: &mut BytesMut, |
209 | 0 | out_buf: &mut BytesMut, |
210 | 0 | len: usize, |
211 | 0 | ) -> Result<(), std::io::Error> { |
212 | 0 | let buffer_growth_interval = settings.buffer_growth_interval; |
213 | 0 | let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval; |
214 | 0 | out_buf.reserve(capacity); |
215 | | |
216 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
217 | | let mut out_writer = out_buf.writer(); |
218 | | |
219 | | match settings.encoding { |
220 | | #[cfg(feature = "gzip")] |
221 | | CompressionEncoding::Gzip => { |
222 | | let mut gzip_encoder = GzEncoder::new( |
223 | | &decompressed_buf[0..len], |
224 | | // FIXME: support customizing the compression level |
225 | | flate2::Compression::new(6), |
226 | | ); |
227 | | std::io::copy(&mut gzip_encoder, &mut out_writer)?; |
228 | | } |
229 | | #[cfg(feature = "deflate")] |
230 | | CompressionEncoding::Deflate => { |
231 | | let mut deflate_encoder = ZlibEncoder::new( |
232 | | &decompressed_buf[0..len], |
233 | | // FIXME: support customizing the compression level |
234 | | flate2::Compression::new(6), |
235 | | ); |
236 | | std::io::copy(&mut deflate_encoder, &mut out_writer)?; |
237 | | } |
238 | | #[cfg(feature = "zstd")] |
239 | | CompressionEncoding::Zstd => { |
240 | | let mut zstd_encoder = Encoder::new( |
241 | | &decompressed_buf[0..len], |
242 | | // FIXME: support customizing the compression level |
243 | | zstd::DEFAULT_COMPRESSION_LEVEL, |
244 | | )?; |
245 | | std::io::copy(&mut zstd_encoder, &mut out_writer)?; |
246 | | } |
247 | | } |
248 | | |
249 | | decompressed_buf.advance(len); |
250 | | |
251 | | Ok(()) |
252 | | } |
253 | | |
254 | | /// Decompress `len` bytes from `compressed_buf` into `out_buf`. |
255 | | #[allow(unused_variables, unreachable_code)] |
256 | 0 | pub(crate) fn decompress( |
257 | 0 | settings: CompressionSettings, |
258 | 0 | compressed_buf: &mut BytesMut, |
259 | 0 | mut out_buf: bytes::buf::Limit<&mut BytesMut>, |
260 | 0 | len: usize, |
261 | 0 | ) -> Result<(), std::io::Error> { |
262 | 0 | let buffer_growth_interval = settings.buffer_growth_interval; |
263 | 0 | let estimate_decompressed_len = len * 2; |
264 | 0 | let capacity = std::cmp::min( |
265 | 0 | bytes::buf::Limit::limit(&out_buf), |
266 | 0 | ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval, |
267 | | ); |
268 | | |
269 | 0 | out_buf.get_mut().reserve(capacity); |
270 | | |
271 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
272 | | let mut out_writer = out_buf.writer(); |
273 | | |
274 | | match settings.encoding { |
275 | | #[cfg(feature = "gzip")] |
276 | | CompressionEncoding::Gzip => { |
277 | | let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); |
278 | | std::io::copy(&mut gzip_decoder, &mut out_writer)?; |
279 | | } |
280 | | #[cfg(feature = "deflate")] |
281 | | CompressionEncoding::Deflate => { |
282 | | let mut deflate_decoder = ZlibDecoder::new(&compressed_buf[0..len]); |
283 | | std::io::copy(&mut deflate_decoder, &mut out_writer)?; |
284 | | } |
285 | | #[cfg(feature = "zstd")] |
286 | | CompressionEncoding::Zstd => { |
287 | | let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; |
288 | | std::io::copy(&mut zstd_decoder, &mut out_writer)?; |
289 | | } |
290 | | } |
291 | | |
292 | | compressed_buf.advance(len); |
293 | | |
294 | | Ok(()) |
295 | | } |
296 | | |
297 | | /// Controls compression behavior for individual messages within a stream. |
298 | | #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] |
299 | | pub enum SingleMessageCompressionOverride { |
300 | | /// Inherit whatever compression is already configured. If the stream is compressed this |
301 | | /// message will also be configured. |
302 | | /// |
303 | | /// This is the default. |
304 | | #[default] |
305 | | Inherit, |
306 | | /// Don't compress this message, even if compression is enabled on the stream. |
307 | | Disable, |
308 | | } |
309 | | |
310 | | #[cfg(test)] |
311 | | mod tests { |
312 | | #[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))] |
313 | | use http::HeaderValue; |
314 | | |
315 | | use super::*; |
316 | | |
317 | | #[test] |
318 | | fn convert_none_into_header_value() { |
319 | | let encodings = EnabledCompressionEncodings::default(); |
320 | | |
321 | | assert!(encodings.into_accept_encoding_header_value().is_none()); |
322 | | } |
323 | | |
324 | | #[test] |
325 | | #[cfg(feature = "gzip")] |
326 | | fn convert_gzip_into_header_value() { |
327 | | const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity"); |
328 | | |
329 | | let encodings = EnabledCompressionEncodings { |
330 | | inner: [Some(CompressionEncoding::Gzip), None, None], |
331 | | }; |
332 | | |
333 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP); |
334 | | |
335 | | let encodings = EnabledCompressionEncodings { |
336 | | inner: [None, None, Some(CompressionEncoding::Gzip)], |
337 | | }; |
338 | | |
339 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP); |
340 | | } |
341 | | |
342 | | #[test] |
343 | | #[cfg(feature = "zstd")] |
344 | | fn convert_zstd_into_header_value() { |
345 | | const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity"); |
346 | | |
347 | | let encodings = EnabledCompressionEncodings { |
348 | | inner: [Some(CompressionEncoding::Zstd), None, None], |
349 | | }; |
350 | | |
351 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD); |
352 | | |
353 | | let encodings = EnabledCompressionEncodings { |
354 | | inner: [None, None, Some(CompressionEncoding::Zstd)], |
355 | | }; |
356 | | |
357 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD); |
358 | | } |
359 | | |
360 | | #[test] |
361 | | #[cfg(all(feature = "gzip", feature = "deflate", feature = "zstd"))] |
362 | | fn convert_compression_encodings_into_header_value() { |
363 | | let encodings = EnabledCompressionEncodings { |
364 | | inner: [ |
365 | | Some(CompressionEncoding::Gzip), |
366 | | Some(CompressionEncoding::Deflate), |
367 | | Some(CompressionEncoding::Zstd), |
368 | | ], |
369 | | }; |
370 | | |
371 | | assert_eq!( |
372 | | encodings.into_accept_encoding_header_value().unwrap(), |
373 | | HeaderValue::from_static("gzip,deflate,zstd,identity"), |
374 | | ); |
375 | | |
376 | | let encodings = EnabledCompressionEncodings { |
377 | | inner: [ |
378 | | Some(CompressionEncoding::Zstd), |
379 | | Some(CompressionEncoding::Deflate), |
380 | | Some(CompressionEncoding::Gzip), |
381 | | ], |
382 | | }; |
383 | | |
384 | | assert_eq!( |
385 | | encodings.into_accept_encoding_header_value().unwrap(), |
386 | | HeaderValue::from_static("zstd,deflate,gzip,identity"), |
387 | | ); |
388 | | } |
389 | | } |