JWTVerifier.java
package com.auth0.jwt;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.exceptions.*;
import com.auth0.jwt.impl.JWTParser;
import com.auth0.jwt.interfaces.Claim;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.impl.ExpectedCheckHolder;
import com.auth0.jwt.interfaces.Verification;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.function.BiPredicate;
/**
* The JWTVerifier class holds the verify method to assert that a given Token has not only a proper JWT format,
* but also its signature matches.
* <p>
* This class is thread-safe.
*
* @see com.auth0.jwt.interfaces.JWTVerifier
*/
public final class JWTVerifier implements com.auth0.jwt.interfaces.JWTVerifier {
private final Algorithm algorithm;
final List<ExpectedCheckHolder> expectedChecks;
private final JWTParser parser;
JWTVerifier(Algorithm algorithm, List<ExpectedCheckHolder> expectedChecks) {
this.algorithm = algorithm;
this.expectedChecks = Collections.unmodifiableList(expectedChecks);
this.parser = new JWTParser();
}
/**
* Initialize a {@link Verification} instance using the given Algorithm.
*
* @param algorithm the Algorithm to use on the JWT verification.
* @return a {@link Verification} instance to configure.
* @throws IllegalArgumentException if the provided algorithm is null.
*/
static Verification init(Algorithm algorithm) throws IllegalArgumentException {
return new BaseVerification(algorithm);
}
/**
* {@link Verification} implementation that accepts all the expected Claim values for verification, and
* builds a {@link com.auth0.jwt.interfaces.JWTVerifier} used to verify a JWT's signature and expected claims.
*
* Note that this class is <strong>not</strong> thread-safe. Calling {@link #build()} returns an instance of
* {@link com.auth0.jwt.interfaces.JWTVerifier} which can be reused.
*/
public static class BaseVerification implements Verification {
private final Algorithm algorithm;
private final List<ExpectedCheckHolder> expectedChecks;
private long defaultLeeway;
private final Map<String, Long> customLeeways;
private boolean ignoreIssuedAt;
private Clock clock;
BaseVerification(Algorithm algorithm) throws IllegalArgumentException {
if (algorithm == null) {
throw new IllegalArgumentException("The Algorithm cannot be null.");
}
this.algorithm = algorithm;
this.expectedChecks = new ArrayList<>();
this.customLeeways = new HashMap<>();
this.defaultLeeway = 0;
}
@Override
public Verification withIssuer(String... issuer) {
List<String> value = isNullOrEmpty(issuer) ? null : Arrays.asList(issuer);
addCheck(RegisteredClaims.ISSUER, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (value == null || !value.contains(claim.asString())) {
throw new IncorrectClaimException(
"The Claim 'iss' value doesn't match the required issuer.", RegisteredClaims.ISSUER, claim);
}
return true;
}));
return this;
}
@Override
public Verification withSubject(String subject) {
addCheck(RegisteredClaims.SUBJECT, (claim, decodedJWT) ->
verifyNull(claim, subject) || subject.equals(claim.asString()));
return this;
}
@Override
public Verification withAudience(String... audience) {
List<String> value = isNullOrEmpty(audience) ? null : Arrays.asList(audience);
addCheck(RegisteredClaims.AUDIENCE, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (!assertValidAudienceClaim(decodedJWT.getAudience(), value, true)) {
throw new IncorrectClaimException("The Claim 'aud' value doesn't contain the required audience.",
RegisteredClaims.AUDIENCE, claim);
}
return true;
}));
return this;
}
@Override
public Verification withAnyOfAudience(String... audience) {
List<String> value = isNullOrEmpty(audience) ? null : Arrays.asList(audience);
addCheck(RegisteredClaims.AUDIENCE, ((claim, decodedJWT) -> {
if (verifyNull(claim, value)) {
return true;
}
if (!assertValidAudienceClaim(decodedJWT.getAudience(), value, false)) {
throw new IncorrectClaimException("The Claim 'aud' value doesn't contain the required audience.",
RegisteredClaims.AUDIENCE, claim);
}
return true;
}));
return this;
}
@Override
public Verification acceptLeeway(long leeway) throws IllegalArgumentException {
assertPositive(leeway);
this.defaultLeeway = leeway;
return this;
}
@Override
public Verification acceptExpiresAt(long leeway) throws IllegalArgumentException {
assertPositive(leeway);
customLeeways.put(RegisteredClaims.EXPIRES_AT, leeway);
return this;
}
@Override
public Verification acceptNotBefore(long leeway) throws IllegalArgumentException {
assertPositive(leeway);
customLeeways.put(RegisteredClaims.NOT_BEFORE, leeway);
return this;
}
@Override
public Verification acceptIssuedAt(long leeway) throws IllegalArgumentException {
assertPositive(leeway);
customLeeways.put(RegisteredClaims.ISSUED_AT, leeway);
return this;
}
@Override
public Verification ignoreIssuedAt() {
this.ignoreIssuedAt = true;
return this;
}
@Override
public Verification withJWTId(String jwtId) {
addCheck(RegisteredClaims.JWT_ID, ((claim, decodedJWT) ->
verifyNull(claim, jwtId) || jwtId.equals(claim.asString())));
return this;
}
@Override
public Verification withClaimPresence(String name) throws IllegalArgumentException {
assertNonNull(name);
//since addCheck already checks presence, we just return true
withClaim(name, ((claim, decodedJWT) -> true));
return this;
}
@Override
public Verification withNullClaim(String name) throws IllegalArgumentException {
assertNonNull(name);
withClaim(name, ((claim, decodedJWT) -> claim.isNull()));
return this;
}
@Override
public Verification withClaim(String name, Boolean value) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asBoolean())));
return this;
}
@Override
public Verification withClaim(String name, Integer value) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asInt())));
return this;
}
@Override
public Verification withClaim(String name, Long value) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asLong())));
return this;
}
@Override
public Verification withClaim(String name, Double value) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asDouble())));
return this;
}
@Override
public Verification withClaim(String name, String value) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, value)
|| value.equals(claim.asString())));
return this;
}
@Override
public Verification withClaim(String name, Date value) throws IllegalArgumentException {
return withClaim(name, value != null ? value.toInstant() : null);
}
@Override
public Verification withClaim(String name, Instant value) throws IllegalArgumentException {
assertNonNull(name);
// Since date-time claims are serialized as epoch seconds,
// we need to compare them with only seconds-granularity
addCheck(name,
((claim, decodedJWT) -> verifyNull(claim, value)
|| value.truncatedTo(ChronoUnit.SECONDS).equals(claim.asInstant())));
return this;
}
@Override
public Verification withClaim(String name, BiPredicate<Claim, DecodedJWT> predicate)
throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, predicate)
|| predicate.test(claim, decodedJWT)));
return this;
}
@Override
public Verification withArrayClaim(String name, String... items) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}
@Override
public Verification withArrayClaim(String name, Integer... items) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}
@Override
public Verification withArrayClaim(String name, Long... items) throws IllegalArgumentException {
assertNonNull(name);
addCheck(name, ((claim, decodedJWT) -> verifyNull(claim, items)
|| assertValidCollectionClaim(claim, items)));
return this;
}
@Override
public JWTVerifier build() {
return this.build(Clock.systemUTC());
}
/**
* Creates a new and reusable instance of the JWTVerifier with the configuration already provided.
* ONLY FOR TEST PURPOSES.
*
* @param clock the instance that will handle the current time.
* @return a new JWTVerifier instance with a custom {@link java.time.Clock}
*/
public JWTVerifier build(Clock clock) {
this.clock = clock;
addMandatoryClaimChecks();
return new JWTVerifier(algorithm, expectedChecks);
}
/**
* Fetches the Leeway set for claim or returns the {@link BaseVerification#defaultLeeway}.
*
* @param name Claim for which leeway is fetched
* @return Leeway value set for the claim
*/
public long getLeewayFor(String name) {
return customLeeways.getOrDefault(name, defaultLeeway);
}
private void addMandatoryClaimChecks() {
long expiresAtLeeway = getLeewayFor(RegisteredClaims.EXPIRES_AT);
long notBeforeLeeway = getLeewayFor(RegisteredClaims.NOT_BEFORE);
long issuedAtLeeway = getLeewayFor(RegisteredClaims.ISSUED_AT);
expectedChecks.add(constructExpectedCheck(RegisteredClaims.EXPIRES_AT, (claim, decodedJWT) ->
assertValidInstantClaim(RegisteredClaims.EXPIRES_AT, claim, expiresAtLeeway, true)));
expectedChecks.add(constructExpectedCheck(RegisteredClaims.NOT_BEFORE, (claim, decodedJWT) ->
assertValidInstantClaim(RegisteredClaims.NOT_BEFORE, claim, notBeforeLeeway, false)));
if (!ignoreIssuedAt) {
expectedChecks.add(constructExpectedCheck(RegisteredClaims.ISSUED_AT, (claim, decodedJWT) ->
assertValidInstantClaim(RegisteredClaims.ISSUED_AT, claim, issuedAtLeeway, false)));
}
}
private boolean assertValidCollectionClaim(Claim claim, Object[] expectedClaimValue) {
List<Object> claimArr;
Object[] claimAsObject = claim.as(Object[].class);
// Jackson uses 'natural' mapping which uses Integer if value fits in 32 bits.
if (expectedClaimValue instanceof Long[]) {
// convert Integers to Longs for comparison with equals
claimArr = new ArrayList<>(claimAsObject.length);
for (Object cao : claimAsObject) {
if (cao instanceof Integer) {
claimArr.add(((Integer) cao).longValue());
} else {
claimArr.add(cao);
}
}
} else {
claimArr = Arrays.asList(claim.as(Object[].class));
}
List<Object> valueArr = Arrays.asList(expectedClaimValue);
return claimArr.containsAll(valueArr);
}
private boolean assertValidInstantClaim(String claimName, Claim claim, long leeway, boolean shouldBeFuture) {
Instant claimVal = claim.asInstant();
Instant now = clock.instant().truncatedTo(ChronoUnit.SECONDS);
boolean isValid;
if (shouldBeFuture) {
isValid = assertInstantIsFuture(claimVal, leeway, now);
if (!isValid) {
throw new TokenExpiredException(String.format("The Token has expired on %s.", claimVal), claimVal);
}
} else {
isValid = assertInstantIsLessThanOrEqualToNow(claimVal, leeway, now);
if (!isValid) {
throw new IncorrectClaimException(
String.format("The Token can't be used before %s.", claimVal), claimName, claim);
}
}
return true;
}
private boolean assertInstantIsFuture(Instant claimVal, long leeway, Instant now) {
return claimVal == null || now.minus(Duration.ofSeconds(leeway)).isBefore(claimVal);
}
private boolean assertInstantIsLessThanOrEqualToNow(Instant claimVal, long leeway, Instant now) {
return !(claimVal != null && now.plus(Duration.ofSeconds(leeway)).isBefore(claimVal));
}
private boolean assertValidAudienceClaim(
List<String> actualAudience,
List<String> expectedAudience,
boolean shouldContainAll
) {
if (actualAudience == null || expectedAudience == null) {
return false;
}
if (shouldContainAll) {
return actualAudience.containsAll(expectedAudience);
} else {
return !Collections.disjoint(actualAudience, expectedAudience);
}
}
private void assertPositive(long leeway) {
if (leeway < 0) {
throw new IllegalArgumentException("Leeway value can't be negative.");
}
}
private void assertNonNull(String name) {
if (name == null) {
throw new IllegalArgumentException("The Custom Claim's name can't be null.");
}
}
private void addCheck(String name, BiPredicate<Claim, DecodedJWT> predicate) {
expectedChecks.add(constructExpectedCheck(name, (claim, decodedJWT) -> {
if (claim.isMissing()) {
throw new MissingClaimException(name);
}
return predicate.test(claim, decodedJWT);
}));
}
private ExpectedCheckHolder constructExpectedCheck(String claimName, BiPredicate<Claim, DecodedJWT> check) {
return new ExpectedCheckHolder() {
@Override
public String getClaimName() {
return claimName;
}
@Override
public boolean verify(Claim claim, DecodedJWT decodedJWT) {
return check.test(claim, decodedJWT);
}
};
}
private boolean verifyNull(Claim claim, Object value) {
return value == null && claim.isNull();
}
private boolean isNullOrEmpty(String[] args) {
if (args == null || args.length == 0) {
return true;
}
boolean isAllNull = true;
for (String arg : args) {
if (arg != null) {
isAllNull = false;
break;
}
}
return isAllNull;
}
}
/**
* Perform the verification against the given Token, using any previous configured options.
*
* @param token to verify.
* @return a verified and decoded JWT.
* @throws AlgorithmMismatchException if the algorithm stated in the token's header is not equal to
* the one defined in the {@link JWTVerifier}.
* @throws SignatureVerificationException if the signature is invalid.
* @throws TokenExpiredException if the token has expired.
* @throws MissingClaimException if a claim to be verified is missing.
* @throws IncorrectClaimException if a claim contained a different value than the expected one.
*/
@Override
public DecodedJWT verify(String token) throws JWTVerificationException {
DecodedJWT jwt = new JWTDecoder(parser, token);
return verify(jwt);
}
/**
* Perform the verification against the given decoded JWT, using any previous configured options.
*
* @param jwt to verify.
* @return a verified and decoded JWT.
* @throws AlgorithmMismatchException if the algorithm stated in the token's header is not equal to
* the one defined in the {@link JWTVerifier}.
* @throws SignatureVerificationException if the signature is invalid.
* @throws TokenExpiredException if the token has expired.
* @throws MissingClaimException if a claim to be verified is missing.
* @throws IncorrectClaimException if a claim contained a different value than the expected one.
*/
@Override
public DecodedJWT verify(DecodedJWT jwt) throws JWTVerificationException {
verifyAlgorithm(jwt, algorithm);
algorithm.verify(jwt);
verifyClaims(jwt, expectedChecks);
return jwt;
}
private void verifyAlgorithm(DecodedJWT jwt, Algorithm expectedAlgorithm) throws AlgorithmMismatchException {
if (!expectedAlgorithm.getName().equals(jwt.getAlgorithm())) {
throw new AlgorithmMismatchException(
"The provided Algorithm doesn't match the one defined in the JWT's Header.");
}
}
private void verifyClaims(DecodedJWT jwt, List<ExpectedCheckHolder> expectedChecks)
throws TokenExpiredException, InvalidClaimException {
for (ExpectedCheckHolder expectedCheck : expectedChecks) {
boolean isValid;
String claimName = expectedCheck.getClaimName();
Claim claim = jwt.getClaim(claimName);
isValid = expectedCheck.verify(claim, jwt);
if (!isValid) {
throw new IncorrectClaimException(
String.format("The Claim '%s' value doesn't match the required one.", claimName),
claimName,
claim
);
}
}
}
}