1
#include "source/extensions/http/injected_credentials/oauth2/token_provider.h"
2

            
3
#include <chrono>
4

            
5
namespace Envoy {
6
namespace Extensions {
7
namespace Http {
8
namespace InjectedCredentials {
9
namespace OAuth2 {
10

            
11
namespace {
12

            
13
constexpr absl::string_view DEFAULT_AUTH_SCOPE = "";
14
// Transforms the proto list of 'auth_scopes' into a vector of std::string, also
15
// handling the default value logic.
16
25
std::string oauthScopesList(const Protobuf::RepeatedPtrField<std::string>& auth_scopes_protos) {
17
25
  std::vector<std::string> scopes;
18

            
19
  // If 'auth_scopes' is empty it must return a list with the default value.
20
25
  if (auth_scopes_protos.empty()) {
21
21
    scopes.emplace_back(DEFAULT_AUTH_SCOPE);
22
21
  } else {
23
4
    scopes.reserve(auth_scopes_protos.size());
24

            
25
4
    for (const auto& scope : auth_scopes_protos) {
26
4
      scopes.emplace_back(scope);
27
4
    }
28
4
  }
29
25
  return absl::StrJoin(scopes, " ");
30
25
}
31

            
32
// Transforms the proto list of 'endpoint_params' into a map of string key-value pairs.
33
std::map<std::string, std::string> endpointParamsMap(
34
    const Protobuf::RepeatedPtrField<
35
        envoy::extensions::http::injected_credentials::oauth2::v3::OAuth2::EndpointParameter>&
36
25
        endpoint_params_protos) {
37
25
  std::map<std::string, std::string> params;
38
25
  for (const auto& param : endpoint_params_protos) {
39
4
    params[param.name()] = param.value();
40
4
  }
41
25
  return params;
42
25
}
43
} // namespace
44

            
45
// TokenProvider Constructor
46
TokenProvider::TokenProvider(Common::SecretReaderConstSharedPtr secret_reader,
47
                             ThreadLocal::SlotAllocator& tls, Upstream::ClusterManager& cm,
48
                             const OAuth2& proto_config, Event::Dispatcher& dispatcher,
49
                             const std::string& stats_prefix, Stats::Scope& scope)
50
25
    : secret_reader_(secret_reader), tls_(tls.allocateSlot()),
51
25
      client_id_(proto_config.client_credentials().client_id()),
52
25
      oauth_scopes_(oauthScopesList(proto_config.scopes())),
53
25
      endpoint_params_(endpointParamsMap(proto_config.endpoint_params())), dispatcher_(&dispatcher),
54
25
      stats_(generateStats(stats_prefix + "oauth2.", scope)),
55
      retry_interval_(
56
25
          proto_config.token_fetch_retry_interval().seconds() > 0
57
25
              ? std::chrono::seconds(proto_config.token_fetch_retry_interval().seconds())
58
25
              : std::chrono::seconds(2)) {
59
33
  timer_ = dispatcher_->createTimer([this]() -> void { asyncGetAccessToken(); });
60
25
  ThreadLocalOauth2ClientCredentialsTokenSharedPtr empty(
61
25
      new ThreadLocalOauth2ClientCredentialsToken(""));
62
25
  tls_->set(
63
49
      [empty](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { return empty; });
64
  // initialize oauth2 client
65
25
  oauth2_client_ = std::make_unique<OAuth2ClientImpl>(cm, proto_config.token_endpoint());
66
  // set the callback for the oauth2 client
67
25
  oauth2_client_->setCallbacks(*this);
68
25
  asyncGetAccessToken();
69
25
}
70

            
71
// TokenProvider asyncGetAccessToken
72
58
void TokenProvider::asyncGetAccessToken() {
73
  // get the access token from the oauth2 client
74
58
  if (timer_->enabled()) {
75
1
    timer_->disableTimer();
76
1
  }
77
58
  if (secret_reader_->credential().empty()) {
78
26
    ENVOY_LOG(error, "asyncGetAccessToken: client secret is empty, retrying in {} seconds.",
79
26
              retry_interval_.count());
80
26
    timer_->enableTimer(std::chrono::seconds(retry_interval_));
81
26
    stats_.token_fetch_failed_on_client_secret_.inc();
82
26
    return;
83
26
  }
84
32
  auto result = oauth2_client_->asyncGetAccessToken(client_id_, secret_reader_->credential(),
85
32
                                                    oauth_scopes_, endpoint_params_);
86
32
  if (result == OAuth2Client::GetTokenResult::NotDispatchedAlreadyInFlight) {
87
    return;
88
  }
89
32
  if (result == OAuth2Client::GetTokenResult::NotDispatchedClusterNotFound) {
90
4
    ENVOY_LOG(error, "asyncGetAccessToken: OAuth cluster not found. Retrying in {} seconds.",
91
4
              retry_interval_.count());
92
4
    timer_->enableTimer(std::chrono::seconds(retry_interval_));
93
4
    stats_.token_fetch_failed_on_cluster_not_found_.inc();
94
4
    return;
95
4
  }
96

            
97
28
  stats_.token_requested_.inc();
98
28
  ENVOY_LOG(debug, "asyncGetAccessToken: Dispatched OAuth request for access token.");
99
28
}
100

            
101
// FilterCallbacks
102
void TokenProvider::onGetAccessTokenSuccess(const std::string& access_token,
103
21
                                            std::chrono::seconds expires_in) {
104
  // set the token
105
21
  auto token = absl::StrCat("Bearer ", access_token);
106
21
  ThreadLocalOauth2ClientCredentialsTokenSharedPtr value(
107
21
      new ThreadLocalOauth2ClientCredentialsToken(token));
108

            
109
21
  tls_->set(
110
41
      [value](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { return value; });
111

            
112
21
  stats_.token_fetched_.inc();
113
21
  ENVOY_LOG(debug, "onGetAccessTokenSuccess: Token fetched successfully, expires in {} seconds.",
114
21
            expires_in.count());
115
21
  if (timer_->enabled()) {
116
1
    return;
117
1
  }
118

            
119
20
  timer_->enableTimer(expires_in / 2);
120
20
}
121

            
122
9
void TokenProvider::onGetAccessTokenFailure(FailureReason failure_reason) {
123
9
  ENVOY_LOG(error, "onGetAccessTokenFailure: Failed to get access token");
124
9
  bool retry = true;
125
9
  switch (failure_reason) {
126
3
  case FailureReason::StreamReset:
127
3
    stats_.token_fetch_failed_on_stream_reset_.inc();
128
3
    break;
129
4
  case FailureReason::BadToken:
130
4
    stats_.token_fetch_failed_on_bad_token_.inc();
131
4
    retry = false;
132
4
    break;
133
2
  case FailureReason::BadResponseCode:
134
2
    stats_.token_fetch_failed_on_bad_response_code_.inc();
135
2
    break;
136
9
  }
137
9
  if (!retry) {
138
4
    return;
139
4
  }
140

            
141
5
  if (timer_->enabled()) {
142
1
    return;
143
1
  }
144
4
  timer_->enableTimer(retry_interval_);
145
4
}
146

            
147
16
const std::string& TokenProvider::credential() const { return threadLocal().token(); }
148

            
149
} // namespace OAuth2
150
} // namespace InjectedCredentials
151
} // namespace Http
152
} // namespace Extensions
153
} // namespace Envoy