Coverage Report

Created: 2024-12-17 06:15

/rust/registry/src/index.crates.io-6f17d22bba15001f/tonic-0.10.2/src/server/grpc.rs
Line
Count
Source (jump to first uncovered line)
1
use crate::codec::compression::{
2
    CompressionEncoding, EnabledCompressionEncodings, SingleMessageCompressionOverride,
3
};
4
use crate::{
5
    body::BoxBody,
6
    codec::{encode_server, Codec, Streaming},
7
    server::{ClientStreamingService, ServerStreamingService, StreamingService, UnaryService},
8
    Code, Request, Status,
9
};
10
use http_body::Body;
11
use std::fmt;
12
use tokio_stream::{Stream, StreamExt};
13
14
macro_rules! t {
15
    ($result:expr) => {
16
        match $result {
17
            Ok(value) => value,
18
            Err(status) => return status.to_http(),
19
        }
20
    };
21
}
22
23
/// A gRPC Server handler.
24
///
25
/// This will wrap some inner [`Codec`] and provide utilities to handle
26
/// inbound unary, client side streaming, server side streaming, and
27
/// bi-directional streaming.
28
///
29
/// Each request handler method accepts some service that implements the
30
/// corresponding service trait and a http request that contains some body that
31
/// implements some [`Body`].
32
pub struct Grpc<T> {
33
    codec: T,
34
    /// Which compression encodings does the server accept for requests?
35
    accept_compression_encodings: EnabledCompressionEncodings,
36
    /// Which compression encodings might the server use for responses.
37
    send_compression_encodings: EnabledCompressionEncodings,
38
    /// Limits the maximum size of a decoded message.
39
    max_decoding_message_size: Option<usize>,
40
    /// Limits the maximum size of an encoded message.
41
    max_encoding_message_size: Option<usize>,
42
}
43
44
impl<T> Grpc<T>
45
where
46
    T: Codec,
47
{
48
    /// Creates a new gRPC server with the provided [`Codec`].
49
0
    pub fn new(codec: T) -> Self {
50
0
        Self {
51
0
            codec,
52
0
            accept_compression_encodings: EnabledCompressionEncodings::default(),
53
0
            send_compression_encodings: EnabledCompressionEncodings::default(),
54
0
            max_decoding_message_size: None,
55
0
            max_encoding_message_size: None,
56
0
        }
57
0
    }
58
59
    /// Enable accepting compressed requests.
60
    ///
61
    /// If a request with an unsupported encoding is received the server will respond with
62
    /// [`Code::UnUnimplemented`](crate::Code).
63
    ///
64
    /// # Example
65
    ///
66
    /// The most common way of using this is through a server generated by tonic-build:
67
    ///
68
    /// ```rust
69
    /// # enum CompressionEncoding { Gzip }
70
    /// # struct Svc;
71
    /// # struct ExampleServer<T>(T);
72
    /// # impl<T> ExampleServer<T> {
73
    /// #     fn new(svc: T) -> Self { Self(svc) }
74
    /// #     fn accept_compressed(self, _: CompressionEncoding) -> Self { self }
75
    /// # }
76
    /// # #[tonic::async_trait]
77
    /// # trait Example {}
78
    ///
79
    /// #[tonic::async_trait]
80
    /// impl Example for Svc {
81
    ///     // ...
82
    /// }
83
    ///
84
    /// let service = ExampleServer::new(Svc).accept_compressed(CompressionEncoding::Gzip);
85
    /// ```
86
0
    pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
87
0
        self.accept_compression_encodings.enable(encoding);
88
0
        self
89
0
    }
90
91
    /// Enable sending compressed responses.
92
    ///
93
    /// Requires the client to also support receiving compressed responses.
94
    ///
95
    /// # Example
96
    ///
97
    /// The most common way of using this is through a server generated by tonic-build:
98
    ///
99
    /// ```rust
100
    /// # enum CompressionEncoding { Gzip }
101
    /// # struct Svc;
102
    /// # struct ExampleServer<T>(T);
103
    /// # impl<T> ExampleServer<T> {
104
    /// #     fn new(svc: T) -> Self { Self(svc) }
105
    /// #     fn send_compressed(self, _: CompressionEncoding) -> Self { self }
106
    /// # }
107
    /// # #[tonic::async_trait]
108
    /// # trait Example {}
109
    ///
110
    /// #[tonic::async_trait]
111
    /// impl Example for Svc {
112
    ///     // ...
113
    /// }
114
    ///
115
    /// let service = ExampleServer::new(Svc).send_compressed(CompressionEncoding::Gzip);
116
    /// ```
117
0
    pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
118
0
        self.send_compression_encodings.enable(encoding);
119
0
        self
120
0
    }
121
122
    /// Limits the maximum size of a decoded message.
123
    ///
124
    /// # Example
125
    ///
126
    /// The most common way of using this is through a server generated by tonic-build:
127
    ///
128
    /// ```rust
129
    /// # struct Svc;
130
    /// # struct ExampleServer<T>(T);
131
    /// # impl<T> ExampleServer<T> {
132
    /// #     fn new(svc: T) -> Self { Self(svc) }
133
    /// #     fn max_decoding_message_size(self, _: usize) -> Self { self }
134
    /// # }
135
    /// # #[tonic::async_trait]
136
    /// # trait Example {}
137
    ///
138
    /// #[tonic::async_trait]
139
    /// impl Example for Svc {
140
    ///     // ...
141
    /// }
142
    ///
143
    /// // Set the limit to 2MB, Defaults to 4MB.
144
    /// let limit = 2 * 1024 * 1024;
145
    /// let service = ExampleServer::new(Svc).max_decoding_message_size(limit);
146
    /// ```
147
0
    pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
148
0
        self.max_decoding_message_size = Some(limit);
149
0
        self
150
0
    }
151
152
    /// Limits the maximum size of a encoded message.
153
    ///
154
    /// # Example
155
    ///
156
    /// The most common way of using this is through a server generated by tonic-build:
157
    ///
158
    /// ```rust
159
    /// # struct Svc;
160
    /// # struct ExampleServer<T>(T);
161
    /// # impl<T> ExampleServer<T> {
162
    /// #     fn new(svc: T) -> Self { Self(svc) }
163
    /// #     fn max_encoding_message_size(self, _: usize) -> Self { self }
164
    /// # }
165
    /// # #[tonic::async_trait]
166
    /// # trait Example {}
167
    ///
168
    /// #[tonic::async_trait]
169
    /// impl Example for Svc {
170
    ///     // ...
171
    /// }
172
    ///
173
    /// // Set the limit to 2MB, Defaults to 4MB.
174
    /// let limit = 2 * 1024 * 1024;
175
    /// let service = ExampleServer::new(Svc).max_encoding_message_size(limit);
176
    /// ```
177
0
    pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
178
0
        self.max_encoding_message_size = Some(limit);
179
0
        self
180
0
    }
181
182
    #[doc(hidden)]
183
0
    pub fn apply_compression_config(
184
0
        self,
185
0
        accept_encodings: EnabledCompressionEncodings,
186
0
        send_encodings: EnabledCompressionEncodings,
187
0
    ) -> Self {
188
0
        let mut this = self;
189
190
0
        for &encoding in CompressionEncoding::encodings() {
191
0
            if accept_encodings.is_enabled(encoding) {
192
0
                this = this.accept_compressed(encoding);
193
0
            }
194
0
            if send_encodings.is_enabled(encoding) {
195
0
                this = this.send_compressed(encoding);
196
0
            }
197
        }
198
199
0
        this
200
0
    }
201
202
    #[doc(hidden)]
203
0
    pub fn apply_max_message_size_config(
204
0
        self,
205
0
        max_decoding_message_size: Option<usize>,
206
0
        max_encoding_message_size: Option<usize>,
207
0
    ) -> Self {
208
0
        let mut this = self;
209
210
0
        if let Some(limit) = max_decoding_message_size {
211
0
            this = this.max_decoding_message_size(limit);
212
0
        }
213
0
        if let Some(limit) = max_encoding_message_size {
214
0
            this = this.max_encoding_message_size(limit);
215
0
        }
216
217
0
        this
218
0
    }
219
220
    /// Handle a single unary gRPC request.
221
0
    pub async fn unary<S, B>(
222
0
        &mut self,
223
0
        mut service: S,
224
0
        req: http::Request<B>,
225
0
    ) -> http::Response<BoxBody>
226
0
    where
227
0
        S: UnaryService<T::Decode, Response = T::Encode>,
228
0
        B: Body + Send + 'static,
229
0
        B::Error: Into<crate::Error> + Send,
230
0
    {
231
0
        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
232
0
            req.headers(),
233
0
            self.send_compression_encodings,
234
0
        );
235
236
0
        let request = match self.map_request_unary(req).await {
237
0
            Ok(r) => r,
238
0
            Err(status) => {
239
0
                return self.map_response::<tokio_stream::Once<Result<T::Encode, Status>>>(
240
0
                    Err(status),
241
0
                    accept_encoding,
242
0
                    SingleMessageCompressionOverride::default(),
243
0
                    self.max_encoding_message_size,
244
0
                );
245
            }
246
        };
247
248
0
        let response = service
249
0
            .call(request)
250
0
            .await
251
0
            .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
252
0
253
0
        let compression_override = compression_override_from_response(&response);
254
0
255
0
        self.map_response(
256
0
            response,
257
0
            accept_encoding,
258
0
            compression_override,
259
0
            self.max_encoding_message_size,
260
0
        )
261
0
    }
262
263
    /// Handle a server side streaming request.
264
0
    pub async fn server_streaming<S, B>(
265
0
        &mut self,
266
0
        mut service: S,
267
0
        req: http::Request<B>,
268
0
    ) -> http::Response<BoxBody>
269
0
    where
270
0
        S: ServerStreamingService<T::Decode, Response = T::Encode>,
271
0
        S::ResponseStream: Send + 'static,
272
0
        B: Body + Send + 'static,
273
0
        B::Error: Into<crate::Error> + Send,
274
0
    {
275
0
        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
276
0
            req.headers(),
277
0
            self.send_compression_encodings,
278
0
        );
279
280
0
        let request = match self.map_request_unary(req).await {
281
0
            Ok(r) => r,
282
0
            Err(status) => {
283
0
                return self.map_response::<S::ResponseStream>(
284
0
                    Err(status),
285
0
                    accept_encoding,
286
0
                    SingleMessageCompressionOverride::default(),
287
0
                    self.max_encoding_message_size,
288
0
                );
289
            }
290
        };
291
292
0
        let response = service.call(request).await;
293
294
0
        self.map_response(
295
0
            response,
296
0
            accept_encoding,
297
0
            // disabling compression of individual stream items must be done on
298
0
            // the items themselves
299
0
            SingleMessageCompressionOverride::default(),
300
0
            self.max_encoding_message_size,
301
0
        )
302
0
    }
303
304
    /// Handle a client side streaming gRPC request.
305
0
    pub async fn client_streaming<S, B>(
306
0
        &mut self,
307
0
        mut service: S,
308
0
        req: http::Request<B>,
309
0
    ) -> http::Response<BoxBody>
310
0
    where
311
0
        S: ClientStreamingService<T::Decode, Response = T::Encode>,
312
0
        B: Body + Send + 'static,
313
0
        B::Error: Into<crate::Error> + Send + 'static,
314
0
    {
315
0
        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
316
0
            req.headers(),
317
0
            self.send_compression_encodings,
318
0
        );
319
320
0
        let request = t!(self.map_request_streaming(req));
321
322
0
        let response = service
323
0
            .call(request)
324
0
            .await
325
0
            .map(|r| r.map(|m| tokio_stream::once(Ok(m))));
326
0
327
0
        let compression_override = compression_override_from_response(&response);
328
0
329
0
        self.map_response(
330
0
            response,
331
0
            accept_encoding,
332
0
            compression_override,
333
0
            self.max_encoding_message_size,
334
0
        )
335
0
    }
336
337
    /// Handle a bi-directional streaming gRPC request.
338
0
    pub async fn streaming<S, B>(
339
0
        &mut self,
340
0
        mut service: S,
341
0
        req: http::Request<B>,
342
0
    ) -> http::Response<BoxBody>
343
0
    where
344
0
        S: StreamingService<T::Decode, Response = T::Encode> + Send,
345
0
        S::ResponseStream: Send + 'static,
346
0
        B: Body + Send + 'static,
347
0
        B::Error: Into<crate::Error> + Send,
348
0
    {
349
0
        let accept_encoding = CompressionEncoding::from_accept_encoding_header(
350
0
            req.headers(),
351
0
            self.send_compression_encodings,
352
0
        );
353
354
0
        let request = t!(self.map_request_streaming(req));
355
356
0
        let response = service.call(request).await;
357
358
0
        self.map_response(
359
0
            response,
360
0
            accept_encoding,
361
0
            SingleMessageCompressionOverride::default(),
362
0
            self.max_encoding_message_size,
363
0
        )
364
0
    }
365
366
0
    async fn map_request_unary<B>(
367
0
        &mut self,
368
0
        request: http::Request<B>,
369
0
    ) -> Result<Request<T::Decode>, Status>
370
0
    where
371
0
        B: Body + Send + 'static,
372
0
        B::Error: Into<crate::Error> + Send,
373
0
    {
374
0
        let request_compression_encoding = self.request_encoding_if_supported(&request)?;
375
376
0
        let (parts, body) = request.into_parts();
377
0
378
0
        let stream = Streaming::new_request(
379
0
            self.codec.decoder(),
380
0
            body,
381
0
            request_compression_encoding,
382
0
            self.max_decoding_message_size,
383
0
        );
384
0
385
0
        tokio::pin!(stream);
386
387
0
        let message = stream
388
0
            .try_next()
389
0
            .await?
390
0
            .ok_or_else(|| Status::new(Code::Internal, "Missing request message."))?;
391
392
0
        let mut req = Request::from_http_parts(parts, message);
393
394
0
        if let Some(trailers) = stream.trailers().await? {
395
0
            req.metadata_mut().merge(trailers);
396
0
        }
397
398
0
        Ok(req)
399
0
    }
400
401
0
    fn map_request_streaming<B>(
402
0
        &mut self,
403
0
        request: http::Request<B>,
404
0
    ) -> Result<Request<Streaming<T::Decode>>, Status>
405
0
    where
406
0
        B: Body + Send + 'static,
407
0
        B::Error: Into<crate::Error> + Send,
408
0
    {
409
0
        let encoding = self.request_encoding_if_supported(&request)?;
410
411
0
        let request = request.map(|body| {
412
0
            Streaming::new_request(
413
0
                self.codec.decoder(),
414
0
                body,
415
0
                encoding,
416
0
                self.max_decoding_message_size,
417
0
            )
418
0
        });
419
0
420
0
        Ok(Request::from_http(request))
421
0
    }
422
423
0
    fn map_response<B>(
424
0
        &mut self,
425
0
        response: Result<crate::Response<B>, Status>,
426
0
        accept_encoding: Option<CompressionEncoding>,
427
0
        compression_override: SingleMessageCompressionOverride,
428
0
        max_message_size: Option<usize>,
429
0
    ) -> http::Response<BoxBody>
430
0
    where
431
0
        B: Stream<Item = Result<T::Encode, Status>> + Send + 'static,
432
0
    {
433
0
        let response = match response {
434
0
            Ok(r) => r,
435
0
            Err(status) => return status.to_http(),
436
        };
437
438
0
        let (mut parts, body) = response.into_http().into_parts();
439
0
440
0
        // Set the content type
441
0
        parts.headers.insert(
442
0
            http::header::CONTENT_TYPE,
443
0
            http::header::HeaderValue::from_static("application/grpc"),
444
0
        );
445
446
0
        if let Some(encoding) = accept_encoding {
447
0
            // Set the content encoding
448
0
            parts.headers.insert(
449
0
                crate::codec::compression::ENCODING_HEADER,
450
0
                encoding.into_header_value(),
451
0
            );
452
0
        }
453
454
0
        let body = encode_server(
455
0
            self.codec.encoder(),
456
0
            body,
457
0
            accept_encoding,
458
0
            compression_override,
459
0
            max_message_size,
460
0
        );
461
0
462
0
        http::Response::from_parts(parts, BoxBody::new(body))
463
0
    }
464
465
0
    fn request_encoding_if_supported<B>(
466
0
        &self,
467
0
        request: &http::Request<B>,
468
0
    ) -> Result<Option<CompressionEncoding>, Status> {
469
0
        CompressionEncoding::from_encoding_header(
470
0
            request.headers(),
471
0
            self.accept_compression_encodings,
472
0
        )
473
0
    }
474
}
475
476
impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
477
0
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478
0
        let mut f = f.debug_struct("Grpc");
479
0
480
0
        f.field("codec", &self.codec);
481
0
482
0
        f.field(
483
0
            "accept_compression_encodings",
484
0
            &self.accept_compression_encodings,
485
0
        );
486
0
487
0
        f.field(
488
0
            "send_compression_encodings",
489
0
            &self.send_compression_encodings,
490
0
        );
491
0
492
0
        f.finish()
493
0
    }
494
}
495
496
0
fn compression_override_from_response<B, E>(
497
0
    res: &Result<crate::Response<B>, E>,
498
0
) -> SingleMessageCompressionOverride {
499
0
    res.as_ref()
500
0
        .ok()
501
0
        .and_then(|response| {
502
0
            response
503
0
                .extensions()
504
0
                .get::<SingleMessageCompressionOverride>()
505
0
                .copied()
506
0
        })
507
0
        .unwrap_or_default()
508
0
}