/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.12.3/src/codec/compression.rs
Line | Count | Source |
1 | | use crate::{metadata::MetadataValue, Status}; |
2 | | use bytes::{Buf, BufMut, BytesMut}; |
3 | | #[cfg(feature = "gzip")] |
4 | | use flate2::read::{GzDecoder, GzEncoder}; |
5 | | use std::fmt; |
6 | | #[cfg(feature = "zstd")] |
7 | | use zstd::stream::read::{Decoder, Encoder}; |
8 | | |
9 | | pub(crate) const ENCODING_HEADER: &str = "grpc-encoding"; |
10 | | pub(crate) const ACCEPT_ENCODING_HEADER: &str = "grpc-accept-encoding"; |
11 | | |
12 | | /// Struct used to configure which encodings are enabled on a server or channel. |
13 | | /// |
14 | | /// Represents an ordered list of compression encodings that are enabled. |
15 | | #[derive(Debug, Default, Clone, Copy)] |
16 | | pub struct EnabledCompressionEncodings { |
17 | | inner: [Option<CompressionEncoding>; 2], |
18 | | } |
19 | | |
20 | | impl EnabledCompressionEncodings { |
21 | | /// Enable a [`CompressionEncoding`]. |
22 | | /// |
23 | | /// Adds the new encoding to the end of the encoding list. |
24 | 0 | pub fn enable(&mut self, encoding: CompressionEncoding) { |
25 | 0 | for e in self.inner.iter_mut() { |
26 | 0 | match e { |
27 | 0 | Some(e) if *e == encoding => return, |
28 | | None => { |
29 | 0 | *e = Some(encoding); |
30 | 0 | return; |
31 | | } |
32 | | _ => continue, |
33 | | } |
34 | | } |
35 | 0 | } |
36 | | |
37 | | /// Remove the last [`CompressionEncoding`]. |
38 | 0 | pub fn pop(&mut self) -> Option<CompressionEncoding> { |
39 | 0 | self.inner |
40 | 0 | .iter_mut() |
41 | 0 | .rev() |
42 | 0 | .find(|entry| entry.is_some())? |
43 | 0 | .take() |
44 | 0 | } |
45 | | |
46 | 0 | pub(crate) fn into_accept_encoding_header_value(self) -> Option<http::HeaderValue> { |
47 | 0 | let mut value = BytesMut::new(); |
48 | 0 | for encoding in self.inner.into_iter().flatten() { |
49 | | value.put_slice(encoding.as_str().as_bytes()); |
50 | | value.put_u8(b','); |
51 | | } |
52 | | |
53 | 0 | if value.is_empty() { |
54 | 0 | return None; |
55 | 0 | } |
56 | | |
57 | 0 | value.put_slice(b"identity"); |
58 | 0 | Some(http::HeaderValue::from_maybe_shared(value).unwrap()) |
59 | 0 | } |
60 | | |
61 | | /// Check if a [`CompressionEncoding`] is enabled. |
62 | 0 | pub fn is_enabled(&self, encoding: CompressionEncoding) -> bool { |
63 | 0 | self.inner.contains(&Some(encoding)) |
64 | 0 | } |
65 | | |
66 | | /// Check if any [`CompressionEncoding`]s are enabled. |
67 | 0 | pub fn is_empty(&self) -> bool { |
68 | 0 | self.inner.iter().all(|e| e.is_none()) |
69 | 0 | } |
70 | | } |
71 | | |
72 | | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
73 | | pub(crate) struct CompressionSettings { |
74 | | pub(crate) encoding: CompressionEncoding, |
75 | | /// buffer_growth_interval controls memory growth for internal buffers to balance resizing cost against memory waste. |
76 | | /// The default buffer growth interval is 8 kilobytes. |
77 | | pub(crate) buffer_growth_interval: usize, |
78 | | } |
79 | | |
80 | | /// The compression encodings Tonic supports. |
81 | | #[derive(Clone, Copy, Debug, PartialEq, Eq)] |
82 | | #[non_exhaustive] |
83 | | pub enum CompressionEncoding { |
84 | | #[allow(missing_docs)] |
85 | | #[cfg(feature = "gzip")] |
86 | | Gzip, |
87 | | #[allow(missing_docs)] |
88 | | #[cfg(feature = "zstd")] |
89 | | Zstd, |
90 | | } |
91 | | |
92 | | impl CompressionEncoding { |
93 | | pub(crate) const ENCODINGS: &'static [CompressionEncoding] = &[ |
94 | | #[cfg(feature = "gzip")] |
95 | | CompressionEncoding::Gzip, |
96 | | #[cfg(feature = "zstd")] |
97 | | CompressionEncoding::Zstd, |
98 | | ]; |
99 | | |
100 | | /// Based on the `grpc-accept-encoding` header, pick an encoding to use. |
101 | 0 | pub(crate) fn from_accept_encoding_header( |
102 | 0 | map: &http::HeaderMap, |
103 | 0 | enabled_encodings: EnabledCompressionEncodings, |
104 | 0 | ) -> Option<Self> { |
105 | 0 | if enabled_encodings.is_empty() { |
106 | 0 | return None; |
107 | 0 | } |
108 | | |
109 | 0 | let header_value = map.get(ACCEPT_ENCODING_HEADER)?; |
110 | 0 | let header_value_str = header_value.to_str().ok()?; |
111 | | |
112 | 0 | split_by_comma(header_value_str).find_map(|value| match value { |
113 | | #[cfg(feature = "gzip")] |
114 | | "gzip" => Some(CompressionEncoding::Gzip), |
115 | | #[cfg(feature = "zstd")] |
116 | | "zstd" => Some(CompressionEncoding::Zstd), |
117 | 0 | _ => None, |
118 | 0 | }) |
119 | 0 | } |
120 | | |
121 | | /// Get the value of `grpc-encoding` header. Returns an error if the encoding isn't supported. |
122 | 0 | pub(crate) fn from_encoding_header( |
123 | 0 | map: &http::HeaderMap, |
124 | 0 | enabled_encodings: EnabledCompressionEncodings, |
125 | 0 | ) -> Result<Option<Self>, Status> { |
126 | 0 | let Some(header_value) = map.get(ENCODING_HEADER) else { |
127 | 0 | return Ok(None); |
128 | | }; |
129 | | |
130 | 0 | match header_value.as_bytes() { |
131 | | #[cfg(feature = "gzip")] |
132 | | b"gzip" if enabled_encodings.is_enabled(CompressionEncoding::Gzip) => { |
133 | | Ok(Some(CompressionEncoding::Gzip)) |
134 | | } |
135 | | #[cfg(feature = "zstd")] |
136 | | b"zstd" if enabled_encodings.is_enabled(CompressionEncoding::Zstd) => { |
137 | | Ok(Some(CompressionEncoding::Zstd)) |
138 | | } |
139 | 0 | b"identity" => Ok(None), |
140 | 0 | other => { |
141 | | // NOTE: Workaround for lifetime limitation. Resolved at Rust 1.79. |
142 | | // https://blog.rust-lang.org/2024/06/13/Rust-1.79.0.html#extending-automatic-temporary-lifetime-extension |
143 | | let other_debug_string; |
144 | | |
145 | 0 | let mut status = Status::unimplemented(format!( |
146 | 0 | "Content is compressed with `{}` which isn't supported", |
147 | 0 | match std::str::from_utf8(other) { |
148 | 0 | Ok(s) => s, |
149 | | Err(_) => { |
150 | 0 | other_debug_string = format!("{other:?}"); |
151 | 0 | &other_debug_string |
152 | | } |
153 | | } |
154 | | )); |
155 | | |
156 | 0 | let header_value = enabled_encodings |
157 | 0 | .into_accept_encoding_header_value() |
158 | 0 | .map(MetadataValue::unchecked_from_header_value) |
159 | 0 | .unwrap_or_else(|| MetadataValue::from_static("identity")); |
160 | 0 | status |
161 | 0 | .metadata_mut() |
162 | 0 | .insert(ACCEPT_ENCODING_HEADER, header_value); |
163 | | |
164 | 0 | Err(status) |
165 | | } |
166 | | } |
167 | 0 | } |
168 | | |
169 | | pub(crate) fn as_str(self) -> &'static str { |
170 | | match self { |
171 | | #[cfg(feature = "gzip")] |
172 | | CompressionEncoding::Gzip => "gzip", |
173 | | #[cfg(feature = "zstd")] |
174 | | CompressionEncoding::Zstd => "zstd", |
175 | | } |
176 | | } |
177 | | |
178 | | #[cfg(any(feature = "gzip", feature = "zstd"))] |
179 | | pub(crate) fn into_header_value(self) -> http::HeaderValue { |
180 | | http::HeaderValue::from_static(self.as_str()) |
181 | | } |
182 | | } |
183 | | |
184 | | impl fmt::Display for CompressionEncoding { |
185 | 0 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
186 | 0 | f.write_str(self.as_str()) |
187 | 0 | } |
188 | | } |
189 | | |
190 | 0 | fn split_by_comma(s: &str) -> impl Iterator<Item = &str> { |
191 | 0 | s.split(',').map(|s| s.trim()) |
192 | 0 | } |
193 | | |
194 | | /// Compress `len` bytes from `decompressed_buf` into `out_buf`. |
195 | | /// buffer_size_increment is a hint to control the growth of out_buf versus the cost of resizing it. |
196 | | #[allow(unused_variables, unreachable_code)] |
197 | 0 | pub(crate) fn compress( |
198 | 0 | settings: CompressionSettings, |
199 | 0 | decompressed_buf: &mut BytesMut, |
200 | 0 | out_buf: &mut BytesMut, |
201 | 0 | len: usize, |
202 | 0 | ) -> Result<(), std::io::Error> { |
203 | 0 | let buffer_growth_interval = settings.buffer_growth_interval; |
204 | 0 | let capacity = ((len / buffer_growth_interval) + 1) * buffer_growth_interval; |
205 | 0 | out_buf.reserve(capacity); |
206 | | |
207 | | #[cfg(any(feature = "gzip", feature = "zstd"))] |
208 | | let mut out_writer = out_buf.writer(); |
209 | | |
210 | | match settings.encoding { |
211 | | #[cfg(feature = "gzip")] |
212 | | CompressionEncoding::Gzip => { |
213 | | let mut gzip_encoder = GzEncoder::new( |
214 | | &decompressed_buf[0..len], |
215 | | // FIXME: support customizing the compression level |
216 | | flate2::Compression::new(6), |
217 | | ); |
218 | | std::io::copy(&mut gzip_encoder, &mut out_writer)?; |
219 | | } |
220 | | #[cfg(feature = "zstd")] |
221 | | CompressionEncoding::Zstd => { |
222 | | let mut zstd_encoder = Encoder::new( |
223 | | &decompressed_buf[0..len], |
224 | | // FIXME: support customizing the compression level |
225 | | zstd::DEFAULT_COMPRESSION_LEVEL, |
226 | | )?; |
227 | | std::io::copy(&mut zstd_encoder, &mut out_writer)?; |
228 | | } |
229 | | } |
230 | | |
231 | | decompressed_buf.advance(len); |
232 | | |
233 | | Ok(()) |
234 | | } |
235 | | |
236 | | /// Decompress `len` bytes from `compressed_buf` into `out_buf`. |
237 | | #[allow(unused_variables, unreachable_code)] |
238 | 0 | pub(crate) fn decompress( |
239 | 0 | settings: CompressionSettings, |
240 | 0 | compressed_buf: &mut BytesMut, |
241 | 0 | out_buf: &mut BytesMut, |
242 | 0 | len: usize, |
243 | 0 | ) -> Result<(), std::io::Error> { |
244 | 0 | let buffer_growth_interval = settings.buffer_growth_interval; |
245 | 0 | let estimate_decompressed_len = len * 2; |
246 | 0 | let capacity = |
247 | 0 | ((estimate_decompressed_len / buffer_growth_interval) + 1) * buffer_growth_interval; |
248 | 0 | out_buf.reserve(capacity); |
249 | | |
250 | | #[cfg(any(feature = "gzip", feature = "zstd"))] |
251 | | let mut out_writer = out_buf.writer(); |
252 | | |
253 | | match settings.encoding { |
254 | | #[cfg(feature = "gzip")] |
255 | | CompressionEncoding::Gzip => { |
256 | | let mut gzip_decoder = GzDecoder::new(&compressed_buf[0..len]); |
257 | | std::io::copy(&mut gzip_decoder, &mut out_writer)?; |
258 | | } |
259 | | #[cfg(feature = "zstd")] |
260 | | CompressionEncoding::Zstd => { |
261 | | let mut zstd_decoder = Decoder::new(&compressed_buf[0..len])?; |
262 | | std::io::copy(&mut zstd_decoder, &mut out_writer)?; |
263 | | } |
264 | | } |
265 | | |
266 | | compressed_buf.advance(len); |
267 | | |
268 | | Ok(()) |
269 | | } |
270 | | |
271 | | #[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] |
272 | | pub enum SingleMessageCompressionOverride { |
273 | | /// Inherit whatever compression is already configured. If the stream is compressed this |
274 | | /// message will also be configured. |
275 | | /// |
276 | | /// This is the default. |
277 | | #[default] |
278 | | Inherit, |
279 | | /// Don't compress this message, even if compression is enabled on the stream. |
280 | | Disable, |
281 | | } |
282 | | |
283 | | #[cfg(test)] |
284 | | mod tests { |
285 | | use http::HeaderValue; |
286 | | |
287 | | use super::*; |
288 | | |
289 | | #[test] |
290 | | fn convert_none_into_header_value() { |
291 | | let encodings = EnabledCompressionEncodings::default(); |
292 | | |
293 | | assert!(encodings.into_accept_encoding_header_value().is_none()); |
294 | | } |
295 | | |
296 | | #[test] |
297 | | #[cfg(feature = "gzip")] |
298 | | fn convert_gzip_into_header_value() { |
299 | | const GZIP: HeaderValue = HeaderValue::from_static("gzip,identity"); |
300 | | |
301 | | let encodings = EnabledCompressionEncodings { |
302 | | inner: [Some(CompressionEncoding::Gzip), None], |
303 | | }; |
304 | | |
305 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP); |
306 | | |
307 | | let encodings = EnabledCompressionEncodings { |
308 | | inner: [None, Some(CompressionEncoding::Gzip)], |
309 | | }; |
310 | | |
311 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), GZIP); |
312 | | } |
313 | | |
314 | | #[test] |
315 | | #[cfg(feature = "zstd")] |
316 | | fn convert_zstd_into_header_value() { |
317 | | const ZSTD: HeaderValue = HeaderValue::from_static("zstd,identity"); |
318 | | |
319 | | let encodings = EnabledCompressionEncodings { |
320 | | inner: [Some(CompressionEncoding::Zstd), None], |
321 | | }; |
322 | | |
323 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD); |
324 | | |
325 | | let encodings = EnabledCompressionEncodings { |
326 | | inner: [None, Some(CompressionEncoding::Zstd)], |
327 | | }; |
328 | | |
329 | | assert_eq!(encodings.into_accept_encoding_header_value().unwrap(), ZSTD); |
330 | | } |
331 | | |
332 | | #[test] |
333 | | #[cfg(all(feature = "gzip", feature = "zstd"))] |
334 | | fn convert_gzip_and_zstd_into_header_value() { |
335 | | let encodings = EnabledCompressionEncodings { |
336 | | inner: [ |
337 | | Some(CompressionEncoding::Gzip), |
338 | | Some(CompressionEncoding::Zstd), |
339 | | ], |
340 | | }; |
341 | | |
342 | | assert_eq!( |
343 | | encodings.into_accept_encoding_header_value().unwrap(), |
344 | | HeaderValue::from_static("gzip,zstd,identity"), |
345 | | ); |
346 | | |
347 | | let encodings = EnabledCompressionEncodings { |
348 | | inner: [ |
349 | | Some(CompressionEncoding::Zstd), |
350 | | Some(CompressionEncoding::Gzip), |
351 | | ], |
352 | | }; |
353 | | |
354 | | assert_eq!( |
355 | | encodings.into_accept_encoding_header_value().unwrap(), |
356 | | HeaderValue::from_static("zstd,gzip,identity"), |
357 | | ); |
358 | | } |
359 | | } |