Saml2LoginBeanDefinitionParser.java
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.config.http;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.w3c.dom.Element;
import org.springframework.beans.BeanMetadataElement;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.BeanReference;
import org.springframework.beans.factory.config.RuntimeBeanReference;
import org.springframework.beans.factory.parsing.BeanComponentDefinition;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.xml.BeanDefinitionParser;
import org.springframework.beans.factory.xml.ParserContext;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.ResolvableType;
import org.springframework.security.config.Elements;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationRequestFilter;
import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter;
import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
import org.springframework.util.StringUtils;
import org.springframework.util.xml.DomUtils;
/**
* SAML 2.0 Login {@link BeanDefinitionParser}
*
* @author Marcus da Coregio
* @since 5.7
*/
final class Saml2LoginBeanDefinitionParser implements BeanDefinitionParser {
private static final String DEFAULT_LOGIN_URI = DefaultLoginPageGeneratingFilter.DEFAULT_LOGIN_PAGE_URL;
private static final String DEFAULT_AUTHENTICATION_REQUEST_PROCESSING_URL = "/saml2/authenticate/{registrationId}";
private static final String ATT_LOGIN_PROCESSING_URL = "login-processing-url";
private static final String ATT_LOGIN_PAGE = "login-page";
private static final String ELT_RELYING_PARTY_REGISTRATION = "relying-party-registration";
private static final String ELT_REGISTRATION_ID = "registration-id";
private static final String ATT_AUTHENTICATION_FAILURE_HANDLER_REF = "authentication-failure-handler-ref";
private static final String ATT_AUTHENTICATION_SUCCESS_HANDLER_REF = "authentication-success-handler-ref";
private static final String ATT_AUTHENTICATION_MANAGER_REF = "authentication-manager-ref";
private final List<BeanDefinition> csrfIgnoreRequestMatchers;
private final BeanReference portMapper;
private final BeanReference requestCache;
private final boolean allowSessionCreation;
private final BeanReference authenticationManager;
private final BeanReference authenticationFilterSecurityContextRepositoryRef;
private final List<BeanReference> authenticationProviders;
private final Map<BeanDefinition, BeanMetadataElement> entryPoints;
private String loginProcessingUrl = Saml2WebSsoAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
private BeanDefinition saml2WebSsoAuthenticationRequestFilter;
private BeanDefinition saml2AuthenticationUrlToProviderName;
Saml2LoginBeanDefinitionParser(List<BeanDefinition> csrfIgnoreRequestMatchers, BeanReference portMapper,
BeanReference requestCache, boolean allowSessionCreation, BeanReference authenticationManager,
BeanReference authenticationFilterSecurityContextRepositoryRef, List<BeanReference> authenticationProviders,
Map<BeanDefinition, BeanMetadataElement> entryPoints) {
this.csrfIgnoreRequestMatchers = csrfIgnoreRequestMatchers;
this.portMapper = portMapper;
this.requestCache = requestCache;
this.allowSessionCreation = allowSessionCreation;
this.authenticationManager = authenticationManager;
this.authenticationFilterSecurityContextRepositoryRef = authenticationFilterSecurityContextRepositoryRef;
this.authenticationProviders = authenticationProviders;
this.entryPoints = entryPoints;
}
@Override
public BeanDefinition parse(Element element, ParserContext pc) {
String loginProcessingUrl = element.getAttribute(ATT_LOGIN_PROCESSING_URL);
if (StringUtils.hasText(loginProcessingUrl)) {
this.loginProcessingUrl = loginProcessingUrl;
}
BeanDefinition saml2LoginBeanConfig = BeanDefinitionBuilder.rootBeanDefinition(Saml2LoginBeanConfig.class)
.getBeanDefinition();
String saml2LoginBeanConfigId = pc.getReaderContext().generateBeanName(saml2LoginBeanConfig);
pc.registerBeanComponent(new BeanComponentDefinition(saml2LoginBeanConfig, saml2LoginBeanConfigId));
registerDefaultCsrfOverride();
BeanMetadataElement relyingPartyRegistrationRepository = Saml2LoginBeanDefinitionParserUtils
.getRelyingPartyRegistrationRepository(element);
BeanMetadataElement authenticationRequestRepository = Saml2LoginBeanDefinitionParserUtils
.getAuthenticationRequestRepository(element);
BeanMetadataElement authenticationRequestResolver = Saml2LoginBeanDefinitionParserUtils
.getAuthenticationRequestResolver(element);
if (authenticationRequestResolver == null) {
authenticationRequestResolver = Saml2LoginBeanDefinitionParserUtils
.createDefaultAuthenticationRequestResolver(relyingPartyRegistrationRepository);
}
BeanMetadataElement authenticationConverter = Saml2LoginBeanDefinitionParserUtils
.getAuthenticationConverter(element);
if (authenticationConverter == null) {
if (!this.loginProcessingUrl.contains("{registrationId}")) {
pc.getReaderContext().error("loginProcessingUrl must contain {registrationId} path variable", element);
}
authenticationConverter = Saml2LoginBeanDefinitionParserUtils
.createDefaultAuthenticationConverter(relyingPartyRegistrationRepository);
}
// Configure the Saml2WebSsoAuthenticationFilter
BeanDefinitionBuilder saml2WebSsoAuthenticationFilterBuilder = BeanDefinitionBuilder
.rootBeanDefinition(Saml2WebSsoAuthenticationFilter.class)
.addConstructorArgValue(authenticationConverter)
.addConstructorArgValue(this.loginProcessingUrl)
.addPropertyValue("authenticationRequestRepository", authenticationRequestRepository);
resolveLoginPage(element, pc);
resolveAuthenticationSuccessHandler(element, saml2WebSsoAuthenticationFilterBuilder);
resolveAuthenticationFailureHandler(element, saml2WebSsoAuthenticationFilterBuilder);
resolveAuthenticationManager(element, saml2WebSsoAuthenticationFilterBuilder);
resolveSecurityContextRepository(element, saml2WebSsoAuthenticationFilterBuilder);
// Configure the Saml2WebSsoAuthenticationRequestFilter
this.saml2WebSsoAuthenticationRequestFilter = BeanDefinitionBuilder
.rootBeanDefinition(Saml2WebSsoAuthenticationRequestFilter.class)
.addConstructorArgValue(authenticationRequestResolver)
.addPropertyValue("authenticationRequestRepository", authenticationRequestRepository)
.getBeanDefinition();
BeanDefinition saml2AuthenticationProvider = Saml2LoginBeanDefinitionParserUtils.createAuthenticationProvider();
this.authenticationProviders.add(
new RuntimeBeanReference(pc.getReaderContext().registerWithGeneratedName(saml2AuthenticationProvider)));
this.saml2AuthenticationUrlToProviderName = BeanDefinitionBuilder.rootBeanDefinition(Map.class)
.setFactoryMethodOnBean("getAuthenticationUrlToProviderName", saml2LoginBeanConfigId)
.getBeanDefinition();
return saml2WebSsoAuthenticationFilterBuilder.getBeanDefinition();
}
private void resolveAuthenticationManager(Element element,
BeanDefinitionBuilder saml2WebSsoAuthenticationFilterBuilder) {
String authenticationManagerRef = element.getAttribute(ATT_AUTHENTICATION_MANAGER_REF);
if (StringUtils.hasText(authenticationManagerRef)) {
saml2WebSsoAuthenticationFilterBuilder.addPropertyReference("authenticationManager",
authenticationManagerRef);
}
else {
saml2WebSsoAuthenticationFilterBuilder.addPropertyValue("authenticationManager",
this.authenticationManager);
}
}
private void resolveSecurityContextRepository(Element element,
BeanDefinitionBuilder saml2WebSsoAuthenticationFilterBuilder) {
if (this.authenticationFilterSecurityContextRepositoryRef != null) {
saml2WebSsoAuthenticationFilterBuilder.addPropertyValue("securityContextRepository",
this.authenticationFilterSecurityContextRepositoryRef);
}
}
private void resolveLoginPage(Element element, ParserContext parserContext) {
String loginPage = element.getAttribute(ATT_LOGIN_PAGE);
Object source = parserContext.extractSource(element);
BeanDefinition saml2LoginAuthenticationEntryPoint = null;
if (StringUtils.hasText(loginPage)) {
WebConfigUtils.validateHttpRedirect(loginPage, parserContext, source);
saml2LoginAuthenticationEntryPoint = BeanDefinitionBuilder
.rootBeanDefinition(LoginUrlAuthenticationEntryPoint.class)
.addConstructorArgValue(loginPage)
.addPropertyValue("portMapper", this.portMapper)
.getBeanDefinition();
}
else {
Map<String, String> identityProviderUrlMap = getIdentityProviderUrlMap(element);
if (identityProviderUrlMap.size() == 1) {
String loginUrl = identityProviderUrlMap.entrySet().iterator().next().getKey();
saml2LoginAuthenticationEntryPoint = BeanDefinitionBuilder
.rootBeanDefinition(LoginUrlAuthenticationEntryPoint.class)
.addConstructorArgValue(loginUrl)
.addPropertyValue("portMapper", this.portMapper)
.getBeanDefinition();
}
}
if (saml2LoginAuthenticationEntryPoint != null) {
BeanDefinitionBuilder requestMatcherBuilder = BeanDefinitionBuilder
.rootBeanDefinition(RequestMatcherFactoryBean.class);
requestMatcherBuilder.addConstructorArgValue(this.loginProcessingUrl);
BeanDefinition requestMatcher = requestMatcherBuilder.getBeanDefinition();
this.entryPoints.put(requestMatcher, saml2LoginAuthenticationEntryPoint);
}
}
private void resolveAuthenticationFailureHandler(Element element,
BeanDefinitionBuilder saml2WebSsoAuthenticationFilterBuilder) {
String authenticationFailureHandlerRef = element.getAttribute(ATT_AUTHENTICATION_FAILURE_HANDLER_REF);
if (StringUtils.hasText(authenticationFailureHandlerRef)) {
saml2WebSsoAuthenticationFilterBuilder.addPropertyReference("authenticationFailureHandler",
authenticationFailureHandlerRef);
}
else {
BeanDefinitionBuilder failureHandlerBuilder = BeanDefinitionBuilder.rootBeanDefinition(
"org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler");
failureHandlerBuilder.addConstructorArgValue(
DEFAULT_LOGIN_URI + "?" + DefaultLoginPageGeneratingFilter.ERROR_PARAMETER_NAME);
failureHandlerBuilder.addPropertyValue("allowSessionCreation", this.allowSessionCreation);
saml2WebSsoAuthenticationFilterBuilder.addPropertyValue("authenticationFailureHandler",
failureHandlerBuilder.getBeanDefinition());
}
}
private void resolveAuthenticationSuccessHandler(Element element,
BeanDefinitionBuilder saml2WebSsoAuthenticationFilterBuilder) {
String authenticationSuccessHandlerRef = element.getAttribute(ATT_AUTHENTICATION_SUCCESS_HANDLER_REF);
if (StringUtils.hasText(authenticationSuccessHandlerRef)) {
saml2WebSsoAuthenticationFilterBuilder.addPropertyReference("authenticationSuccessHandler",
authenticationSuccessHandlerRef);
}
else {
BeanDefinitionBuilder successHandlerBuilder = BeanDefinitionBuilder
.rootBeanDefinition(
"org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler")
.addPropertyValue("requestCache", this.requestCache);
saml2WebSsoAuthenticationFilterBuilder.addPropertyValue("authenticationSuccessHandler",
successHandlerBuilder.getBeanDefinition());
}
}
private void registerDefaultCsrfOverride() {
BeanDefinitionBuilder requestMatcherBuilder = BeanDefinitionBuilder
.rootBeanDefinition(RequestMatcherFactoryBean.class);
requestMatcherBuilder.addConstructorArgValue(this.loginProcessingUrl);
BeanDefinition requestMatcher = requestMatcherBuilder.getBeanDefinition();
this.csrfIgnoreRequestMatchers.add(requestMatcher);
}
private Map<String, String> getIdentityProviderUrlMap(Element element) {
Map<String, String> idps = new LinkedHashMap<>();
Element relyingPartyRegistrationsElt = DomUtils.getChildElementByTagName(
element.getOwnerDocument().getDocumentElement(), Elements.RELYING_PARTY_REGISTRATIONS);
String authenticationRequestProcessingUrl = DEFAULT_AUTHENTICATION_REQUEST_PROCESSING_URL;
if (relyingPartyRegistrationsElt != null) {
List<Element> relyingPartyRegList = DomUtils.getChildElementsByTagName(relyingPartyRegistrationsElt,
ELT_RELYING_PARTY_REGISTRATION);
for (Element relyingPartyReg : relyingPartyRegList) {
String registrationId = relyingPartyReg.getAttribute(ELT_REGISTRATION_ID);
idps.put(authenticationRequestProcessingUrl.replace("{registrationId}", registrationId),
registrationId);
}
}
return idps;
}
BeanDefinition getSaml2WebSsoAuthenticationRequestFilter() {
return this.saml2WebSsoAuthenticationRequestFilter;
}
BeanDefinition getSaml2AuthenticationUrlToProviderName() {
return this.saml2AuthenticationUrlToProviderName;
}
/**
* Wrapper bean class to provide configuration from applicationContext
*/
public static class Saml2LoginBeanConfig implements ApplicationContextAware {
private ApplicationContext context;
@SuppressWarnings({ "unchecked", "unused" })
Map<String, String> getAuthenticationUrlToProviderName() {
Iterable<RelyingPartyRegistration> relyingPartyRegistrations = null;
RelyingPartyRegistrationRepository relyingPartyRegistrationRepository = this.context
.getBean(RelyingPartyRegistrationRepository.class);
ResolvableType type = ResolvableType.forInstance(relyingPartyRegistrationRepository).as(Iterable.class);
if (type != ResolvableType.NONE
&& RelyingPartyRegistration.class.isAssignableFrom(type.resolveGenerics()[0])) {
relyingPartyRegistrations = (Iterable<RelyingPartyRegistration>) relyingPartyRegistrationRepository;
}
if (relyingPartyRegistrations == null) {
return Collections.emptyMap();
}
String authenticationRequestProcessingUrl = DEFAULT_AUTHENTICATION_REQUEST_PROCESSING_URL;
Map<String, String> saml2AuthenticationUrlToProviderName = new HashMap<>();
relyingPartyRegistrations.forEach((registration) -> saml2AuthenticationUrlToProviderName.put(
authenticationRequestProcessingUrl.replace("{registrationId}", registration.getRegistrationId()),
registration.getRegistrationId()));
return saml2AuthenticationUrlToProviderName;
}
@Override
public void setApplicationContext(ApplicationContext context) throws BeansException {
this.context = context;
}
}
}