1
#pragma once
2

            
3
#include "envoy/common/pure.h"
4
#include "envoy/common/time.h"
5

            
6
#include "source/common/common/cleanup.h"
7
#include "source/common/common/thread.h"
8

            
9
#include "absl/types/optional.h"
10

            
11
namespace Envoy {
12
namespace Extensions {
13
namespace Common {
14
namespace Aws {
15

            
16
constexpr char AWS_ACCESS_KEY_ID[] = "AWS_ACCESS_KEY_ID";
17
constexpr char AWS_SECRET_ACCESS_KEY[] = "AWS_SECRET_ACCESS_KEY";
18
constexpr char AWS_SESSION_TOKEN[] = "AWS_SESSION_TOKEN";
19
constexpr char ACCESS_KEY_ID[] = "AccessKeyId";
20
constexpr char SECRET_ACCESS_KEY[] = "SecretAccessKey";
21
constexpr char TOKEN[] = "Token";
22
constexpr char SESSION_TOKEN[] = "SessionToken";
23
constexpr char EXPIRATION[] = "Expiration";
24
constexpr char CREDENTIALS[] = "Credentials";
25
constexpr char STS_SERVICE_NAME[] = "sts";
26
constexpr std::chrono::hours REFRESH_INTERVAL{1};
27
constexpr std::chrono::seconds REFRESH_GRACE_PERIOD{60};
28
constexpr std::chrono::seconds MAX_CACHE_JITTER{30};
29

            
30
/**
31
 * AWS credentials containers
32
 *
33
 * If a credential component was not found in the execution environment, it's getter method will
34
 * return absl::nullopt. Credential components with the empty string value are treated as not found.
35
 */
36
class Credentials {
37
public:
38
  // Access Key Credentials
39
  explicit Credentials(absl::string_view access_key_id = absl::string_view(),
40
                       absl::string_view secret_access_key = absl::string_view(),
41
680
                       absl::string_view session_token = absl::string_view()) {
42
    // TODO(suniltheta): Move credential expiration date in here
43
680
    if (!access_key_id.empty()) {
44
207
      access_key_id_ = std::string(access_key_id);
45
207
      if (!secret_access_key.empty()) {
46
206
        secret_access_key_ = std::string(secret_access_key);
47
206
        if (!session_token.empty()) {
48
149
          session_token_ = std::string(session_token);
49
149
        }
50
206
      }
51
207
    }
52
680
  }
53

            
54
285
  const absl::optional<std::string>& accessKeyId() const { return access_key_id_; }
55

            
56
199
  const absl::optional<std::string>& secretAccessKey() const { return secret_access_key_; }
57

            
58
182
  const absl::optional<std::string>& sessionToken() const { return session_token_; }
59

            
60
224
  bool hasCredentials() const {
61
224
    return access_key_id_.has_value() && secret_access_key_.has_value();
62
224
  }
63

            
64
3
  bool operator==(const Credentials& other) const {
65
3
    return access_key_id_ == other.access_key_id_ &&
66
3
           secret_access_key_ == other.secret_access_key_ && session_token_ == other.session_token_;
67
3
  }
68

            
69
private:
70
  absl::optional<std::string> access_key_id_;
71
  absl::optional<std::string> secret_access_key_;
72
  absl::optional<std::string> session_token_;
73
};
74

            
75
using CredentialsPendingCallback = std::function<void()>;
76

            
77
/*
78
 * X509 Credentials used for IAM Roles Anywhere
79
 */
80

            
81
class X509Credentials {
82
public:
83
  enum class PublicKeySignatureAlgorithm {
84
    RSA,
85
    ECDSA,
86
  };
87

            
88
  // X509 Credentials
89
  X509Credentials(absl::string_view certificate_b64,
90
                  PublicKeySignatureAlgorithm certificate_signature_algorithm,
91
                  absl::string_view certificate_serial,
92
                  absl::optional<absl::string_view> certificate_chain_b64,
93
                  absl::string_view certificate_private_key_pem,
94
                  SystemTime certificate_expiration_time)
95
52
      : certificate_b64_(certificate_b64),
96
52
        certificate_private_key_pem_(certificate_private_key_pem),
97
52
        certificate_serial_(certificate_serial),
98
52
        certificate_expiration_(certificate_expiration_time),
99
52
        certificate_signature_algorithm_(certificate_signature_algorithm) {
100
52
    if (certificate_chain_b64.has_value()) {
101
37
      certificate_chain_b64_ = certificate_chain_b64.value();
102
37
    }
103
52
  }
104

            
105
63
  X509Credentials() = default;
106

            
107
56
  const absl::optional<std::string>& certificateDerB64() const { return certificate_b64_; }
108

            
109
29
  const absl::optional<std::string>& certificateSerial() const { return certificate_serial_; }
110

            
111
3
  const absl::optional<SystemTime>& certificateExpiration() const {
112
3
    return certificate_expiration_;
113
3
  }
114

            
115
46
  const absl::optional<std::string>& certificateChainDerB64() const {
116
46
    return certificate_chain_b64_;
117
46
  }
118

            
119
70
  const absl::optional<PublicKeySignatureAlgorithm>& publicKeySignatureAlgorithm() const {
120
70
    return certificate_signature_algorithm_;
121
70
  }
122

            
123
75
  const absl::optional<std::string> certificatePrivateKey() const {
124
75
    return certificate_private_key_pem_;
125
75
  }
126

            
127
private:
128
  // RolesAnywhere certificate based credentials
129
  absl::optional<std::string> certificate_b64_ = absl::nullopt;
130
  absl::optional<std::string> certificate_chain_b64_ = absl::nullopt;
131
  absl::optional<std::string> certificate_private_key_pem_ = absl::nullopt;
132
  absl::optional<std::string> certificate_serial_ = absl::nullopt;
133
  absl::optional<SystemTime> certificate_expiration_ = absl::nullopt;
134
  absl::optional<PublicKeySignatureAlgorithm> certificate_signature_algorithm_ = absl::nullopt;
135
};
136

            
137
/*
138
 * Interface for classes able to fetch AWS credentials from the execution environment.
139
 */
140
class CredentialsProvider {
141
public:
142
634
  virtual ~CredentialsProvider() = default;
143

            
144
  /**
145
   * Get credentials from the environment.
146
   *
147
   * @return AWS credentials
148
   */
149
  virtual std::string providerName() PURE;
150
  virtual Credentials getCredentials() PURE;
151
  /**
152
   * @return true if credentials are pending from this provider, false if credentials are available
153
   */
154
  virtual bool credentialsPending() PURE;
155
};
156

            
157
using CredentialsConstSharedPtr = std::shared_ptr<const Credentials>;
158
using CredentialsConstUniquePtr = std::unique_ptr<const Credentials>;
159
using CredentialsProviderSharedPtr = std::shared_ptr<CredentialsProvider>;
160

            
161
class CredentialSubscriberCallbacks {
162
public:
163
337
  virtual ~CredentialSubscriberCallbacks() = default;
164

            
165
  virtual void onCredentialUpdate() PURE;
166
};
167

            
168
using CredentialSubscriberCallbacksSharedPtr = std::shared_ptr<CredentialSubscriberCallbacks>;
169

            
170
// Subscription model allowing CredentialsProviderChains to be notified of credential provider
171
// updates. A credential provider chain will call credential_provider->subscribeToCredentialUpdates
172
// to register itself for updates via onCredentialUpdate callback. When a credential provider has
173
// successfully updated all threads with new credentials, via the setCredentialsToAllThreads method
174
// it will notify all subscribers that credentials have been retrieved.
175
//
176
// Subscription is only relevant for metadata credentials providers, as these are the only
177
// credential providers that implement async credential retrieval functionality.
178
//
179
// RAII is used, as credential providers may be instantiated as singletons, as such they may outlive
180
// the credential provider chain.
181
//
182
// Uses weak_ptr to safely handle subscriber lifetime without dangling pointers.
183
class CredentialSubscriberCallbacksHandle
184
    : public RaiiListElement<std::weak_ptr<CredentialSubscriberCallbacks>> {
185
public:
186
  CredentialSubscriberCallbacksHandle(
187
      CredentialSubscriberCallbacksSharedPtr cb,
188
      std::list<std::weak_ptr<CredentialSubscriberCallbacks>>& parent)
189
88
      : RaiiListElement<std::weak_ptr<CredentialSubscriberCallbacks>>(parent, cb) {}
190
};
191

            
192
using CredentialSubscriberCallbacksHandlePtr = std::unique_ptr<CredentialSubscriberCallbacksHandle>;
193

            
194
/**
195
 * AWS credentials provider chain, able to fallback between multiple credential providers.
196
 */
197
class CredentialsProviderChain : public CredentialSubscriberCallbacks,
198
                                 public Logger::Loggable<Logger::Id::aws>,
199
                                 public std::enable_shared_from_this<CredentialsProviderChain> {
200
public:
201
337
  ~CredentialsProviderChain() override {
202
355
    for (auto& subscriber_handle : subscriber_handles_) {
203
84
      if (subscriber_handle) {
204
84
        subscriber_handle->cancel();
205
84
      }
206
84
    }
207
337
  }
208

            
209
565
  void add(const CredentialsProviderSharedPtr& credentials_provider) {
210
565
    if (credentials_provider != nullptr) {
211
557
      providers_.emplace_back(credentials_provider);
212
557
    }
213
565
  }
214

            
215
  // Store a callback if credentials are pending from a credential provider, to be called when
216
  // credentials are available
217
  virtual bool addCallbackIfChainCredentialsPending(CredentialsPendingCallback&&);
218

            
219
  // Loop through all credential providers in a chain and return credentials from the first one that
220
  // has credentials available
221
  Credentials chainGetCredentials();
222

            
223
  // Store the RAII handle for a subscription to credential provider notification
224
  void storeSubscription(CredentialSubscriberCallbacksHandlePtr);
225

            
226
  // Returns the size of the credential provider chain
227
17
  size_t getNumProviders() { return providers_.size(); }
228

            
229
private:
230
  // Callback to notify on credential updates occurring from a chain member
231
  void onCredentialUpdate() override;
232

            
233
  bool chainProvidersPending();
234

            
235
protected:
236
  std::list<CredentialsProviderSharedPtr> providers_;
237
  Thread::MutexBasicLockable mu_;
238
  std::vector<CredentialsPendingCallback> credential_pending_callbacks_ ABSL_GUARDED_BY(mu_);
239
  std::list<CredentialSubscriberCallbacksHandlePtr> subscriber_handles_;
240
};
241

            
242
using CredentialsProviderChainSharedPtr = std::shared_ptr<CredentialsProviderChain>;
243

            
244
/*
245
 * X509 credential provider used for IAM Roles Anywhere
246
 */
247
class X509CredentialsProvider {
248
public:
249
54
  virtual ~X509CredentialsProvider() = default;
250

            
251
  /**
252
   * Get credentials from the environment.
253
   *
254
   * @return AWS credentials
255
   */
256
  virtual X509Credentials getCredentials() PURE;
257
};
258

            
259
using X509CredentialsProviderSharedPtr = std::shared_ptr<X509CredentialsProvider>;
260

            
261
} // namespace Aws
262
} // namespace Common
263
} // namespace Extensions
264
} // namespace Envoy