OAuth2ConfigurerUtils.java
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
import java.util.Map;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import org.springframework.beans.factory.BeanFactoryUtils;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
import org.springframework.context.ApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.oauth2.core.OAuth2Token;
import org.springframework.security.oauth2.jwt.JwtEncoder;
import org.springframework.security.oauth2.jwt.NimbusJwtEncoder;
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
import org.springframework.security.oauth2.server.authorization.token.DelegatingOAuth2TokenGenerator;
import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
import org.springframework.security.oauth2.server.authorization.token.JwtGenerator;
import org.springframework.security.oauth2.server.authorization.token.OAuth2AccessTokenGenerator;
import org.springframework.security.oauth2.server.authorization.token.OAuth2RefreshTokenGenerator;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenClaimsContext;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Utility methods for the OAuth 2.0 Configurers.
*
* @author Joe Grandja
* @since 7.0
*/
final class OAuth2ConfigurerUtils {
private OAuth2ConfigurerUtils() {
}
static String withMultipleIssuersPattern(String endpointUri) {
Assert.hasText(endpointUri, "endpointUri cannot be empty");
return endpointUri.startsWith("/") ? "/**" + endpointUri : "/**/" + endpointUri;
}
static RegisteredClientRepository getRegisteredClientRepository(HttpSecurity httpSecurity) {
RegisteredClientRepository registeredClientRepository = httpSecurity
.getSharedObject(RegisteredClientRepository.class);
if (registeredClientRepository == null) {
registeredClientRepository = getBean(httpSecurity, RegisteredClientRepository.class);
httpSecurity.setSharedObject(RegisteredClientRepository.class, registeredClientRepository);
}
return registeredClientRepository;
}
static OAuth2AuthorizationService getAuthorizationService(HttpSecurity httpSecurity) {
OAuth2AuthorizationService authorizationService = httpSecurity
.getSharedObject(OAuth2AuthorizationService.class);
if (authorizationService == null) {
authorizationService = getOptionalBean(httpSecurity, OAuth2AuthorizationService.class);
if (authorizationService == null) {
authorizationService = new InMemoryOAuth2AuthorizationService();
}
httpSecurity.setSharedObject(OAuth2AuthorizationService.class, authorizationService);
}
return authorizationService;
}
static OAuth2AuthorizationConsentService getAuthorizationConsentService(HttpSecurity httpSecurity) {
OAuth2AuthorizationConsentService authorizationConsentService = httpSecurity
.getSharedObject(OAuth2AuthorizationConsentService.class);
if (authorizationConsentService == null) {
authorizationConsentService = getOptionalBean(httpSecurity, OAuth2AuthorizationConsentService.class);
if (authorizationConsentService == null) {
authorizationConsentService = new InMemoryOAuth2AuthorizationConsentService();
}
httpSecurity.setSharedObject(OAuth2AuthorizationConsentService.class, authorizationConsentService);
}
return authorizationConsentService;
}
@SuppressWarnings("unchecked")
static OAuth2TokenGenerator<? extends OAuth2Token> getTokenGenerator(HttpSecurity httpSecurity) {
OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator = httpSecurity
.getSharedObject(OAuth2TokenGenerator.class);
if (tokenGenerator == null) {
tokenGenerator = getOptionalBean(httpSecurity, OAuth2TokenGenerator.class);
if (tokenGenerator == null) {
JwtGenerator jwtGenerator = getJwtGenerator(httpSecurity);
OAuth2AccessTokenGenerator accessTokenGenerator = new OAuth2AccessTokenGenerator();
accessTokenGenerator.setAccessTokenCustomizer(getAccessTokenCustomizer(httpSecurity));
OAuth2RefreshTokenGenerator refreshTokenGenerator = new OAuth2RefreshTokenGenerator();
if (jwtGenerator != null) {
tokenGenerator = new DelegatingOAuth2TokenGenerator(jwtGenerator, accessTokenGenerator,
refreshTokenGenerator);
}
else {
tokenGenerator = new DelegatingOAuth2TokenGenerator(accessTokenGenerator, refreshTokenGenerator);
}
}
httpSecurity.setSharedObject(OAuth2TokenGenerator.class, tokenGenerator);
}
return tokenGenerator;
}
private static JwtGenerator getJwtGenerator(HttpSecurity httpSecurity) {
JwtGenerator jwtGenerator = httpSecurity.getSharedObject(JwtGenerator.class);
if (jwtGenerator == null) {
JwtEncoder jwtEncoder = getJwtEncoder(httpSecurity);
if (jwtEncoder != null) {
jwtGenerator = new JwtGenerator(jwtEncoder);
jwtGenerator.setJwtCustomizer(getJwtCustomizer(httpSecurity));
httpSecurity.setSharedObject(JwtGenerator.class, jwtGenerator);
}
}
return jwtGenerator;
}
private static JwtEncoder getJwtEncoder(HttpSecurity httpSecurity) {
JwtEncoder jwtEncoder = httpSecurity.getSharedObject(JwtEncoder.class);
if (jwtEncoder == null) {
jwtEncoder = getOptionalBean(httpSecurity, JwtEncoder.class);
if (jwtEncoder == null) {
JWKSource<SecurityContext> jwkSource = getJwkSource(httpSecurity);
if (jwkSource != null) {
jwtEncoder = new NimbusJwtEncoder(jwkSource);
}
}
if (jwtEncoder != null) {
httpSecurity.setSharedObject(JwtEncoder.class, jwtEncoder);
}
}
return jwtEncoder;
}
@SuppressWarnings("unchecked")
static JWKSource<SecurityContext> getJwkSource(HttpSecurity httpSecurity) {
JWKSource<SecurityContext> jwkSource = httpSecurity.getSharedObject(JWKSource.class);
if (jwkSource == null) {
ResolvableType type = ResolvableType.forClassWithGenerics(JWKSource.class, SecurityContext.class);
jwkSource = getOptionalBean(httpSecurity, type);
if (jwkSource != null) {
httpSecurity.setSharedObject(JWKSource.class, jwkSource);
}
}
return jwkSource;
}
private static OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(HttpSecurity httpSecurity) {
final OAuth2TokenCustomizer<JwtEncodingContext> defaultJwtCustomizer = DefaultOAuth2TokenCustomizers
.jwtCustomizer();
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class,
JwtEncodingContext.class);
final OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = getOptionalBean(httpSecurity, type);
if (jwtCustomizer == null) {
return defaultJwtCustomizer;
}
return (context) -> {
defaultJwtCustomizer.customize(context);
jwtCustomizer.customize(context);
};
}
private static OAuth2TokenCustomizer<OAuth2TokenClaimsContext> getAccessTokenCustomizer(HttpSecurity httpSecurity) {
final OAuth2TokenCustomizer<OAuth2TokenClaimsContext> defaultAccessTokenCustomizer = DefaultOAuth2TokenCustomizers
.accessTokenCustomizer();
ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class,
OAuth2TokenClaimsContext.class);
OAuth2TokenCustomizer<OAuth2TokenClaimsContext> accessTokenCustomizer = getOptionalBean(httpSecurity, type);
if (accessTokenCustomizer == null) {
return defaultAccessTokenCustomizer;
}
return (context) -> {
defaultAccessTokenCustomizer.customize(context);
accessTokenCustomizer.customize(context);
};
}
static AuthorizationServerSettings getAuthorizationServerSettings(HttpSecurity httpSecurity) {
AuthorizationServerSettings authorizationServerSettings = httpSecurity
.getSharedObject(AuthorizationServerSettings.class);
if (authorizationServerSettings == null) {
authorizationServerSettings = getBean(httpSecurity, AuthorizationServerSettings.class);
httpSecurity.setSharedObject(AuthorizationServerSettings.class, authorizationServerSettings);
}
return authorizationServerSettings;
}
static <T> T getBean(HttpSecurity httpSecurity, Class<T> type) {
return httpSecurity.getSharedObject(ApplicationContext.class).getBean(type);
}
@SuppressWarnings("unchecked")
static <T> T getBean(HttpSecurity httpSecurity, ResolvableType type) {
ApplicationContext context = httpSecurity.getSharedObject(ApplicationContext.class);
String[] names = context.getBeanNamesForType(type);
if (names.length == 1) {
return (T) context.getBean(names[0]);
}
if (names.length > 1) {
throw new NoUniqueBeanDefinitionException(type, names);
}
throw new NoSuchBeanDefinitionException(type);
}
static <T> T getOptionalBean(HttpSecurity httpSecurity, Class<T> type) {
Map<String, T> beansMap = BeanFactoryUtils
.beansOfTypeIncludingAncestors(httpSecurity.getSharedObject(ApplicationContext.class), type);
if (beansMap.size() > 1) {
throw new NoUniqueBeanDefinitionException(type, beansMap.size(),
"Expected single matching bean of type '" + type.getName() + "' but found " + beansMap.size() + ": "
+ StringUtils.collectionToCommaDelimitedString(beansMap.keySet()));
}
return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null);
}
@SuppressWarnings("unchecked")
static <T> T getOptionalBean(HttpSecurity httpSecurity, ResolvableType type) {
ApplicationContext context = httpSecurity.getSharedObject(ApplicationContext.class);
String[] names = context.getBeanNamesForType(type);
if (names.length > 1) {
throw new NoUniqueBeanDefinitionException(type, names);
}
return (names.length == 1) ? (T) context.getBean(names[0]) : null;
}
}