LCOV - code coverage report
Current view: top level - source/common/grpc - common.cc (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 136 232 58.6 %
Date: 2024-01-05 06:35:25 Functions: 21 24 87.5 %

          Line data    Source code
       1             : #include "source/common/grpc/common.h"
       2             : 
       3             : #include <atomic>
       4             : #include <cstdint>
       5             : #include <cstring>
       6             : #include <string>
       7             : 
       8             : #include "source/common/buffer/buffer_impl.h"
       9             : #include "source/common/buffer/zero_copy_input_stream_impl.h"
      10             : #include "source/common/common/assert.h"
      11             : #include "source/common/common/base64.h"
      12             : #include "source/common/common/empty_string.h"
      13             : #include "source/common/common/enum_to_int.h"
      14             : #include "source/common/common/fmt.h"
      15             : #include "source/common/common/macros.h"
      16             : #include "source/common/common/safe_memcpy.h"
      17             : #include "source/common/common/utility.h"
      18             : #include "source/common/grpc/codec.h"
      19             : #include "source/common/http/header_utility.h"
      20             : #include "source/common/http/headers.h"
      21             : #include "source/common/http/message_impl.h"
      22             : #include "source/common/http/utility.h"
      23             : #include "source/common/protobuf/protobuf.h"
      24             : 
      25             : #include "absl/container/fixed_array.h"
      26             : #include "absl/strings/match.h"
      27             : 
      28             : namespace Envoy {
      29             : namespace Grpc {
      30             : 
      31        2114 : bool Common::hasGrpcContentType(const Http::RequestOrResponseHeaderMap& headers) {
      32        2114 :   const absl::string_view content_type = headers.getContentTypeValue();
      33             :   // Content type is gRPC if it is exactly "application/grpc" or starts with
      34             :   // "application/grpc+". Specifically, something like application/grpc-web is not gRPC.
      35        2114 :   return absl::StartsWith(content_type, Http::Headers::get().ContentTypeValues.Grpc) &&
      36        2114 :          (content_type.size() == Http::Headers::get().ContentTypeValues.Grpc.size() ||
      37         292 :           content_type[Http::Headers::get().ContentTypeValues.Grpc.size()] == '+');
      38        2114 : }
      39             : 
      40           1 : bool Common::hasConnectProtocolVersionHeader(const Http::RequestOrResponseHeaderMap& headers) {
      41           1 :   return !headers.get(Http::CustomHeaders::get().ConnectProtocolVersion).empty();
      42           1 : }
      43             : 
      44           2 : bool Common::hasConnectStreamingContentType(const Http::RequestOrResponseHeaderMap& headers) {
      45             :   // Consider the request a connect request if the content type starts with "application/connect+".
      46           2 :   static constexpr absl::string_view connect_prefix{"application/connect+"};
      47           2 :   const absl::string_view content_type = headers.getContentTypeValue();
      48           2 :   return absl::StartsWith(content_type, connect_prefix);
      49           2 : }
      50             : 
      51           1 : bool Common::hasProtobufContentType(const Http::RequestOrResponseHeaderMap& headers) {
      52           1 :   return headers.getContentTypeValue() == Http::Headers::get().ContentTypeValues.Protobuf;
      53           1 : }
      54             : 
      55        2133 : bool Common::isGrpcRequestHeaders(const Http::RequestHeaderMap& headers) {
      56        2133 :   if (!headers.Path()) {
      57          32 :     return false;
      58          32 :   }
      59        2101 :   return hasGrpcContentType(headers);
      60        2133 : }
      61             : 
      62           1 : bool Common::isConnectRequestHeaders(const Http::RequestHeaderMap& headers) {
      63           1 :   if (!headers.Path()) {
      64           0 :     return false;
      65           0 :   }
      66           1 :   return hasConnectProtocolVersionHeader(headers);
      67           1 : }
      68             : 
      69           1 : bool Common::isConnectStreamingRequestHeaders(const Http::RequestHeaderMap& headers) {
      70           1 :   if (!headers.Path()) {
      71           0 :     return false;
      72           0 :   }
      73           1 :   return hasConnectStreamingContentType(headers);
      74           1 : }
      75             : 
      76           1 : bool Common::isProtobufRequestHeaders(const Http::RequestHeaderMap& headers) {
      77           1 :   if (!headers.Path()) {
      78           0 :     return false;
      79           0 :   }
      80           1 :   return hasProtobufContentType(headers);
      81           1 : }
      82             : 
      83          20 : bool Common::isGrpcResponseHeaders(const Http::ResponseHeaderMap& headers, bool end_stream) {
      84          20 :   if (end_stream) {
      85             :     // Trailers-only response, only grpc-status is required.
      86           7 :     return headers.GrpcStatus() != nullptr;
      87           7 :   }
      88          13 :   if (Http::Utility::getResponseStatus(headers) != enumToInt(Http::Code::OK)) {
      89           0 :     return false;
      90           0 :   }
      91          13 :   return hasGrpcContentType(headers);
      92          13 : }
      93             : 
      94           1 : bool Common::isConnectStreamingResponseHeaders(const Http::ResponseHeaderMap& headers) {
      95           1 :   if (Http::Utility::getResponseStatus(headers) != enumToInt(Http::Code::OK)) {
      96           0 :     return false;
      97           0 :   }
      98           1 :   return hasConnectStreamingContentType(headers);
      99           1 : }
     100             : 
     101             : absl::optional<Status::GrpcStatus>
     102         170 : Common::getGrpcStatus(const Http::ResponseHeaderOrTrailerMap& trailers, bool allow_user_defined) {
     103         170 :   const absl::string_view grpc_status_header = trailers.getGrpcStatusValue();
     104         170 :   uint64_t grpc_status_code;
     105             : 
     106         170 :   if (grpc_status_header.empty()) {
     107          98 :     return absl::nullopt;
     108          98 :   }
     109          72 :   if (!absl::SimpleAtoi(grpc_status_header, &grpc_status_code) ||
     110          72 :       (grpc_status_code > Status::WellKnownGrpcStatus::MaximumKnown && !allow_user_defined)) {
     111           0 :     return {Status::WellKnownGrpcStatus::InvalidCode};
     112           0 :   }
     113          72 :   return {static_cast<Status::GrpcStatus>(grpc_status_code)};
     114          72 : }
     115             : 
     116             : absl::optional<Status::GrpcStatus> Common::getGrpcStatus(const Http::ResponseTrailerMap& trailers,
     117             :                                                          const Http::ResponseHeaderMap& headers,
     118             :                                                          const StreamInfo::StreamInfo& info,
     119           0 :                                                          bool allow_user_defined) {
     120             :   // The gRPC specification does not guarantee a gRPC status code will be returned from a gRPC
     121             :   // request. When it is returned, it will be in the response trailers. With that said, Envoy will
     122             :   // treat a trailers-only response as a headers-only response, so we have to check the following
     123             :   // in order:
     124             :   //   1. trailers gRPC status, if it exists.
     125             :   //   2. headers gRPC status, if it exists.
     126             :   //   3. Inferred from info HTTP status, if it exists.
     127           0 :   absl::optional<Grpc::Status::GrpcStatus> optional_status;
     128           0 :   optional_status = Grpc::Common::getGrpcStatus(trailers, allow_user_defined);
     129           0 :   if (optional_status.has_value()) {
     130           0 :     return optional_status;
     131           0 :   }
     132           0 :   optional_status = Grpc::Common::getGrpcStatus(headers, allow_user_defined);
     133           0 :   if (optional_status.has_value()) {
     134           0 :     return optional_status;
     135           0 :   }
     136           0 :   return info.responseCode() ? absl::optional<Grpc::Status::GrpcStatus>(
     137           0 :                                    Grpc::Utility::httpToGrpcStatus(info.responseCode().value()))
     138           0 :                              : absl::nullopt;
     139           0 : }
     140             : 
     141          28 : std::string Common::getGrpcMessage(const Http::ResponseHeaderOrTrailerMap& trailers) {
     142          28 :   const auto entry = trailers.GrpcMessage();
     143          28 :   return entry ? std::string(entry->value().getStringView()) : EMPTY_STRING;
     144          28 : }
     145             : 
     146             : absl::optional<google::rpc::Status>
     147           0 : Common::getGrpcStatusDetailsBin(const Http::HeaderMap& trailers) {
     148           0 :   const auto details_header = trailers.get(Http::Headers::get().GrpcStatusDetailsBin);
     149           0 :   if (details_header.empty()) {
     150           0 :     return absl::nullopt;
     151           0 :   }
     152             : 
     153             :   // Some implementations use non-padded base64 encoding for grpc-status-details-bin.
     154             :   // This is effectively a trusted header so using the first value is fine.
     155           0 :   auto decoded_value = Base64::decodeWithoutPadding(details_header[0]->value().getStringView());
     156           0 :   if (decoded_value.empty()) {
     157           0 :     return absl::nullopt;
     158           0 :   }
     159             : 
     160           0 :   google::rpc::Status status;
     161           0 :   if (!status.ParseFromString(decoded_value)) {
     162           0 :     return absl::nullopt;
     163           0 :   }
     164             : 
     165           0 :   return {std::move(status)};
     166           0 : }
     167             : 
     168         253 : Buffer::InstancePtr Common::serializeToGrpcFrame(const Protobuf::Message& message) {
     169             :   // http://www.grpc.io/docs/guides/wire.html
     170             :   // Reserve enough space for the entire message and the 5 byte header.
     171             :   // NB: we do not use prependGrpcFrameHeader because that would add another BufferFragment and this
     172             :   // (using a single BufferFragment) is more efficient.
     173         253 :   Buffer::InstancePtr body(new Buffer::OwnedImpl());
     174         253 :   const uint32_t size = message.ByteSize();
     175         253 :   const uint32_t alloc_size = size + 5;
     176         253 :   auto reservation = body->reserveSingleSlice(alloc_size);
     177         253 :   ASSERT(reservation.slice().len_ >= alloc_size);
     178         253 :   uint8_t* current = reinterpret_cast<uint8_t*>(reservation.slice().mem_);
     179         253 :   *current++ = 0; // flags
     180         253 :   const uint32_t nsize = htonl(size);
     181         253 :   safeMemcpyUnsafeDst(current, &nsize);
     182         253 :   current += sizeof(uint32_t);
     183         253 :   Protobuf::io::ArrayOutputStream stream(current, size, -1);
     184         253 :   Protobuf::io::CodedOutputStream codec_stream(&stream);
     185         253 :   message.SerializeWithCachedSizes(&codec_stream);
     186         253 :   reservation.commit(alloc_size);
     187         253 :   return body;
     188         253 : }
     189             : 
     190         417 : Buffer::InstancePtr Common::serializeMessage(const Protobuf::Message& message) {
     191         417 :   auto body = std::make_unique<Buffer::OwnedImpl>();
     192         417 :   const uint32_t size = message.ByteSize();
     193         417 :   auto reservation = body->reserveSingleSlice(size);
     194         417 :   ASSERT(reservation.slice().len_ >= size);
     195         417 :   uint8_t* current = reinterpret_cast<uint8_t*>(reservation.slice().mem_);
     196         417 :   Protobuf::io::ArrayOutputStream stream(current, size, -1);
     197         417 :   Protobuf::io::CodedOutputStream codec_stream(&stream);
     198         417 :   message.SerializeWithCachedSizes(&codec_stream);
     199         417 :   reservation.commit(size);
     200         417 :   return body;
     201         417 : }
     202             : 
     203             : absl::optional<std::chrono::milliseconds>
     204         241 : Common::getGrpcTimeout(const Http::RequestHeaderMap& request_headers) {
     205         241 :   const Http::HeaderEntry* header_grpc_timeout_entry = request_headers.GrpcTimeout();
     206         241 :   std::chrono::milliseconds timeout;
     207         241 :   if (header_grpc_timeout_entry) {
     208           0 :     int64_t grpc_timeout;
     209           0 :     absl::string_view timeout_entry = header_grpc_timeout_entry->value().getStringView();
     210           0 :     if (timeout_entry.empty()) {
     211             :       // Must be of the form TimeoutValue TimeoutUnit. See
     212             :       // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md#requests.
     213           0 :       return absl::nullopt;
     214           0 :     }
     215             :     // TimeoutValue must be a positive integer of at most 8 digits.
     216           0 :     if (absl::SimpleAtoi(timeout_entry.substr(0, timeout_entry.size() - 1), &grpc_timeout) &&
     217           0 :         grpc_timeout >= 0 && static_cast<uint64_t>(grpc_timeout) <= MAX_GRPC_TIMEOUT_VALUE) {
     218           0 :       const char unit = timeout_entry[timeout_entry.size() - 1];
     219           0 :       switch (unit) {
     220           0 :       case 'H':
     221           0 :         return std::chrono::hours(grpc_timeout);
     222           0 :       case 'M':
     223           0 :         return std::chrono::minutes(grpc_timeout);
     224           0 :       case 'S':
     225           0 :         return std::chrono::seconds(grpc_timeout);
     226           0 :       case 'm':
     227           0 :         return std::chrono::milliseconds(grpc_timeout);
     228           0 :         break;
     229           0 :       case 'u':
     230           0 :         timeout = std::chrono::duration_cast<std::chrono::milliseconds>(
     231           0 :             std::chrono::microseconds(grpc_timeout));
     232           0 :         if (timeout < std::chrono::microseconds(grpc_timeout)) {
     233           0 :           timeout++;
     234           0 :         }
     235           0 :         return timeout;
     236           0 :       case 'n':
     237           0 :         timeout = std::chrono::duration_cast<std::chrono::milliseconds>(
     238           0 :             std::chrono::nanoseconds(grpc_timeout));
     239           0 :         if (timeout < std::chrono::nanoseconds(grpc_timeout)) {
     240           0 :           timeout++;
     241           0 :         }
     242           0 :         return timeout;
     243           0 :       }
     244           0 :     }
     245           0 :   }
     246         241 :   return absl::nullopt;
     247         241 : }
     248             : 
     249             : void Common::toGrpcTimeout(const std::chrono::milliseconds& timeout,
     250          26 :                            Http::RequestHeaderMap& headers) {
     251          26 :   uint64_t time = timeout.count();
     252          26 :   static const char units[] = "mSMH";
     253          26 :   const char* unit = units; // start with milliseconds
     254          26 :   if (time > MAX_GRPC_TIMEOUT_VALUE) {
     255           0 :     time /= 1000; // Convert from milliseconds to seconds
     256           0 :     unit++;
     257           0 :   }
     258          26 :   while (time > MAX_GRPC_TIMEOUT_VALUE) {
     259           0 :     if (*unit == 'H') {
     260           0 :       time = MAX_GRPC_TIMEOUT_VALUE; // No bigger unit available, clip to max 8 digit hours.
     261           0 :     } else {
     262           0 :       time /= 60; // Convert from seconds to minutes to hours
     263           0 :       unit++;
     264           0 :     }
     265           0 :   }
     266          26 :   headers.setGrpcTimeout(absl::StrCat(time, absl::string_view(unit, 1)));
     267          26 : }
     268             : 
     269             : Http::RequestMessagePtr
     270             : Common::prepareHeaders(const std::string& host_name, const std::string& service_full_name,
     271             :                        const std::string& method_name,
     272          94 :                        const absl::optional<std::chrono::milliseconds>& timeout) {
     273          94 :   Http::RequestMessagePtr message(new Http::RequestMessageImpl());
     274          94 :   message->headers().setReferenceMethod(Http::Headers::get().MethodValues.Post);
     275          94 :   message->headers().setPath(absl::StrCat("/", service_full_name, "/", method_name));
     276          94 :   message->headers().setHost(host_name);
     277             :   // According to https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md TE should appear
     278             :   // before Timeout and ContentType.
     279          94 :   message->headers().setReferenceTE(Http::Headers::get().TEValues.Trailers);
     280          94 :   if (timeout) {
     281           0 :     toGrpcTimeout(timeout.value(), message->headers());
     282           0 :   }
     283          94 :   message->headers().setReferenceContentType(Http::Headers::get().ContentTypeValues.Grpc);
     284             : 
     285          94 :   return message;
     286          94 : }
     287             : 
     288         673 : const std::string& Common::typeUrlPrefix() {
     289         673 :   CONSTRUCT_ON_FIRST_USE(std::string, "type.googleapis.com");
     290         673 : }
     291             : 
     292         665 : std::string Common::typeUrl(const std::string& qualified_name) {
     293         665 :   return typeUrlPrefix() + "/" + qualified_name;
     294         665 : }
     295             : 
     296         416 : void Common::prependGrpcFrameHeader(Buffer::Instance& buffer) {
     297         416 :   std::array<char, 5> header;
     298         416 :   header[0] = GRPC_FH_DEFAULT; // flags
     299         416 :   const uint32_t nsize = htonl(buffer.length());
     300         416 :   safeMemcpyUnsafeDst(&header[1], &nsize);
     301         416 :   buffer.prepend(absl::string_view(&header[0], 5));
     302         416 : }
     303             : 
     304           0 : bool Common::parseBufferInstance(Buffer::InstancePtr&& buffer, Protobuf::Message& proto) {
     305           0 :   Buffer::ZeroCopyInputStreamImpl stream(std::move(buffer));
     306           0 :   return proto.ParseFromZeroCopyStream(&stream);
     307           0 : }
     308             : 
     309             : absl::optional<Common::RequestNames>
     310           1 : Common::resolveServiceAndMethod(const Http::HeaderEntry* path) {
     311           1 :   absl::optional<RequestNames> request_names;
     312           1 :   if (path == nullptr) {
     313           0 :     return request_names;
     314           0 :   }
     315           1 :   absl::string_view str = path->value().getStringView();
     316           1 :   str = str.substr(0, str.find('?'));
     317           1 :   const auto parts = StringUtil::splitToken(str, "/");
     318           1 :   if (parts.size() != 2) {
     319           0 :     return request_names;
     320           0 :   }
     321           1 :   request_names = RequestNames{parts[0], parts[1]};
     322           1 :   return request_names;
     323           1 : }
     324             : 
     325             : } // namespace Grpc
     326             : } // namespace Envoy

Generated by: LCOV version 1.15