DefaultOAuth2TokenCustomizers.java

/*
 * Copyright 2004-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;

import java.security.MessageDigest;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;

import com.nimbusds.jose.jwk.JWK;

import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeActor;
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenExchangeCompositeAuthenticationToken;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenClaimNames;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenClaimsContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.util.CollectionUtils;

/**
 * @author Joe Grandja
 * @author Steve Riesenberg
 * @since 7.0
 */
final class DefaultOAuth2TokenCustomizers {

	private DefaultOAuth2TokenCustomizers() {
	}

	static OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
		return (context) -> context.getClaims().claims((claims) -> customize(context, claims));
	}

	static OAuth2TokenCustomizer<OAuth2TokenClaimsContext> accessTokenCustomizer() {
		return (context) -> context.getClaims().claims((claims) -> customize(context, claims));
	}

	private static void customize(OAuth2TokenContext tokenContext, Map<String, Object> claims) {
		Map<String, Object> cnfClaims = null;

		// Add 'cnf' claim for Mutual-TLS Client Certificate-Bound Access Tokens
		if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenContext.getTokenType())
				&& tokenContext.getAuthorizationGrant() != null && tokenContext.getAuthorizationGrant()
					.getPrincipal() instanceof OAuth2ClientAuthenticationToken clientAuthentication) {

			if ((ClientAuthenticationMethod.TLS_CLIENT_AUTH.equals(clientAuthentication.getClientAuthenticationMethod())
					|| ClientAuthenticationMethod.SELF_SIGNED_TLS_CLIENT_AUTH
						.equals(clientAuthentication.getClientAuthenticationMethod()))
					&& tokenContext.getRegisteredClient().getTokenSettings().isX509CertificateBoundAccessTokens()) {

				X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
				try {
					String sha256Thumbprint = computeSHA256Thumbprint(clientCertificateChain[0]);
					cnfClaims = new HashMap<>();
					cnfClaims.put("x5t#S256", sha256Thumbprint);
				}
				catch (Exception ex) {
					OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
							"Failed to compute SHA-256 Thumbprint for client X509Certificate.", null);
					throw new OAuth2AuthenticationException(error, ex);
				}
			}
		}

		// Add 'cnf' claim for OAuth 2.0 Demonstrating Proof of Possession (DPoP)
		Jwt dPoPProofJwt = tokenContext.get(OAuth2TokenContext.DPOP_PROOF_KEY);
		if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenContext.getTokenType()) && dPoPProofJwt != null) {
			JWK jwk = null;
			@SuppressWarnings("unchecked")
			Map<String, Object> jwkJson = (Map<String, Object>) dPoPProofJwt.getHeaders().get("jwk");
			try {
				jwk = JWK.parse(jwkJson);
			}
			catch (Exception ignored) {
			}
			if (jwk == null) {
				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
						"jwk header is missing or invalid.", null);
				throw new OAuth2AuthenticationException(error);
			}

			try {
				String sha256Thumbprint = jwk.computeThumbprint().toString();
				if (cnfClaims == null) {
					cnfClaims = new HashMap<>();
				}
				cnfClaims.put("jkt", sha256Thumbprint);
			}
			catch (Exception ex) {
				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
						"Failed to compute SHA-256 Thumbprint for DPoP Proof PublicKey.", null);
				throw new OAuth2AuthenticationException(error, ex);
			}
		}

		if (!CollectionUtils.isEmpty(cnfClaims)) {
			claims.put("cnf", cnfClaims);
		}

		// Add 'act' claim for delegation use case of Token Exchange Grant.
		// If more than one actor is present, we create a chain of delegation by nesting
		// "act" claims.
		if (tokenContext
			.getPrincipal() instanceof OAuth2TokenExchangeCompositeAuthenticationToken compositeAuthenticationToken) {
			Map<String, Object> currentClaims = claims;
			for (OAuth2TokenExchangeActor actor : compositeAuthenticationToken.getActors()) {
				Map<String, Object> actorClaims = actor.getClaims();
				Map<String, Object> actClaim = new LinkedHashMap<>();
				actClaim.put(OAuth2TokenClaimNames.ISS, actorClaims.get(OAuth2TokenClaimNames.ISS));
				actClaim.put(OAuth2TokenClaimNames.SUB, actorClaims.get(OAuth2TokenClaimNames.SUB));
				currentClaims.put("act", Collections.unmodifiableMap(actClaim));
				currentClaims = actClaim;
			}
		}
	}

	private static String computeSHA256Thumbprint(X509Certificate x509Certificate) throws Exception {
		MessageDigest md = MessageDigest.getInstance("SHA-256");
		byte[] digest = md.digest(x509Certificate.getEncoded());
		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
	}

}