1
// Copyright 2018 Google LLC
2
// Copyright Envoy Project Authors
3
// SPDX-License-Identifier: Apache-2.0
4

            
5
#include "source/common/jwt/verify.h"
6

            
7
#include "source/common/jwt/check_audience.h"
8

            
9
#include "absl/strings/string_view.h"
10
#include "absl/time/clock.h"
11
#include "openssl/bn.h"
12
#include "openssl/curve25519.h"
13
#include "openssl/ecdsa.h"
14
#include "openssl/err.h"
15
#include "openssl/evp.h"
16
#include "openssl/hmac.h"
17
#include "openssl/mem.h"
18
#include "openssl/rsa.h"
19
#include "openssl/sha.h"
20

            
21
namespace Envoy {
22
namespace JwtVerify {
23
namespace {
24

            
25
// A convenience inline cast function.
26
96985
inline const uint8_t* castToUChar(const absl::string_view& str) {
27
96985
  return reinterpret_cast<const uint8_t*>(str.data());
28
96985
}
29

            
30
bool verifySignatureRSA(RSA* key, const EVP_MD* md, const uint8_t* signature, size_t signature_len,
31
27816
                        const uint8_t* signed_data, size_t signed_data_len) {
32
27816
  if (key == nullptr || md == nullptr || signature == nullptr || signed_data == nullptr) {
33
    return false;
34
  }
35
27816
  bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
36
27816
  if (EVP_PKEY_set1_RSA(evp_pkey.get(), key) != 1) {
37
    return false;
38
  }
39

            
40
27816
  bssl::UniquePtr<EVP_MD_CTX> md_ctx(EVP_MD_CTX_create());
41
27816
  if (EVP_DigestVerifyInit(md_ctx.get(), nullptr, md, nullptr, evp_pkey.get()) == 1) {
42
27816
    if (EVP_DigestVerifyUpdate(md_ctx.get(), signed_data, signed_data_len) == 1) {
43
27816
      if (EVP_DigestVerifyFinal(md_ctx.get(), signature, signature_len) == 1) {
44
154
        return true;
45
154
      }
46
27816
    }
47
27816
  }
48
27662
  ERR_clear_error();
49
27662
  return false;
50
27816
}
51

            
52
bool verifySignatureRSA(RSA* key, const EVP_MD* md, absl::string_view signature,
53
27816
                        absl::string_view signed_data) {
54
27816
  return verifySignatureRSA(key, md, castToUChar(signature), signature.length(),
55
27816
                            castToUChar(signed_data), signed_data.length());
56
27816
}
57

            
58
bool verifySignatureRSAPSS(RSA* key, const EVP_MD* md, const uint8_t* signature,
59
                           size_t signature_len, const uint8_t* signed_data,
60
6914
                           size_t signed_data_len) {
61
6914
  if (key == nullptr || md == nullptr || signature == nullptr || signed_data == nullptr) {
62
    return false;
63
  }
64
6914
  bssl::UniquePtr<EVP_PKEY> evp_pkey(EVP_PKEY_new());
65
6914
  if (EVP_PKEY_set1_RSA(evp_pkey.get(), key) != 1) {
66
    return false;
67
  }
68

            
69
6914
  bssl::UniquePtr<EVP_MD_CTX> md_ctx(EVP_MD_CTX_create());
70
  // ``pctx`` is owned by ``md_ctx``, no need to free it separately.
71
6914
  EVP_PKEY_CTX* pctx;
72
6914
  if (EVP_DigestVerifyInit(md_ctx.get(), &pctx, md, nullptr, evp_pkey.get()) == 1 &&
73
6914
      EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) == 1 &&
74
6914
      EVP_PKEY_CTX_set_rsa_mgf1_md(pctx, md) == 1 &&
75
6914
      EVP_DigestVerify(md_ctx.get(), signature, signature_len, signed_data, signed_data_len) == 1) {
76
4
    return true;
77
4
  }
78

            
79
6910
  ERR_clear_error();
80
6910
  return false;
81
6914
}
82

            
83
bool verifySignatureRSAPSS(RSA* key, const EVP_MD* md, absl::string_view signature,
84
6914
                           absl::string_view signed_data) {
85
6914
  return verifySignatureRSAPSS(key, md, castToUChar(signature), signature.length(),
86
6914
                               castToUChar(signed_data), signed_data.length());
87
6914
}
88

            
89
bool verifySignatureEC(EC_KEY* key, const EVP_MD* md, const uint8_t* signature,
90
7002
                       size_t signature_len, const uint8_t* signed_data, size_t signed_data_len) {
91
7002
  if (key == nullptr || md == nullptr || signature == nullptr || signed_data == nullptr) {
92
    return false;
93
  }
94
7002
  bssl::UniquePtr<EVP_MD_CTX> md_ctx(EVP_MD_CTX_create());
95
7002
  std::vector<uint8_t> digest(EVP_MAX_MD_SIZE);
96
7002
  unsigned int digest_len = 0;
97

            
98
7002
  if (EVP_DigestInit(md_ctx.get(), md) == 0) {
99
    return false;
100
  }
101

            
102
7002
  if (EVP_DigestUpdate(md_ctx.get(), signed_data, signed_data_len) == 0) {
103
    return false;
104
  }
105

            
106
7002
  if (EVP_DigestFinal(md_ctx.get(), digest.data(), &digest_len) == 0) {
107
    return false;
108
  }
109

            
110
7002
  bssl::UniquePtr<ECDSA_SIG> ecdsa_sig(ECDSA_SIG_new());
111
7002
  if (!ecdsa_sig) {
112
    return false;
113
  }
114

            
115
7002
  bssl::UniquePtr<BIGNUM> ecdsa_sig_r{BN_bin2bn(signature, signature_len / 2, nullptr)};
116
7002
  bssl::UniquePtr<BIGNUM> ecdsa_sig_s{
117
7002
      BN_bin2bn(signature + (signature_len / 2), signature_len / 2, nullptr)};
118

            
119
  // Short-circuit evaluation ensures `ECDSA_SIG_set0` is only called if both `BIGNUMs` are valid.
120
  // On `ECDSA_SIG_set0` success, ownership transfers to `ecdsa_sig`; on failure, `unique_ptrs`
121
  // clean up.
122
7002
  if (ecdsa_sig_r == nullptr || ecdsa_sig_s == nullptr ||
123
7002
      ECDSA_SIG_set0(ecdsa_sig.get(), ecdsa_sig_r.get(), ecdsa_sig_s.get()) == 0) {
124
    return false;
125
  }
126
7002
  ecdsa_sig_r.release();
127
7002
  ecdsa_sig_s.release();
128

            
129
7002
  if (ECDSA_do_verify(digest.data(), digest_len, ecdsa_sig.get(), key) == 1) {
130
12
    return true;
131
12
  }
132

            
133
6990
  ERR_clear_error();
134
6990
  return false;
135
7002
}
136

            
137
bool verifySignatureEC(EC_KEY* key, const EVP_MD* md, absl::string_view signature,
138
7002
                       absl::string_view signed_data) {
139
7002
  return verifySignatureEC(key, md, castToUChar(signature), signature.length(),
140
7002
                           castToUChar(signed_data), signed_data.length());
141
7002
}
142

            
143
bool verifySignatureOct(const uint8_t* key, size_t key_len, const EVP_MD* md,
144
                        const uint8_t* signature, size_t signature_len, const uint8_t* signed_data,
145
2452
                        size_t signed_data_len) {
146
2452
  if (key == nullptr || md == nullptr || signature == nullptr || signed_data == nullptr) {
147
    return false;
148
  }
149

            
150
2452
  std::vector<uint8_t> out(EVP_MAX_MD_SIZE);
151
2452
  unsigned int out_len = 0;
152
2452
  if (HMAC(md, key, key_len, signed_data, signed_data_len, out.data(), &out_len) == nullptr) {
153
    ERR_clear_error();
154
    return false;
155
  }
156

            
157
2452
  if (out_len != signature_len) {
158
265
    return false;
159
265
  }
160

            
161
2187
  if (CRYPTO_memcmp(out.data(), signature, signature_len) == 0) {
162
7
    return true;
163
7
  }
164

            
165
2180
  ERR_clear_error();
166
2180
  return false;
167
2187
}
168

            
169
bool verifySignatureOct(absl::string_view key, const EVP_MD* md, absl::string_view signature,
170
2452
                        absl::string_view signed_data) {
171
2452
  return verifySignatureOct(castToUChar(key), key.length(), md, castToUChar(signature),
172
2452
                            signature.length(), castToUChar(signed_data), signed_data.length());
173
2452
}
174

            
175
Status verifySignatureEd25519(absl::string_view key, absl::string_view signature,
176
2244
                              absl::string_view signed_data) {
177
2244
  if (signature.length() != ED25519_SIGNATURE_LEN) {
178
189
    return Status::JwtEd25519SignatureWrongLength;
179
189
  }
180

            
181
2055
  if (ED25519_verify(castToUChar(signed_data), signed_data.length(), castToUChar(signature),
182
2055
                     castToUChar(key.data())) == 1) {
183
5
    return Status::Ok;
184
5
  }
185

            
186
2050
  ERR_clear_error();
187
2050
  return Status::JwtVerificationFail;
188
2055
}
189

            
190
} // namespace
191

            
192
38417
Status verifyJwtWithoutTimeChecking(const Jwt& jwt, const Jwks& jwks) {
193
  // Verify signature
194
38417
  std::string signed_data = jwt.header_str_base64url_ + '.' + jwt.payload_str_base64url_;
195
38417
  bool kid_alg_matched = false;
196
89268
  for (const auto& jwk : jwks.keys()) {
197
    // If kid is specified in JWT, JWK with the same kid is used for
198
    // verification.
199
    // If kid is not specified in JWT, try all JWK.
200
89268
    if (!jwt.kid_.empty() && !jwk->kid_.empty() && jwk->kid_ != jwt.kid_) {
201
39963
      continue;
202
39963
    }
203

            
204
    // The same alg must be used.
205
49305
    if (!jwk->alg_.empty() && jwk->alg_ != jwt.alg_) {
206
2877
      continue;
207
2877
    }
208
46428
    kid_alg_matched = true;
209

            
210
46428
    if (jwk->kty_ == "EC") {
211
7002
      const EVP_MD* md;
212
7002
      if (jwt.alg_ == "ES384") {
213
1732
        md = EVP_sha384();
214
5270
      } else if (jwt.alg_ == "ES512") {
215
2379
        md = EVP_sha512();
216
2925
      } else {
217
        // default to SHA256
218
2891
        md = EVP_sha256();
219
2891
      }
220

            
221
7002
      if (verifySignatureEC(jwk->ec_key_.get(), md, jwt.signature_, signed_data)) {
222
        // Verification succeeded.
223
12
        return Status::Ok;
224
12
      }
225
46426
    } else if (jwk->kty_ == "RSA") {
226
34730
      const EVP_MD* md;
227
34730
      if (jwt.alg_ == "RS384" || jwt.alg_ == "PS384") {
228
4608
        md = EVP_sha384();
229
30122
      } else if (jwt.alg_ == "RS512" || jwt.alg_ == "PS512") {
230
4608
        md = EVP_sha512();
231
25514
      } else {
232
        // default to SHA256
233
25514
        md = EVP_sha256();
234
25514
      }
235

            
236
34730
      if (jwt.alg_.compare(0, 2, "RS") == 0) {
237
27816
        if (verifySignatureRSA(jwk->rsa_.get(), md, jwt.signature_, signed_data)) {
238
          // Verification succeeded.
239
154
          return Status::Ok;
240
154
        }
241
34730
      } else if (jwt.alg_.compare(0, 2, "PS") == 0) {
242
6914
        if (verifySignatureRSAPSS(jwk->rsa_.get(), md, jwt.signature_, signed_data)) {
243
          // Verification succeeded.
244
4
          return Status::Ok;
245
4
        }
246
6914
      }
247
39426
    } else if (jwk->kty_ == "oct") {
248
2452
      const EVP_MD* md;
249
2452
      if (jwt.alg_ == "HS384") {
250
432
        md = EVP_sha384();
251
2020
      } else if (jwt.alg_ == "HS512") {
252
576
        md = EVP_sha512();
253
1444
      } else {
254
        // default to SHA256
255
1444
        md = EVP_sha256();
256
1444
      }
257

            
258
2452
      if (verifySignatureOct(jwk->hmac_key_, md, jwt.signature_, signed_data)) {
259
        // Verification succeeded.
260
7
        return Status::Ok;
261
7
      }
262
4696
    } else if (jwk->kty_ == "OKP" && jwk->crv_ == "Ed25519") {
263
2244
      Status status = verifySignatureEd25519(jwk->okp_key_raw_, jwt.signature_, signed_data);
264
      // For verification failures keep going and try the rest of the keys in
265
      // the JWKS. Otherwise status is either OK or an error with the JWT and we
266
      // can return immediately.
267
2244
      if (status == Status::Ok || status == Status::JwtEd25519SignatureWrongLength) {
268
194
        return status;
269
194
      }
270
2244
    }
271
46428
  }
272

            
273
  // Verification failed.
274
38046
  return kid_alg_matched ? Status::JwtVerificationFail : Status::JwksKidAlgMismatch;
275
38417
}
276

            
277
2314
Status verifyJwt(const Jwt& jwt, const Jwks& jwks) {
278
2314
  return verifyJwt(jwt, jwks, absl::ToUnixSeconds(absl::Now()));
279
2314
}
280

            
281
38267
Status verifyJwt(const Jwt& jwt, const Jwks& jwks, uint64_t now, uint64_t clock_skew) {
282
38267
  Status time_status = jwt.verifyTimeConstraint(now, clock_skew);
283
38267
  if (time_status != Status::Ok) {
284
    return time_status;
285
  }
286

            
287
38267
  return verifyJwtWithoutTimeChecking(jwt, jwks);
288
38267
}
289

            
290
3
Status verifyJwt(const Jwt& jwt, const Jwks& jwks, const std::vector<std::string>& audiences) {
291
3
  return verifyJwt(jwt, jwks, audiences, absl::ToUnixSeconds(absl::Now()));
292
3
}
293

            
294
Status verifyJwt(const Jwt& jwt, const Jwks& jwks, const std::vector<std::string>& audiences,
295
3
                 uint64_t now) {
296
3
  CheckAudience checker(audiences);
297
3
  if (!checker.areAudiencesAllowed(jwt.audiences_)) {
298
1
    return Status::JwtAudienceNotAllowed;
299
1
  }
300
2
  return verifyJwt(jwt, jwks, now);
301
3
}
302

            
303
} // namespace JwtVerify
304
} // namespace Envoy