X509SelfSignedCertificateVerifier.java
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.server.authorization.authentication;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.PublicKey;
import java.security.cert.X509Certificate;
import java.text.ParseException;
import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.security.auth.x500.X500Principal;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSet;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
/**
* The default {@code X509Certificate} verifier for the
* {@code self_signed_tls_client_auth} authentication method.
*
* @author Joe Grandja
* @since 7.0
* @see X509ClientCertificateAuthenticationProvider#setCertificateVerifier(Consumer)
*/
final class X509SelfSignedCertificateVerifier implements Consumer<OAuth2ClientAuthenticationContext> {
private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
private static final JWKMatcher HAS_X509_CERT_CHAIN_MATCHER = new JWKMatcher.Builder().hasX509CertChain(true)
.build();
private final Function<RegisteredClient, JWKSet> jwkSetSupplier = new JwkSetSupplier();
@Override
public void accept(OAuth2ClientAuthenticationContext clientAuthenticationContext) {
OAuth2ClientAuthenticationToken clientAuthentication = clientAuthenticationContext.getAuthentication();
RegisteredClient registeredClient = clientAuthenticationContext.getRegisteredClient();
X509Certificate[] clientCertificateChain = (X509Certificate[]) clientAuthentication.getCredentials();
X509Certificate clientCertificate = clientCertificateChain[0];
X500Principal issuer = clientCertificate.getIssuerX500Principal();
X500Principal subject = clientCertificate.getSubjectX500Principal();
if (issuer == null || !issuer.equals(subject)) {
throwInvalidClient("x509_certificate_issuer");
}
JWKSet jwkSet = this.jwkSetSupplier.apply(registeredClient);
boolean publicKeyMatches = false;
for (JWK jwk : jwkSet.filter(HAS_X509_CERT_CHAIN_MATCHER).getKeys()) {
X509Certificate x509Certificate = jwk.getParsedX509CertChain().get(0);
PublicKey publicKey = x509Certificate.getPublicKey();
if (Arrays.equals(clientCertificate.getPublicKey().getEncoded(), publicKey.getEncoded())) {
publicKeyMatches = true;
break;
}
}
if (!publicKeyMatches) {
throwInvalidClient("x509_certificate");
}
}
private static void throwInvalidClient(String parameterName) {
throwInvalidClient(parameterName, null);
}
private static void throwInvalidClient(String parameterName, Throwable cause) {
OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
"Client authentication failed: " + parameterName, ERROR_URI);
throw new OAuth2AuthenticationException(error, error.toString(), cause);
}
private static final class JwkSetSupplier implements Function<RegisteredClient, JWKSet> {
private static final MediaType APPLICATION_JWK_SET_JSON = new MediaType("application", "jwk-set+json");
private final RestOperations restOperations;
private final Map<String, Supplier<JWKSet>> jwkSets = new ConcurrentHashMap<>();
private JwkSetSupplier() {
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
requestFactory.setConnectTimeout(15_000);
requestFactory.setReadTimeout(15_000);
this.restOperations = new RestTemplate(requestFactory);
}
@Override
public JWKSet apply(RegisteredClient registeredClient) {
Supplier<JWKSet> jwkSetSupplier = this.jwkSets.computeIfAbsent(registeredClient.getId(), (key) -> {
if (!StringUtils.hasText(registeredClient.getClientSettings().getJwkSetUrl())) {
throwInvalidClient("client_jwk_set_url");
}
return new JwkSetHolder(registeredClient.getClientSettings().getJwkSetUrl());
});
return jwkSetSupplier.get();
}
private JWKSet retrieve(String jwkSetUrl) {
URI jwkSetUri = null;
try {
jwkSetUri = new URI(jwkSetUrl);
}
catch (URISyntaxException ex) {
throwInvalidClient("jwk_set_uri", ex);
}
HttpHeaders headers = new HttpHeaders();
headers.setAccept(Arrays.asList(MediaType.APPLICATION_JSON, APPLICATION_JWK_SET_JSON));
RequestEntity<Void> request = new RequestEntity<>(headers, HttpMethod.GET, jwkSetUri);
ResponseEntity<String> response = null;
try {
response = this.restOperations.exchange(request, String.class);
}
catch (Exception ex) {
throwInvalidClient("jwk_set_response_error", ex);
}
if (response.getStatusCode().value() != 200) {
throwInvalidClient("jwk_set_response_status");
}
JWKSet jwkSet = null;
try {
jwkSet = JWKSet.parse(response.getBody());
}
catch (ParseException ex) {
throwInvalidClient("jwk_set_response_body", ex);
}
return jwkSet;
}
private final class JwkSetHolder implements Supplier<JWKSet> {
private final ReentrantReadWriteLock rwLock = new ReentrantReadWriteLock();
private final Clock clock = Clock.systemUTC();
private final String jwkSetUrl;
private JWKSet jwkSet;
private Instant lastUpdatedAt;
private JwkSetHolder(String jwkSetUrl) {
this.jwkSetUrl = jwkSetUrl;
}
@Override
public JWKSet get() {
this.rwLock.readLock().lock();
if (shouldRefresh()) {
this.rwLock.readLock().unlock();
this.rwLock.writeLock().lock();
try {
if (shouldRefresh()) {
this.jwkSet = retrieve(this.jwkSetUrl);
this.lastUpdatedAt = Instant.now();
}
this.rwLock.readLock().lock();
}
finally {
this.rwLock.writeLock().unlock();
}
}
try {
return this.jwkSet;
}
finally {
this.rwLock.readLock().unlock();
}
}
private boolean shouldRefresh() {
// Refresh every 5 minutes
return (this.jwkSet == null
|| this.clock.instant().isAfter(this.lastUpdatedAt.plus(5, ChronoUnit.MINUTES)));
}
}
}
}