Coverage Report

Created: 2025-10-29 07:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/rust/registry/src/index.crates.io-1949cf8c6b5b557f/tonic-0.13.0/src/codec/prost.rs
Line
Count
Source
1
use super::{BufferSettings, Codec, DecodeBuf, Decoder, Encoder};
2
use crate::codec::EncodeBuf;
3
use crate::Status;
4
use prost::Message;
5
use std::marker::PhantomData;
6
7
/// A [`Codec`] that implements `application/grpc+proto` via the prost library..
8
#[derive(Debug, Clone)]
9
pub struct ProstCodec<T, U> {
10
    _pd: PhantomData<(T, U)>,
11
}
12
13
impl<T, U> ProstCodec<T, U> {
14
    /// Configure a ProstCodec with encoder/decoder buffer settings. This is used to control
15
    /// how memory is allocated and grows per RPC.
16
0
    pub fn new() -> Self {
17
0
        Self { _pd: PhantomData }
18
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse>>::new
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse>>::new
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _>>::new
19
}
20
21
impl<T, U> Default for ProstCodec<T, U> {
22
0
    fn default() -> Self {
23
0
        Self::new()
24
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse> as core::default::Default>::default
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as core::default::Default>::default
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _> as core::default::Default>::default
25
}
26
27
impl<T, U> ProstCodec<T, U>
28
where
29
    T: Message + Send + 'static,
30
    U: Message + Default + Send + 'static,
31
{
32
    /// A tool for building custom codecs based on prost encoding and decoding.
33
    /// See the codec_buffers example for one possible way to use this.
34
0
    pub fn raw_encoder(buffer_settings: BufferSettings) -> <Self as Codec>::Encoder {
35
0
        ProstEncoder {
36
0
            _pd: PhantomData,
37
0
            buffer_settings,
38
0
        }
39
0
    }
40
41
    /// A tool for building custom codecs based on prost encoding and decoding.
42
    /// See the codec_buffers example for one possible way to use this.
43
0
    pub fn raw_decoder(buffer_settings: BufferSettings) -> <Self as Codec>::Decoder {
44
0
        ProstDecoder {
45
0
            _pd: PhantomData,
46
0
            buffer_settings,
47
0
        }
48
0
    }
49
}
50
51
impl<T, U> Codec for ProstCodec<T, U>
52
where
53
    T: Message + Send + 'static,
54
    U: Message + Default + Send + 'static,
55
{
56
    type Encode = T;
57
    type Decode = U;
58
59
    type Encoder = ProstEncoder<T>;
60
    type Decoder = ProstDecoder<U>;
61
62
0
    fn encoder(&mut self) -> Self::Encoder {
63
0
        ProstEncoder {
64
0
            _pd: PhantomData,
65
0
            buffer_settings: BufferSettings::default(),
66
0
        }
67
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Codec>::encoder
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Codec>::encoder
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _> as tonic::codec::Codec>::encoder
68
69
0
    fn decoder(&mut self) -> Self::Decoder {
70
0
        ProstDecoder {
71
0
            _pd: PhantomData,
72
0
            buffer_settings: BufferSettings::default(),
73
0
        }
74
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::istio::ca::IstioCertificateRequest, ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Codec>::decoder
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest, ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Codec>::decoder
Unexecuted instantiation: <tonic::codec::prost::ProstCodec<_, _> as tonic::codec::Codec>::decoder
75
}
76
77
/// A [`Encoder`] that knows how to encode `T`.
78
#[derive(Debug, Clone, Default)]
79
pub struct ProstEncoder<T> {
80
    _pd: PhantomData<T>,
81
    buffer_settings: BufferSettings,
82
}
83
84
impl<T> ProstEncoder<T> {
85
    /// Get a new encoder with explicit buffer settings
86
0
    pub fn new(buffer_settings: BufferSettings) -> Self {
87
0
        Self {
88
0
            _pd: PhantomData,
89
0
            buffer_settings,
90
0
        }
91
0
    }
92
}
93
94
impl<T: Message> Encoder for ProstEncoder<T> {
95
    type Item = T;
96
    type Error = Status;
97
98
0
    fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
99
0
        item.encode(buf)
100
0
            .expect("Message only errors if not enough space");
101
102
0
        Ok(())
103
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::istio::ca::IstioCertificateRequest> as tonic::codec::Encoder>::encode
Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest> as tonic::codec::Encoder>::encode
Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<_> as tonic::codec::Encoder>::encode
104
105
0
    fn buffer_settings(&self) -> BufferSettings {
106
0
        self.buffer_settings
107
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::istio::ca::IstioCertificateRequest> as tonic::codec::Encoder>::buffer_settings
Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryRequest> as tonic::codec::Encoder>::buffer_settings
Unexecuted instantiation: <tonic::codec::prost::ProstEncoder<_> as tonic::codec::Encoder>::buffer_settings
108
}
109
110
/// A [`Decoder`] that knows how to decode `U`.
111
#[derive(Debug, Clone, Default)]
112
pub struct ProstDecoder<U> {
113
    _pd: PhantomData<U>,
114
    buffer_settings: BufferSettings,
115
}
116
117
impl<U> ProstDecoder<U> {
118
    /// Get a new decoder with explicit buffer settings
119
0
    pub fn new(buffer_settings: BufferSettings) -> Self {
120
0
        Self {
121
0
            _pd: PhantomData,
122
0
            buffer_settings,
123
0
        }
124
0
    }
125
}
126
127
impl<U: Message + Default> Decoder for ProstDecoder<U> {
128
    type Item = U;
129
    type Error = Status;
130
131
0
    fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
132
0
        let item = Message::decode(buf)
133
0
            .map(Option::Some)
134
0
            .map_err(from_decode_error)?;
135
136
0
        Ok(item)
137
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Decoder>::decode
Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Decoder>::decode
Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<_> as tonic::codec::Decoder>::decode
138
139
0
    fn buffer_settings(&self) -> BufferSettings {
140
0
        self.buffer_settings
141
0
    }
Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::istio::ca::IstioCertificateResponse> as tonic::codec::Decoder>::buffer_settings
Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<ztunnel::xds::types::service::discovery::v3::DeltaDiscoveryResponse> as tonic::codec::Decoder>::buffer_settings
Unexecuted instantiation: <tonic::codec::prost::ProstDecoder<_> as tonic::codec::Decoder>::buffer_settings
142
}
143
144
0
fn from_decode_error(error: prost::DecodeError) -> crate::Status {
145
    // Map Protobuf parse errors to an INTERNAL status code, as per
146
    // https://github.com/grpc/grpc/blob/master/doc/statuscodes.md
147
0
    Status::internal(error.to_string())
148
0
}
149
150
#[cfg(test)]
151
mod tests {
152
    use crate::codec::compression::SingleMessageCompressionOverride;
153
    use crate::codec::{
154
        DecodeBuf, Decoder, EncodeBody, EncodeBuf, Encoder, Streaming, HEADER_SIZE,
155
    };
156
    use crate::Status;
157
    use bytes::{Buf, BufMut, BytesMut};
158
    use http_body::Body;
159
    use http_body_util::BodyExt as _;
160
    use std::pin::pin;
161
162
    const LEN: usize = 10000;
163
    // The maximum uncompressed size in bytes for a message. Set to 2MB.
164
    const MAX_MESSAGE_SIZE: usize = 2 * 1024 * 1024;
165
166
    #[tokio::test]
167
    async fn decode() {
168
        let decoder = MockDecoder::default();
169
170
        let msg = vec![0u8; LEN];
171
172
        let mut buf = BytesMut::new();
173
174
        buf.reserve(msg.len() + HEADER_SIZE);
175
        buf.put_u8(0);
176
        buf.put_u32(msg.len() as u32);
177
178
        buf.put(&msg[..]);
179
180
        let body = body::MockBody::new(&buf[..], 10005, 0);
181
182
        let mut stream = Streaming::new_request(decoder, body, None, None);
183
184
        let mut i = 0usize;
185
        while let Some(output_msg) = stream.message().await.unwrap() {
186
            assert_eq!(output_msg.len(), msg.len());
187
            i += 1;
188
        }
189
        assert_eq!(i, 1);
190
    }
191
192
    #[tokio::test]
193
    async fn decode_max_message_size_exceeded() {
194
        let decoder = MockDecoder::default();
195
196
        let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
197
198
        let mut buf = BytesMut::new();
199
200
        buf.reserve(msg.len() + HEADER_SIZE);
201
        buf.put_u8(0);
202
        buf.put_u32(msg.len() as u32);
203
204
        buf.put(&msg[..]);
205
206
        let body = body::MockBody::new(&buf[..], MAX_MESSAGE_SIZE + HEADER_SIZE + 1, 0);
207
208
        let mut stream = Streaming::new_request(decoder, body, None, Some(MAX_MESSAGE_SIZE));
209
210
        let actual = stream.message().await.unwrap_err();
211
212
        let expected = Status::out_of_range(format!(
213
            "Error, decoded message length too large: found {} bytes, the limit is: {} bytes",
214
            msg.len(),
215
            MAX_MESSAGE_SIZE
216
        ));
217
218
        assert_eq!(actual.code(), expected.code());
219
        assert_eq!(actual.message(), expected.message());
220
    }
221
222
    #[tokio::test]
223
    async fn encode() {
224
        let encoder = MockEncoder::default();
225
226
        let msg = Vec::from(&[0u8; 1024][..]);
227
228
        let messages = std::iter::repeat_with(move || Ok::<_, Status>(msg.clone())).take(10000);
229
        let source = tokio_stream::iter(messages);
230
231
        let mut body = pin!(EncodeBody::new_server(
232
            encoder,
233
            source,
234
            None,
235
            SingleMessageCompressionOverride::default(),
236
            None,
237
        ));
238
239
        while let Some(r) = body.frame().await {
240
            r.unwrap();
241
        }
242
    }
243
244
    #[tokio::test]
245
    async fn encode_max_message_size_exceeded() {
246
        let encoder = MockEncoder::default();
247
248
        let msg = vec![0u8; MAX_MESSAGE_SIZE + 1];
249
250
        let messages = std::iter::once(Ok::<_, Status>(msg));
251
        let source = tokio_stream::iter(messages);
252
253
        let mut body = pin!(EncodeBody::new_server(
254
            encoder,
255
            source,
256
            None,
257
            SingleMessageCompressionOverride::default(),
258
            Some(MAX_MESSAGE_SIZE),
259
        ));
260
261
        let frame = body
262
            .frame()
263
            .await
264
            .expect("at least one frame")
265
            .expect("no error polling frame");
266
        assert_eq!(
267
            frame
268
                .into_trailers()
269
                .expect("got trailers")
270
                .get(Status::GRPC_STATUS)
271
                .expect("grpc-status header"),
272
            "11"
273
        );
274
        assert!(body.is_end_stream());
275
    }
276
277
    // skip on windows because CI stumbles over our 4GB allocation
278
    #[cfg(not(target_family = "windows"))]
279
    #[tokio::test]
280
    async fn encode_too_big() {
281
        use crate::codec::EncodeBody;
282
283
        let encoder = MockEncoder::default();
284
285
        let msg = vec![0u8; u32::MAX as usize + 1];
286
287
        let messages = std::iter::once(Ok::<_, Status>(msg));
288
        let source = tokio_stream::iter(messages);
289
290
        let mut body = pin!(EncodeBody::new_server(
291
            encoder,
292
            source,
293
            None,
294
            SingleMessageCompressionOverride::default(),
295
            Some(usize::MAX),
296
        ));
297
298
        let frame = body
299
            .frame()
300
            .await
301
            .expect("at least one frame")
302
            .expect("no error polling frame");
303
        assert_eq!(
304
            frame
305
                .into_trailers()
306
                .expect("got trailers")
307
                .get(Status::GRPC_STATUS)
308
                .expect("grpc-status header"),
309
            "8"
310
        );
311
        assert!(body.is_end_stream());
312
    }
313
314
    #[derive(Debug, Clone, Default)]
315
    struct MockEncoder {}
316
317
    impl Encoder for MockEncoder {
318
        type Item = Vec<u8>;
319
        type Error = Status;
320
321
        fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
322
            buf.put(&item[..]);
323
            Ok(())
324
        }
325
326
        fn buffer_settings(&self) -> crate::codec::BufferSettings {
327
            Default::default()
328
        }
329
    }
330
331
    #[derive(Debug, Clone, Default)]
332
    struct MockDecoder {}
333
334
    impl Decoder for MockDecoder {
335
        type Item = Vec<u8>;
336
        type Error = Status;
337
338
        fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
339
            let out = Vec::from(buf.chunk());
340
            buf.advance(LEN);
341
            Ok(Some(out))
342
        }
343
344
        fn buffer_settings(&self) -> crate::codec::BufferSettings {
345
            Default::default()
346
        }
347
    }
348
349
    mod body {
350
        use crate::Status;
351
        use bytes::Bytes;
352
        use http_body::{Body, Frame};
353
        use std::{
354
            pin::Pin,
355
            task::{Context, Poll},
356
        };
357
358
        #[derive(Debug)]
359
        pub(super) struct MockBody {
360
            data: Bytes,
361
362
            // the size of the partial message to send
363
            partial_len: usize,
364
365
            // the number of times we've sent
366
            count: usize,
367
        }
368
369
        impl MockBody {
370
            pub(super) fn new(b: &[u8], partial_len: usize, count: usize) -> Self {
371
                MockBody {
372
                    data: Bytes::copy_from_slice(b),
373
                    partial_len,
374
                    count,
375
                }
376
            }
377
        }
378
379
        impl Body for MockBody {
380
            type Data = Bytes;
381
            type Error = Status;
382
383
            fn poll_frame(
384
                mut self: Pin<&mut Self>,
385
                cx: &mut Context<'_>,
386
            ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
387
                // every other call to poll_data returns data
388
                let should_send = self.count % 2 == 0;
389
                let data_len = self.data.len();
390
                let partial_len = self.partial_len;
391
                let count = self.count;
392
                if data_len > 0 {
393
                    let result = if should_send {
394
                        let response =
395
                            self.data
396
                                .split_to(if count == 0 { partial_len } else { data_len });
397
                        Poll::Ready(Some(Ok(Frame::data(response))))
398
                    } else {
399
                        cx.waker().wake_by_ref();
400
                        Poll::Pending
401
                    };
402
                    // make some fake progress
403
                    self.count += 1;
404
                    result
405
                } else {
406
                    Poll::Ready(None)
407
                }
408
            }
409
        }
410
    }
411
}