/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.13.0/src/codec/prost.rs
Line | Count | Source |
1 | | use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder}; |
2 | | use crate::codec::EncodeBuf; |
3 | | use crate::Status; |
4 | | use prost::Message; |
5 | | use std::marker::PhantomData; |
6 | | |
7 | | /// A [`Codec`] that implements `application/grpc+proto` via the prost library.. |
8 | | #[derive(Debug, Clone)] |
9 | | pub struct ProstCodec<T, U> { |
10 | | _pd: PhantomData<(T, U)>, |
11 | | } |
12 | | |
13 | | impl<T, U> ProstCodec<T, U> { |
14 | | /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control |
15 | | /// how memory is allocated and grows per RPC. |
16 | 0 | pub fn new() -> Self { |
17 | 0 | Self { _pd: PhantomData } |
18 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse>>::new Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse>>::new Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _>>::new |
19 | | } |
20 | | |
21 | | impl<T, U> Default for ProstCodec<T, U> { |
22 | 0 | fn default() -> Self { |
23 | 0 | Self::new() |
24 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse> as core::default::Default>::default Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as core::default::Default>::default Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _> as core::default::Default>::default |
25 | | } |
26 | | |
27 | | impl<T, U> ProstCodec<T, U> |
28 | | where |
29 | | T: Message + Send + 'static, |
30 | | U: Message + Default + Send + 'static, |
31 | | { |
32 | | /// A tool for building custom codecs based on prost encoding and decoding. |
33 | | /// See the codec_buffers example for one possible way to use this. |
34 | 0 | pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder { |
35 | 0 | ProstEncoder { |
36 | 0 | _pd: PhantomData, |
37 | 0 | buffer_settings, |
38 | 0 | } |
39 | 0 | } |
40 | | |
41 | | /// A tool for building custom codecs based on prost encoding and decoding. |
42 | | /// See the codec_buffers example for one possible way to use this. |
43 | 0 | pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder { |
44 | 0 | ProstDecoder { |
45 | 0 | _pd: PhantomData, |
46 | 0 | buffer_settings, |
47 | 0 | } |
48 | 0 | } |
49 | | } |
50 | | |
51 | | impl<T, U> Codec for ProstCodec<T, U> |
52 | | where |
53 | | T: Message + Send + 'static, |
54 | | U: Message + Default + Send + 'static, |
55 | | { |
56 | | type Encode = T; |
57 | | type Decode = U; |
58 | | |
59 | | type Encoder = ProstEncoder<T>; |
60 | | type Decoder = ProstDecoder<U>; |
61 | | |
62 | 0 | fn encoder(&mut self) -> Self::Encoder { |
63 | 0 | ProstEncoder { |
64 | 0 | _pd: PhantomData, |
65 | 0 | buffer_settings: BufferSettings::default(), |
66 | 0 | } |
67 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Codec>::encoder Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Codec>::encoder Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _> as tonic::codec::Codec>::encoder |
68 | | |
69 | 0 | fn decoder(&mut self) -> Self::Decoder { |
70 | 0 | ProstDecoder { |
71 | 0 | _pd: PhantomData, |
72 | 0 | buffer_settings: BufferSettings::default(), |
73 | 0 | } |
74 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Codec>::decoder Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Codec>::decoder Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _> as tonic::codec::Codec>::decoder |
75 | | } |
76 | | |
77 | | /// A [`Encoder`] that knows how to encode `T`. |
78 | | #[derive(Debug, Clone, Default)] |
79 | | pub struct ProstEncoder<T> { |
80 | | _pd: PhantomData<T>, |
81 | | buffer_settings: BufferSettings, |
82 | | } |
83 | | |
84 | | impl<T> ProstEncoder<T> { |
85 | | /// Get a new encoder with explicit buffer settings |
86 | 0 | pub fn new(buffer_settings: BufferSettings) -> Self { |
87 | 0 | Self { |
88 | 0 | _pd: PhantomData, |
89 | 0 | buffer_settings, |
90 | 0 | } |
91 | 0 | } |
92 | | } |
93 | | |
94 | | impl<T: Message> Encoder for ProstEncoder<T> { |
95 | | type Item = T; |
96 | | type Error = Status; |
97 | | |
98 | 0 | fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { |
99 | 0 | item.encode(buf) |
100 | 0 | .expect("Message only errors if not enough space"); |
101 | | |
102 | 0 | Ok(()) |
103 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::istio::ca::IstioCertificateRequest> as tonic::codec::Encoder>::encode Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest> as tonic::codec::Encoder>::encode Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<_> as tonic::codec::Encoder>::encode |
104 | | |
105 | 0 | fn buffer_settings(&self) -> BufferSettings { |
106 | 0 | self.buffer_settings |
107 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::istio::ca::IstioCertificateRequest> as tonic::codec::Encoder>::buffer_settings Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest> as tonic::codec::Encoder>::buffer_settings Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<_> as tonic::codec::Encoder>::buffer_settings |
108 | | } |
109 | | |
110 | | /// A [`Decoder`] that knows how to decode `U`. |
111 | | #[derive(Debug, Clone, Default)] |
112 | | pub struct ProstDecoder<U> { |
113 | | _pd: PhantomData<U>, |
114 | | buffer_settings: BufferSettings, |
115 | | } |
116 | | |
117 | | impl<U> ProstDecoder<U> { |
118 | | /// Get a new decoder with explicit buffer settings |
119 | 0 | pub fn new(buffer_settings: BufferSettings) -> Self { |
120 | 0 | Self { |
121 | 0 | _pd: PhantomData, |
122 | 0 | buffer_settings, |
123 | 0 | } |
124 | 0 | } |
125 | | } |
126 | | |
127 | | impl<U: Message + Default> Decoder for ProstDecoder<U> { |
128 | | type Item = U; |
129 | | type Error = Status; |
130 | | |
131 | 0 | fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> { |
132 | 0 | let item = Message::decode(buf) |
133 | 0 | .map(Option::Some) |
134 | 0 | .map_err(from_decode_error)?; |
135 | | |
136 | 0 | Ok(item) |
137 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Decoder>::decode Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Decoder>::decode Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<_> as tonic::codec::Decoder>::decode |
138 | | |
139 | 0 | fn buffer_settings(&self) -> BufferSettings { |
140 | 0 | self.buffer_settings |
141 | 0 | } Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Decoder>::buffer_settings Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Decoder>::buffer_settings Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<_> as tonic::codec::Decoder>::buffer_settings |
142 | | } |
143 | | |
144 | 0 | fn from_decode_error(error: prost::DecodeError) -> crate::Status { |
145 | | // Map Protobuf parse errors to an INTERNAL status code, as per |
146 | | // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md |
147 | 0 | Status::internal(error.to_string()) |
148 | 0 | } |
149 | | |
150 | | #[cfg(test)] |
151 | | mod tests { |
152 | | use crate::codec::compression::SingleMessageCompressionOverride; |
153 | | use crate::codec::{ |
154 | | DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE, |
155 | | }; |
156 | | use crate::Status; |
157 | | use bytes::{Buf, BufMut, BytesMut}; |
158 | | use http_body::Body; |
159 | | use http_body_util::BodyExt as _; |
160 | | use std::pin::pin; |
161 | | |
162 | | const LEN: usize = 10000; |
163 | | // The maximum uncompressed size in bytes for a message. Set to 2MB. |
164 | | const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024; |
165 | | |
166 | | #[tokio::test] |
167 | | async fn decode() { |
168 | | let decoder = MockDecoder::default(); |
169 | | |
170 | | let msg = vec![0u8; LEN]; |
171 | | |
172 | | let mut buf = BytesMut::new(); |
173 | | |
174 | | buf.reserve(msg.len() + HEADER_SIZE); |
175 | | buf.put_u8(0); |
176 | | buf.put_u32(msg.len() as u32); |
177 | | |
178 | | buf.put(&msg[..]); |
179 | | |
180 | | let body = body::MockBody::new(&buf[..], 10005, 0); |
181 | | |
182 | | let mut stream = Streaming::new_request(decoder, body, None, None); |
183 | | |
184 | | let mut i = 0usize; |
185 | | while let Some(output_msg) = stream.message().await.unwrap() { |
186 | | assert_eq!(output_msg.len(), msg.len()); |
187 | | i += 1; |
188 | | } |
189 | | assert_eq!(i, 1); |
190 | | } |
191 | | |
192 | | #[tokio::test] |
193 | | async fn decode_max_message_size_exceeded() { |
194 | | let decoder = MockDecoder::default(); |
195 | | |
196 | | let msg = vec![0u8; MAX_MESSAGE_SIZE + 1]; |
197 | | |
198 | | let mut buf = BytesMut::new(); |
199 | | |
200 | | buf.reserve(msg.len() + HEADER_SIZE); |
201 | | buf.put_u8(0); |
202 | | buf.put_u32(msg.len() as u32); |
203 | | |
204 | | buf.put(&msg[..]); |
205 | | |
206 | | let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0); |
207 | | |
208 | | let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE)); |
209 | | |
210 | | let actual = stream.message().await.unwrap_err(); |
211 | | |
212 | | let expected = Status::out_of_range(format!( |
213 | | "Error, decoded message length too large: found {} bytes, the limit is: {} bytes", |
214 | | msg.len(), |
215 | | MAX_MESSAGE_SIZE |
216 | | )); |
217 | | |
218 | | assert_eq!(actual.code(), expected.code()); |
219 | | assert_eq!(actual.message(), expected.message()); |
220 | | } |
221 | | |
222 | | #[tokio::test] |
223 | | async fn encode() { |
224 | | let encoder = MockEncoder::default(); |
225 | | |
226 | | let msg = Vec::from(&[0u8; 1024][..]); |
227 | | |
228 | | let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000); |
229 | | let source = tokio_stream::iter(messages); |
230 | | |
231 | | let mut body = pin!(EncodeBody::new_server( |
232 | | encoder, |
233 | | source, |
234 | | None, |
235 | | SingleMessageCompressionOverride::default(), |
236 | | None, |
237 | | )); |
238 | | |
239 | | while let Some(r) = body.frame().await { |
240 | | r.unwrap(); |
241 | | } |
242 | | } |
243 | | |
244 | | #[tokio::test] |
245 | | async fn encode_max_message_size_exceeded() { |
246 | | let encoder = MockEncoder::default(); |
247 | | |
248 | | let msg = vec![0u8; MAX_MESSAGE_SIZE + 1]; |
249 | | |
250 | | let messages = std::iter::once(Ok::<_, Status>(msg)); |
251 | | let source = tokio_stream::iter(messages); |
252 | | |
253 | | let mut body = pin!(EncodeBody::new_server( |
254 | | encoder, |
255 | | source, |
256 | | None, |
257 | | SingleMessageCompressionOverride::default(), |
258 | | Some(MAX_MESSAGE_SIZE), |
259 | | )); |
260 | | |
261 | | let frame = body |
262 | | .frame() |
263 | | .await |
264 | | .expect("at least one frame") |
265 | | .expect("no error polling frame"); |
266 | | assert_eq!( |
267 | | frame |
268 | | .into_trailers() |
269 | | .expect("got trailers") |
270 | | .get(Status::GRPC_STATUS) |
271 | | .expect("grpc-status header"), |
272 | | "11" |
273 | | ); |
274 | | assert!(body.is_end_stream()); |
275 | | } |
276 | | |
277 | | // skip on windows because CI stumbles over our 4GB allocation |
278 | | #[cfg(not(target_family = "windows"))] |
279 | | #[tokio::test] |
280 | | async fn encode_too_big() { |
281 | | use crate::codec::EncodeBody; |
282 | | |
283 | | let encoder = MockEncoder::default(); |
284 | | |
285 | | let msg = vec![0u8; u32::MAX as usize + 1]; |
286 | | |
287 | | let messages = std::iter::once(Ok::<_, Status>(msg)); |
288 | | let source = tokio_stream::iter(messages); |
289 | | |
290 | | let mut body = pin!(EncodeBody::new_server( |
291 | | encoder, |
292 | | source, |
293 | | None, |
294 | | SingleMessageCompressionOverride::default(), |
295 | | Some(usize::MAX), |
296 | | )); |
297 | | |
298 | | let frame = body |
299 | | .frame() |
300 | | .await |
301 | | .expect("at least one frame") |
302 | | .expect("no error polling frame"); |
303 | | assert_eq!( |
304 | | frame |
305 | | .into_trailers() |
306 | | .expect("got trailers") |
307 | | .get(Status::GRPC_STATUS) |
308 | | .expect("grpc-status header"), |
309 | | "8" |
310 | | ); |
311 | | assert!(body.is_end_stream()); |
312 | | } |
313 | | |
314 | | #[derive(Debug, Clone, Default)] |
315 | | struct MockEncoder {} |
316 | | |
317 | | impl Encoder for MockEncoder { |
318 | | type Item = Vec<u8>; |
319 | | type Error = Status; |
320 | | |
321 | | fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> { |
322 | | buf.put(&item[..]); |
323 | | Ok(()) |
324 | | } |
325 | | |
326 | | fn buffer_settings(&self) -> crate::codec::BufferSettings { |
327 | | Default::default() |
328 | | } |
329 | | } |
330 | | |
331 | | #[derive(Debug, Clone, Default)] |
332 | | struct MockDecoder {} |
333 | | |
334 | | impl Decoder for MockDecoder { |
335 | | type Item = Vec<u8>; |
336 | | type Error = Status; |
337 | | |
338 | | fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> { |
339 | | let out = Vec::from(buf.chunk()); |
340 | | buf.advance(LEN); |
341 | | Ok(Some(out)) |
342 | | } |
343 | | |
344 | | fn buffer_settings(&self) -> crate::codec::BufferSettings { |
345 | | Default::default() |
346 | | } |
347 | | } |
348 | | |
349 | | mod body { |
350 | | use crate::Status; |
351 | | use bytes::Bytes; |
352 | | use http_body::{Body, Frame}; |
353 | | use std::{ |
354 | | pin::Pin, |
355 | | task::{Context, Poll}, |
356 | | }; |
357 | | |
358 | | #[derive(Debug)] |
359 | | pub(super) struct MockBody { |
360 | | data: Bytes, |
361 | | |
362 | | // the size of the partial message to send |
363 | | partial_len: usize, |
364 | | |
365 | | // the number of times we've sent |
366 | | count: usize, |
367 | | } |
368 | | |
369 | | impl MockBody { |
370 | | pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self { |
371 | | MockBody { |
372 | | data: Bytes::copy_from_slice(b), |
373 | | partial_len, |
374 | | count, |
375 | | } |
376 | | } |
377 | | } |
378 | | |
379 | | impl Body for MockBody { |
380 | | type Data = Bytes; |
381 | | type Error = Status; |
382 | | |
383 | | fn poll_frame( |
384 | | mut self: Pin<&mut Self>, |
385 | | cx: &mut Context<'_>, |
386 | | ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { |
387 | | // every other call to poll_data returns data |
388 | | let should_send = self.count % 2 == 0; |
389 | | let data_len = self.data.len(); |
390 | | let partial_len = self.partial_len; |
391 | | let count = self.count; |
392 | | if data_len > 0 { |
393 | | let result = if should_send { |
394 | | let response = |
395 | | self.data |
396 | | .split_to(if count == 0 { partial_len } else { data_len }); |
397 | | Poll::Ready(Some(Ok(Frame::data(response)))) |
398 | | } else { |
399 | | cx.waker().wake_by_ref(); |
400 | | Poll::Pending |
401 | | }; |
402 | | // make some fake progress |
403 | | self.count += 1; |
404 | | result |
405 | | } else { |
406 | | Poll::Ready(None) |
407 | | } |
408 | | } |
409 | | } |
410 | | } |
411 | | } |