LCOV - code coverage report
Current view: top level - source/extensions/common/aws - utility.cc (source / functions) Hit Total Coverage
Test: coverage.dat Lines: 95 258 36.8 %
Date: 2024-01-05 06:35:25 Functions: 3 17 17.6 %

          Line data    Source code
       1             : #include "source/extensions/common/aws/utility.h"
       2             : 
       3             : #include "envoy/upstream/cluster_manager.h"
       4             : 
       5             : #include "source/common/common/empty_string.h"
       6             : #include "source/common/common/fmt.h"
       7             : #include "source/common/common/utility.h"
       8             : #include "source/common/protobuf/message_validator_impl.h"
       9             : #include "source/common/protobuf/utility.h"
      10             : 
      11             : #include "absl/strings/match.h"
      12             : #include "absl/strings/str_join.h"
      13             : #include "absl/strings/str_split.h"
      14             : #include "curl/curl.h"
      15             : #include "fmt/printf.h"
      16             : 
      17             : namespace Envoy {
      18             : namespace Extensions {
      19             : namespace Common {
      20             : namespace Aws {
      21             : 
      22             : constexpr absl::string_view PATH_SPLITTER = "/";
      23             : constexpr absl::string_view QUERY_PARAM_SEPERATOR = "=";
      24             : constexpr absl::string_view QUERY_SEPERATOR = "&";
      25             : constexpr absl::string_view QUERY_SPLITTER = "?";
      26             : constexpr absl::string_view RESERVED_CHARS = "-._~";
      27             : constexpr absl::string_view S3_SERVICE_NAME = "s3";
      28             : constexpr absl::string_view URI_ENCODE = "%{:02X}";
      29             : constexpr absl::string_view URI_DOUBLE_ENCODE = "%25{:02X}";
      30             : 
      31             : std::map<std::string, std::string>
      32             : Utility::canonicalizeHeaders(const Http::RequestHeaderMap& headers,
      33           0 :                              const std::vector<Matchers::StringMatcherPtr>& excluded_headers) {
      34           0 :   std::map<std::string, std::string> out;
      35           0 :   headers.iterate(
      36           0 :       [&out, &excluded_headers](const Http::HeaderEntry& entry) -> Http::HeaderMap::Iterate {
      37             :         // Skip empty headers
      38           0 :         if (entry.key().empty() || entry.value().empty()) {
      39           0 :           return Http::HeaderMap::Iterate::Continue;
      40           0 :         }
      41             :         // Pseudo-headers should not be canonicalized
      42           0 :         if (!entry.key().getStringView().empty() && entry.key().getStringView()[0] == ':') {
      43           0 :           return Http::HeaderMap::Iterate::Continue;
      44           0 :         }
      45           0 :         const auto key = entry.key().getStringView();
      46           0 :         if (std::any_of(excluded_headers.begin(), excluded_headers.end(),
      47           0 :                         [&key](const Matchers::StringMatcherPtr& matcher) {
      48           0 :                           return matcher->match(key);
      49           0 :                         })) {
      50           0 :           return Http::HeaderMap::Iterate::Continue;
      51           0 :         }
      52             : 
      53           0 :         std::string value(entry.value().getStringView());
      54             :         // Remove leading, trailing, and deduplicate repeated ascii spaces
      55           0 :         absl::RemoveExtraAsciiWhitespace(&value);
      56           0 :         const auto iter = out.find(std::string(entry.key().getStringView()));
      57             :         // If the entry already exists, append the new value to the end
      58           0 :         if (iter != out.end()) {
      59           0 :           iter->second += fmt::format(",{}", value);
      60           0 :         } else {
      61           0 :           out.emplace(std::string(entry.key().getStringView()), value);
      62           0 :         }
      63           0 :         return Http::HeaderMap::Iterate::Continue;
      64           0 :       });
      65             :   // The AWS SDK has a quirk where it removes "default ports" (80, 443) from the host headers
      66             :   // Additionally, we canonicalize the :authority header as "host"
      67             :   // TODO(suniltheta): This may need to be tweaked to canonicalize :authority for HTTP/2 requests
      68           0 :   const absl::string_view authority_header = headers.getHostValue();
      69           0 :   if (!authority_header.empty()) {
      70           0 :     const auto parts = StringUtil::splitToken(authority_header, ":");
      71           0 :     if (parts.size() > 1 && (parts[1] == "80" || parts[1] == "443")) {
      72             :       // Has default port, so use only the host part
      73           0 :       out.emplace(Http::Headers::get().HostLegacy.get(), std::string(parts[0]));
      74           0 :     } else {
      75           0 :       out.emplace(Http::Headers::get().HostLegacy.get(), std::string(authority_header));
      76           0 :     }
      77           0 :   }
      78           0 :   return out;
      79           0 : }
      80             : 
      81             : std::string Utility::createCanonicalRequest(
      82             :     absl::string_view service_name, absl::string_view method, absl::string_view path,
      83           0 :     const std::map<std::string, std::string>& canonical_headers, absl::string_view content_hash) {
      84           0 :   std::vector<absl::string_view> parts;
      85           0 :   parts.emplace_back(method);
      86             :   // don't include the query part of the path
      87           0 :   const auto path_part = StringUtil::cropRight(path, QUERY_SPLITTER);
      88           0 :   const auto canonicalized_path = path_part.empty()
      89           0 :                                       ? std::string{PATH_SPLITTER}
      90           0 :                                       : canonicalizePathString(path_part, service_name);
      91           0 :   parts.emplace_back(canonicalized_path);
      92           0 :   const auto query_part = StringUtil::cropLeft(path, QUERY_SPLITTER);
      93             :   // if query_part == path_part, then there is no query
      94           0 :   const auto canonicalized_query =
      95           0 :       query_part == path_part ? EMPTY_STRING : Utility::canonicalizeQueryString(query_part);
      96           0 :   parts.emplace_back(absl::string_view(canonicalized_query));
      97           0 :   std::vector<std::string> formatted_headers;
      98           0 :   formatted_headers.reserve(canonical_headers.size());
      99           0 :   for (const auto& header : canonical_headers) {
     100           0 :     formatted_headers.emplace_back(fmt::format("{}:{}", header.first, header.second));
     101           0 :     parts.emplace_back(formatted_headers.back());
     102           0 :   }
     103             :   // need an extra blank space after the canonical headers
     104           0 :   parts.emplace_back(EMPTY_STRING);
     105           0 :   const auto signed_headers = Utility::joinCanonicalHeaderNames(canonical_headers);
     106           0 :   parts.emplace_back(signed_headers);
     107           0 :   parts.emplace_back(content_hash);
     108           0 :   return absl::StrJoin(parts, "\n");
     109           0 : }
     110             : 
     111             : /**
     112             :  * Normalizes the path string based on AWS requirements.
     113             :  * See step 2 in https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
     114             :  */
     115             : std::string Utility::canonicalizePathString(absl::string_view path_string,
     116           0 :                                             absl::string_view service_name) {
     117             :   // If service is S3, do not normalize but only encode the path
     118           0 :   if (absl::EqualsIgnoreCase(service_name, S3_SERVICE_NAME)) {
     119           0 :     return encodePathSegment(path_string, service_name);
     120           0 :   }
     121             :   // If service is not S3, normalize and encode the path
     122           0 :   const auto path_segments = StringUtil::splitToken(path_string, std::string{PATH_SPLITTER});
     123           0 :   std::vector<std::string> path_list;
     124           0 :   path_list.reserve(path_segments.size());
     125           0 :   for (const auto& path_segment : path_segments) {
     126           0 :     if (path_segment.empty()) {
     127           0 :       continue;
     128           0 :     }
     129           0 :     path_list.emplace_back(encodePathSegment(path_segment, service_name));
     130           0 :   }
     131           0 :   auto canonical_path_string =
     132           0 :       fmt::format("{}{}", PATH_SPLITTER, absl::StrJoin(path_list, PATH_SPLITTER));
     133             :   // Handle corner case when path ends with '/'
     134           0 :   if (absl::EndsWith(path_string, PATH_SPLITTER) && canonical_path_string.size() > 1) {
     135           0 :     canonical_path_string.push_back(PATH_SPLITTER[0]);
     136           0 :   }
     137           0 :   return canonical_path_string;
     138           0 : }
     139             : 
     140           0 : bool isReservedChar(const char c) {
     141           0 :   return std::isalnum(c) || RESERVED_CHARS.find(c) != std::string::npos;
     142           0 : }
     143             : 
     144           0 : void encodeS3Path(std::string& encoded, const char& c) {
     145             :   // Do not encode '/' for S3
     146           0 :   if (c == PATH_SPLITTER[0]) {
     147           0 :     encoded.push_back(c);
     148           0 :   } else {
     149           0 :     absl::StrAppend(&encoded, fmt::format(URI_ENCODE, c));
     150           0 :   }
     151           0 : }
     152             : 
     153           0 : std::string Utility::encodePathSegment(absl::string_view decoded, absl::string_view service_name) {
     154           0 :   std::string encoded;
     155           0 :   for (char c : decoded) {
     156           0 :     if (isReservedChar(c)) {
     157             :       // Escape unreserved chars from RFC 3986
     158           0 :       encoded.push_back(c);
     159           0 :     } else if (absl::EqualsIgnoreCase(service_name, S3_SERVICE_NAME)) {
     160           0 :       encodeS3Path(encoded, c);
     161           0 :     } else {
     162             :       // TODO: @aws, There is some inconsistency between AWS services if this should be double
     163             :       // encoded or not. We need to parameterize this and expose this in the config. Ref:
     164             :       // https://github.com/aws/aws-sdk-cpp/blob/main/aws-cpp-sdk-core/source/auth/AWSAuthSigner.cpp#L79-L93
     165           0 :       absl::StrAppend(&encoded, fmt::format(URI_ENCODE, c));
     166           0 :     }
     167           0 :   }
     168           0 :   return encoded;
     169           0 : }
     170             : 
     171             : /**
     172             :  * Normalizes the query string based on AWS requirements.
     173             :  * See step 3 in https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
     174             :  */
     175           0 : std::string Utility::canonicalizeQueryString(absl::string_view query_string) {
     176             :   // Sort query string based on param name and append "=" if value is missing
     177           0 :   const auto query_fragments = StringUtil::splitToken(query_string, QUERY_SEPERATOR);
     178           0 :   std::vector<std::pair<std::string, std::string>> query_list;
     179           0 :   for (const auto& query_fragment : query_fragments) {
     180             :     // Only split at the first "=" and encode the rest
     181           0 :     const std::vector<std::string> query =
     182           0 :         absl::StrSplit(query_fragment, absl::MaxSplits(QUERY_PARAM_SEPERATOR, 1));
     183           0 :     if (!query.empty()) {
     184           0 :       const absl::string_view param = query[0];
     185           0 :       const absl::string_view value = query.size() > 1 ? query[1] : EMPTY_STRING;
     186           0 :       query_list.emplace_back(std::make_pair(param, value));
     187           0 :     }
     188           0 :   }
     189             :   // Sort query params by name and value
     190           0 :   std::sort(query_list.begin(), query_list.end());
     191             :   // Encode query params name and value separately
     192           0 :   for (auto& query : query_list) {
     193           0 :     query = std::make_pair(Utility::encodeQueryParam(query.first),
     194           0 :                            Utility::encodeQueryParam(query.second));
     195           0 :   }
     196           0 :   return absl::StrJoin(query_list, QUERY_SEPERATOR, absl::PairFormatter(QUERY_PARAM_SEPERATOR));
     197           0 : }
     198             : 
     199           0 : std::string Utility::encodeQueryParam(absl::string_view decoded) {
     200           0 :   std::string encoded;
     201           0 :   for (char c : decoded) {
     202           0 :     if (isReservedChar(c) || c == '%') {
     203             :       // Escape unreserved chars from RFC 3986
     204           0 :       encoded.push_back(c);
     205           0 :     } else if (c == '+') {
     206             :       // Encode '+' as space
     207           0 :       absl::StrAppend(&encoded, "%20");
     208           0 :     } else if (c == QUERY_PARAM_SEPERATOR[0]) {
     209             :       // Double encode '='
     210           0 :       absl::StrAppend(&encoded, fmt::format(URI_DOUBLE_ENCODE, c));
     211           0 :     } else {
     212           0 :       absl::StrAppend(&encoded, fmt::format(URI_ENCODE, c));
     213           0 :     }
     214           0 :   }
     215           0 :   return encoded;
     216           0 : }
     217             : 
     218             : std::string
     219           0 : Utility::joinCanonicalHeaderNames(const std::map<std::string, std::string>& canonical_headers) {
     220           0 :   return absl::StrJoin(canonical_headers, ";", [](auto* out, const auto& pair) {
     221           0 :     return absl::StrAppend(out, pair.first);
     222           0 :   });
     223           0 : }
     224             : 
     225           0 : std::string Utility::getSTSEndpoint(absl::string_view region) {
     226           0 :   if (region == "cn-northwest-1" || region == "cn-north-1") {
     227           0 :     return fmt::format("sts.{}.amazonaws.com.cn", region);
     228           0 :   }
     229             : #ifdef ENVOY_SSL_FIPS
     230             :   // Use AWS STS FIPS endpoints in FIPS mode https://docs.aws.amazon.com/general/latest/gr/sts.html.
     231             :   // Note: AWS GovCloud doesn't have separate fips endpoints.
     232             :   // TODO(suniltheta): Include `ca-central-1` when sts supports a dedicated FIPS endpoint.
     233             :   if (region == "us-east-1" || region == "us-east-2" || region == "us-west-1" ||
     234             :       region == "us-west-2") {
     235             :     return fmt::format("sts-fips.{}.amazonaws.com", region);
     236             :   }
     237             : #endif
     238           0 :   return fmt::format("sts.{}.amazonaws.com", region);
     239           0 : }
     240             : 
     241           0 : static size_t curlCallback(char* ptr, size_t, size_t nmemb, void* data) {
     242           0 :   auto buf = static_cast<std::string*>(data);
     243           0 :   buf->append(ptr, nmemb);
     244           0 :   return nmemb;
     245           0 : }
     246             : 
     247          10 : absl::optional<std::string> Utility::fetchMetadata(Http::RequestMessage& message) {
     248          10 :   static const size_t MAX_RETRIES = 4;
     249          10 :   static const std::chrono::milliseconds RETRY_DELAY{1000};
     250          10 :   static const std::chrono::seconds TIMEOUT{5};
     251             : 
     252          10 :   CURL* const curl = curl_easy_init();
     253          10 :   if (!curl) {
     254           0 :     return absl::nullopt;
     255          10 :   };
     256             : 
     257          10 :   const auto host = message.headers().getHostValue();
     258          10 :   const auto path = message.headers().getPathValue();
     259          10 :   const auto method = message.headers().getMethodValue();
     260          10 :   const auto scheme = message.headers().getSchemeValue();
     261             : 
     262          10 :   const std::string url = fmt::format("{}://{}{}", scheme, host, path);
     263          10 :   curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
     264          10 :   curl_easy_setopt(curl, CURLOPT_TIMEOUT, TIMEOUT.count());
     265          10 :   curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L);
     266          10 :   curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L);
     267             : 
     268          10 :   std::string buffer;
     269          10 :   curl_easy_setopt(curl, CURLOPT_WRITEDATA, &buffer);
     270          10 :   curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, curlCallback);
     271             : 
     272          10 :   struct curl_slist* headers = nullptr;
     273          45 :   message.headers().iterate([&headers](const Http::HeaderEntry& entry) -> Http::HeaderMap::Iterate {
     274             :     // Skip pseudo-headers
     275          45 :     if (!entry.key().getStringView().empty() && entry.key().getStringView()[0] == ':') {
     276          40 :       return Http::HeaderMap::Iterate::Continue;
     277          40 :     }
     278           5 :     const std::string header =
     279           5 :         fmt::format("{}: {}", entry.key().getStringView(), entry.value().getStringView());
     280           5 :     headers = curl_slist_append(headers, header.c_str());
     281           5 :     return Http::HeaderMap::Iterate::Continue;
     282          45 :   });
     283             : 
     284             :   // This function only support doing PUT(UPLOAD) other than GET(_default_) operation.
     285          10 :   if (Http::Headers::get().MethodValues.Put == method) {
     286             :     // https://curl.se/libcurl/c/CURLOPT_PUT.html is deprecated
     287             :     // so using https://curl.se/libcurl/c/CURLOPT_UPLOAD.html.
     288           5 :     curl_easy_setopt(curl, CURLOPT_UPLOAD, 1L);
     289             :     // To call PUT on HTTP 1.0 we must specify a value for the upload size
     290             :     // since some old EC2's metadata service will be serving on HTTP 1.0.
     291             :     // https://curl.se/libcurl/c/CURLOPT_INFILESIZE.html
     292           5 :     curl_easy_setopt(curl, CURLOPT_INFILESIZE, 0);
     293             :     // Disabling `Expect: 100-continue` header to get a response
     294             :     // in the first attempt as the put size is zero.
     295             :     // https://everything.curl.dev/http/post/expect100
     296           5 :     headers = curl_slist_append(headers, "Expect:");
     297           5 :   }
     298             : 
     299          10 :   if (headers != nullptr) {
     300           5 :     curl_easy_setopt(curl, CURLOPT_HTTPHEADER, headers);
     301           5 :   }
     302             : 
     303          50 :   for (size_t retry = 0; retry < MAX_RETRIES; retry++) {
     304          40 :     const CURLcode res = curl_easy_perform(curl);
     305          40 :     if (res == CURLE_OK) {
     306           0 :       break;
     307           0 :     }
     308          40 :     ENVOY_LOG_MISC(debug, "Could not fetch AWS metadata: {}", curl_easy_strerror(res));
     309          40 :     buffer.clear();
     310          40 :     std::this_thread::sleep_for(RETRY_DELAY);
     311          40 :   }
     312             : 
     313          10 :   curl_easy_cleanup(curl);
     314          10 :   curl_slist_free_all(headers);
     315             : 
     316          10 :   return buffer.empty() ? absl::nullopt : absl::optional<std::string>(buffer);
     317          10 : }
     318             : 
     319             : bool Utility::addInternalClusterStatic(
     320             :     Upstream::ClusterManager& cm, absl::string_view cluster_name,
     321           6 :     const envoy::config::cluster::v3::Cluster::DiscoveryType cluster_type, absl::string_view uri) {
     322             :   // Check if local cluster exists with that name.
     323           6 :   if (cm.getThreadLocalCluster(cluster_name) == nullptr) {
     324             :     // Make sure we run this on main thread.
     325           6 :     TRY_ASSERT_MAIN_THREAD {
     326           6 :       envoy::config::cluster::v3::Cluster cluster;
     327           6 :       absl::string_view host_port;
     328           6 :       absl::string_view path;
     329           6 :       Http::Utility::extractHostPathFromUri(uri, host_port, path);
     330           6 :       const auto host_attributes = Http::Utility::parseAuthority(host_port);
     331           6 :       const auto host = host_attributes.host_;
     332           6 :       const auto port = host_attributes.port_ ? host_attributes.port_.value() : 80;
     333             : 
     334           6 :       cluster.set_name(cluster_name);
     335           6 :       cluster.set_type(cluster_type);
     336           6 :       cluster.mutable_connect_timeout()->set_seconds(5);
     337           6 :       cluster.mutable_load_assignment()->set_cluster_name(cluster_name);
     338           6 :       auto* endpoint = cluster.mutable_load_assignment()
     339           6 :                            ->add_endpoints()
     340           6 :                            ->add_lb_endpoints()
     341           6 :                            ->mutable_endpoint();
     342           6 :       auto* addr = endpoint->mutable_address();
     343           6 :       addr->mutable_socket_address()->set_address(host);
     344           6 :       addr->mutable_socket_address()->set_port_value(port);
     345           6 :       cluster.set_lb_policy(envoy::config::cluster::v3::Cluster::ROUND_ROBIN);
     346           6 :       envoy::extensions::upstreams::http::v3::HttpProtocolOptions protocol_options;
     347           6 :       auto* http_protocol_options =
     348           6 :           protocol_options.mutable_explicit_http_config()->mutable_http_protocol_options();
     349           6 :       http_protocol_options->set_accept_http_10(true);
     350           6 :       (*cluster.mutable_typed_extension_protocol_options())
     351           6 :           ["envoy.extensions.upstreams.http.v3.HttpProtocolOptions"]
     352           6 :               .PackFrom(protocol_options);
     353             : 
     354             :       // Add tls transport socket if cluster supports https over port 443.
     355           6 :       if (port == 443) {
     356           0 :         auto* socket = cluster.mutable_transport_socket();
     357           0 :         envoy::extensions::transport_sockets::tls::v3::UpstreamTlsContext tls_socket;
     358           0 :         socket->set_name("envoy.transport_sockets.tls");
     359           0 :         socket->mutable_typed_config()->PackFrom(tls_socket);
     360           0 :       }
     361             : 
     362             :       // TODO(suniltheta): use random number generator here for cluster version.
     363             :       // While adding multiple clusters make sure that change in random version number across
     364             :       // multiple clusters won't make Envoy delete/replace previously registered internal cluster.
     365           6 :       cm.addOrUpdateCluster(cluster, "12345");
     366             : 
     367           6 :       const auto cluster_type_str = envoy::config::cluster::v3::Cluster::DiscoveryType_descriptor()
     368           6 :                                         ->FindValueByNumber(cluster_type)
     369           6 :                                         ->name();
     370           6 :       ENVOY_LOG_MISC(info,
     371           6 :                      "Added a {} internal cluster [name: {}, address:{}] to fetch aws "
     372           6 :                      "credentials",
     373           6 :                      cluster_type_str, cluster_name, host_port);
     374           6 :     }
     375           6 :     END_TRY
     376           6 :     CATCH(const EnvoyException& e, {
     377           6 :       ENVOY_LOG_MISC(error, "Failed to add internal cluster {}: {}", cluster_name, e.what());
     378           6 :       return false;
     379           6 :     });
     380           6 :   }
     381           6 :   return true;
     382           6 : }
     383             : 
     384             : } // namespace Aws
     385             : } // namespace Common
     386             : } // namespace Extensions
     387             : } // namespace Envoy

Generated by: LCOV version 1.15