1
#include "source/extensions/http/injected_credentials/oauth2/oauth_client.h"
2

            
3
#include <chrono>
4

            
5
#include "envoy/http/async_client.h"
6
#include "envoy/http/message.h"
7
#include "envoy/upstream/cluster_manager.h"
8

            
9
#include "source/common/common/empty_string.h"
10
#include "source/common/common/fmt.h"
11
#include "source/common/common/logger.h"
12
#include "source/common/http/message_impl.h"
13
#include "source/common/http/utility.h"
14
#include "source/common/protobuf/message_validator_impl.h"
15
#include "source/common/protobuf/utility.h"
16
#include "source/extensions/http/injected_credentials/oauth2/oauth_response.pb.h"
17

            
18
namespace Envoy {
19
namespace Extensions {
20
namespace Http {
21
namespace InjectedCredentials {
22
namespace OAuth2 {
23

            
24
namespace {
25
constexpr const char* GetAccessTokenBodyFormatString =
26
    "grant_type=client_credentials&client_id={0}&client_secret={1}";
27
constexpr const char* GetAccessTokenBodyFormatStringWithScopes =
28
    "grant_type=client_credentials&client_id={0}&client_secret={1}&scope={2}";
29

            
30
} // namespace
31

            
32
OAuth2Client::GetTokenResult
33
OAuth2ClientImpl::asyncGetAccessToken(const std::string& client_id, const std::string& secret,
34
                                      const std::string& scopes,
35
32
                                      const std::map<std::string, std::string>& endpoint_params) {
36
32
  if (in_flight_request_ != nullptr) {
37
    return GetTokenResult::NotDispatchedAlreadyInFlight;
38
  }
39
32
  const auto encoded_client_id = Envoy::Http::Utility::PercentEncoding::encode(client_id, ":/=&?");
40
32
  const auto encoded_secret = Envoy::Http::Utility::PercentEncoding::encode(secret, ":/=&?");
41

            
42
32
  Envoy::Http::RequestMessagePtr request = createPostRequest();
43
32
  std::string body;
44
32
  if (scopes.empty()) {
45
26
    body = fmt::format(GetAccessTokenBodyFormatString, encoded_client_id, encoded_secret);
46
26
  } else {
47
6
    const auto encoded_scopes = Envoy::Http::Utility::PercentEncoding::encode(scopes, ":/=&?");
48
6
    body = fmt::format(GetAccessTokenBodyFormatStringWithScopes, encoded_client_id, encoded_secret,
49
6
                       encoded_scopes);
50
6
  }
51

            
52
32
  for (const auto& [param_name, param_value] : endpoint_params) {
53
4
    const auto encoded_name = Envoy::Http::Utility::PercentEncoding::encode(param_name, ":/=&?");
54
4
    const auto encoded_value = Envoy::Http::Utility::PercentEncoding::encode(param_value, ":/=&?");
55
4
    body += fmt::format("&{}={}", encoded_name, encoded_value);
56
4
  }
57

            
58
32
  request->body().add(body);
59
32
  request->headers().setContentLength(body.length());
60
32
  return dispatchRequest(std::move(request));
61
32
}
62

            
63
OAuth2Client::GetTokenResult
64
32
OAuth2ClientImpl::dispatchRequest(Envoy::Http::RequestMessagePtr&& msg) {
65
32
  const auto thread_local_cluster = cm_.getThreadLocalCluster(uri_.cluster());
66
32
  if (thread_local_cluster != nullptr) {
67
28
    in_flight_request_ = thread_local_cluster->httpAsyncClient().send(
68
28
        std::move(msg), *this,
69
28
        Envoy::Http::AsyncClient::RequestOptions().setTimeout(
70
28
            std::chrono::milliseconds(PROTOBUF_GET_MS_REQUIRED(uri_, timeout))));
71
28
  } else {
72
4
    return GetTokenResult::NotDispatchedClusterNotFound;
73
4
  }
74
28
  return GetTokenResult::DispatchedRequest;
75
32
}
76

            
77
void OAuth2ClientImpl::onSuccess(const Envoy::Http::AsyncClient::Request&,
78
26
                                 Envoy::Http::ResponseMessagePtr&& message) {
79
26
  in_flight_request_ = nullptr;
80
  // Check that the auth cluster returned a happy response.
81
26
  const auto response_code = message->headers().Status()->value().getStringView();
82
26
  if (response_code != "200") {
83
2
    ENVOY_LOG(error, "Oauth response code: {}", response_code);
84
2
    ENVOY_LOG(error, "Oauth response body: {}", message->bodyAsString());
85
2
    parent_->onGetAccessTokenFailure(FilterCallbacks::FailureReason::BadResponseCode);
86
2
    return;
87
2
  }
88

            
89
24
  const std::string response_body = message->bodyAsString();
90

            
91
24
  envoy::extensions::http::injected_credentials::oauth2::OAuthResponse response;
92
24
  TRY_NEEDS_AUDIT {
93
24
    MessageUtil::loadFromJson(response_body, response, ProtobufMessage::getNullValidationVisitor());
94
24
  }
95
24
  END_TRY catch (EnvoyException& e) {
96
2
    ENVOY_LOG(error, "Error parsing response body, received exception: {}", e.what());
97
2
    ENVOY_LOG(error, "Response body: {}", response_body);
98
2
    parent_->onGetAccessTokenFailure(FilterCallbacks::FailureReason::BadToken);
99
2
    return;
100
2
  }
101

            
102
22
  if (!response.has_access_token() || !response.has_expires_in()) {
103
2
    ENVOY_LOG(error, "No access token or expiration after asyncGetAccessToken");
104
2
    parent_->onGetAccessTokenFailure(FilterCallbacks::FailureReason::BadToken);
105
2
    return;
106
2
  }
107

            
108
20
  const std::string access_token{PROTOBUF_GET_WRAPPED_REQUIRED(response, access_token)};
109
20
  const std::chrono::seconds expires_in{PROTOBUF_GET_WRAPPED_REQUIRED(response, expires_in)};
110
20
  parent_->onGetAccessTokenSuccess(access_token, expires_in);
111
20
}
112

            
113
void OAuth2ClientImpl::onFailure(const Envoy::Http::AsyncClient::Request&,
114
2
                                 Envoy::Http::AsyncClient::FailureReason) {
115
2
  ENVOY_LOG(error, "OAuth request failed: stream reset");
116
2
  in_flight_request_ = nullptr;
117
2
  parent_->onGetAccessTokenFailure(FilterCallbacks::FailureReason::StreamReset);
118
2
}
119

            
120
} // namespace OAuth2
121
} // namespace InjectedCredentials
122
} // namespace Http
123
} // namespace Extensions
124
} // namespace Envoy