JwtDecoderProviderConfigurationUtils.java

/*
 * Copyright 2004-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.oauth2.jwt;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.JWSKeySelector;
import com.nimbusds.jose.proc.JWSVerificationKeySelector;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jwt.proc.ConfigurableJWTProcessor;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.util.Assert;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

/**
 * Allows resolving configuration from an <a href=
 * "https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderConfig">OpenID
 * Provider Configuration</a> or
 * <a href="https://tools.ietf.org/html/rfc8414#section-3.1">Authorization Server Metadata
 * Request</a> based on provided issuer and method invoked.
 *
 * @author Thomas Vitale
 * @author Rafiullah Hamedy
 * @since 5.2
 */
final class JwtDecoderProviderConfigurationUtils {

	private static final String OIDC_METADATA_PATH = "/.well-known/openid-configuration";

	private static final String OAUTH_METADATA_PATH = "/.well-known/oauth-authorization-server";

	private static final RestTemplate rest = new RestTemplate();

	static {
		int connectTimeout = Integer.parseInt(System.getProperty("sun.net.client.defaultConnectTimeout", "30000"));
		int readTimeout = Integer.parseInt(System.getProperty("sun.net.client.defaultReadTimeout", "30000"));
		SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
		requestFactory.setConnectTimeout(connectTimeout);
		requestFactory.setReadTimeout(readTimeout);
		rest.setRequestFactory(requestFactory);
	}

	private static final ParameterizedTypeReference<Map<String, Object>> STRING_OBJECT_MAP = new ParameterizedTypeReference<>() {
	};

	private JwtDecoderProviderConfigurationUtils() {
	}

	static Map<String, Object> getConfigurationForOidcIssuerLocation(String oidcIssuerLocation) {
		return getConfiguration(oidcIssuerLocation, rest, oidc(oidcIssuerLocation));
	}

	static Map<String, Object> getConfigurationForIssuerLocation(String issuer, RestOperations rest) {
		return getConfiguration(issuer, rest, oidc(issuer), oidcRfc8414(issuer), oauth(issuer));
	}

	static Map<String, Object> getConfigurationForIssuerLocation(String issuer) {
		return getConfigurationForIssuerLocation(issuer, rest);
	}

	static void validateIssuer(Map<String, Object> configuration, String issuer) {
		String metadataIssuer = getMetadataIssuer(configuration);
		Assert.state(issuer.equals(metadataIssuer), () -> "The Issuer \"" + metadataIssuer
				+ "\" provided in the configuration did not " + "match the requested issuer \"" + issuer + "\"");
	}

	static <C extends SecurityContext> void addJWSAlgorithms(ConfigurableJWTProcessor<C> jwtProcessor) {
		JWSKeySelector<C> selector = jwtProcessor.getJWSKeySelector();
		if (selector instanceof JWSVerificationKeySelector) {
			JWKSource<C> jwkSource = ((JWSVerificationKeySelector<C>) selector).getJWKSource();
			Set<JWSAlgorithm> algorithms = getJWSAlgorithms(jwkSource);
			selector = new JWSVerificationKeySelector<>(algorithms, jwkSource);
			jwtProcessor.setJWSKeySelector(selector);
		}
	}

	static <C extends SecurityContext> Set<JWSAlgorithm> getJWSAlgorithms(JWKSource<C> jwkSource) {
		JWKMatcher jwkMatcher = new JWKMatcher.Builder().publicOnly(true)
			.keyUses(KeyUse.SIGNATURE, null)
			.keyTypes(KeyType.RSA, KeyType.EC)
			.build();
		Set<JWSAlgorithm> jwsAlgorithms = new HashSet<>();
		try {
			List<? extends JWK> jwks = jwkSource.get(new JWKSelector(jwkMatcher), null);
			for (JWK jwk : jwks) {
				if (jwk.getAlgorithm() != null) {
					JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(jwk.getAlgorithm().getName());
					jwsAlgorithms.add(jwsAlgorithm);
				}
				else {
					if (jwk.getKeyType() == KeyType.RSA) {
						jwsAlgorithms.addAll(JWSAlgorithm.Family.RSA);
					}
					else if (jwk.getKeyType() == KeyType.EC) {
						jwsAlgorithms.addAll(JWSAlgorithm.Family.EC);
					}
				}
			}
		}
		catch (KeySourceException ex) {
			throw new IllegalStateException(ex);
		}
		Assert.notEmpty(jwsAlgorithms, "Failed to find any algorithms from the JWK set");
		return jwsAlgorithms;
	}

	static Set<SignatureAlgorithm> getSignatureAlgorithms(JWKSource<SecurityContext> jwkSource) {
		Set<JWSAlgorithm> jwsAlgorithms = getJWSAlgorithms(jwkSource);
		Set<SignatureAlgorithm> signatureAlgorithms = new HashSet<>();
		for (JWSAlgorithm jwsAlgorithm : jwsAlgorithms) {
			SignatureAlgorithm signatureAlgorithm = SignatureAlgorithm.from(jwsAlgorithm.getName());
			if (signatureAlgorithm != null) {
				signatureAlgorithms.add(signatureAlgorithm);
			}
		}
		return signatureAlgorithms;
	}

	private static String getMetadataIssuer(Map<String, Object> configuration) {
		if (configuration.containsKey("issuer")) {
			return configuration.get("issuer").toString();
		}
		return "(unavailable)";
	}

	private static Map<String, Object> getConfiguration(String issuer, RestOperations rest, UriComponents... uris) {
		String errorMessage = "Unable to resolve the Configuration with the provided Issuer of " + "\"" + issuer + "\"";
		for (UriComponents uri : uris) {
			try {
				RequestEntity<Void> request = RequestEntity.get(uri.toUriString()).build();
				ResponseEntity<Map<String, Object>> response = rest.exchange(request, STRING_OBJECT_MAP);
				Map<String, Object> configuration = response.getBody();
				Assert.isTrue(configuration.get("jwks_uri") != null, "The public JWK set URI must not be null");
				return configuration;
			}
			catch (IllegalArgumentException ex) {
				throw ex;
			}
			catch (RuntimeException ex) {
				if (!(ex instanceof HttpClientErrorException
						&& ((HttpClientErrorException) ex).getStatusCode().is4xxClientError())) {
					throw new IllegalArgumentException(errorMessage, ex);
				}
				// else try another endpoint
			}
		}
		throw new IllegalArgumentException(errorMessage);
	}

	static UriComponents oidc(String issuer) {
		UriComponents uri = UriComponentsBuilder.fromUriString(issuer).build();
		// @formatter:off
		return UriComponentsBuilder.newInstance().uriComponents(uri)
				.replacePath(uri.getPath() + OIDC_METADATA_PATH)
				.build();
		// @formatter:on
	}

	static UriComponents oidcRfc8414(String issuer) {
		UriComponents uri = UriComponentsBuilder.fromUriString(issuer).build();
		// @formatter:off
		return UriComponentsBuilder.newInstance().uriComponents(uri)
				.replacePath(OIDC_METADATA_PATH + uri.getPath())
				.build();
		// @formatter:on
	}

	static UriComponents oauth(String issuer) {
		UriComponents uri = UriComponentsBuilder.fromUriString(issuer).build();
		// @formatter:off
		return UriComponentsBuilder.newInstance().uriComponents(uri)
				.replacePath(OAUTH_METADATA_PATH + uri.getPath())
				.build();
		// @formatter:on
	}

}