BaseOpenSamlAuthenticationRequestResolver.java
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.web.authentication;
import java.time.Clock;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;
import jakarta.servlet.http.HttpServletRequest;
import org.opensaml.core.config.ConfigurationService;
import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.Issuer;
import org.opensaml.saml.saml2.core.NameID;
import org.opensaml.saml.saml2.core.NameIDPolicy;
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.saml2.core.OpenSamlInitializationService;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
import org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher;
import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.ParameterRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatchers;
import org.springframework.util.Assert;
import static org.springframework.security.web.servlet.util.matcher.PathPatternRequestMatcher.pathPattern;
/**
* For internal use only. Intended for consolidating common behavior related to minting a
* SAML 2.0 Authn Request.
*/
class BaseOpenSamlAuthenticationRequestResolver implements Saml2AuthenticationRequestResolver {
static {
OpenSamlInitializationService.initialize();
}
private final OpenSamlOperations saml;
private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
private final AuthnRequestBuilder authnRequestBuilder;
private final AuthnRequestMarshaller marshaller;
private final IssuerBuilder issuerBuilder;
private final NameIDBuilder nameIdBuilder;
private final NameIDPolicyBuilder nameIdPolicyBuilder;
private RequestMatcher requestMatcher = RequestMatchers.anyOf(
PathPatternRequestMatcher.withDefaults()
.matcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI),
new PathPatternQueryRequestMatcher("/saml2/authenticate", "registrationId={registrationId}"));
private Clock clock = Clock.systemUTC();
private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
private Consumer<AuthnRequestParameters> parametersConsumer = (parameters) -> {
};
/**
* Construct a {@link BaseOpenSamlAuthenticationRequestResolver} using the provided
* parameters
* @param relyingPartyRegistrationResolver a strategy for resolving the
* {@link RelyingPartyRegistration} from the {@link HttpServletRequest}
*/
BaseOpenSamlAuthenticationRequestResolver(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
OpenSamlOperations saml) {
this.saml = saml;
Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory()
.getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.marshaller, "authnRequestMarshaller must be configured in OpenSAML");
this.authnRequestBuilder = (AuthnRequestBuilder) XMLObjectProviderRegistrySupport.getBuilderFactory()
.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.authnRequestBuilder, "authnRequestBuilder must be configured in OpenSAML");
this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML");
this.nameIdBuilder = (NameIDBuilder) registry.getBuilderFactory().getBuilder(NameID.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory()
.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME);
Assert.notNull(this.nameIdPolicyBuilder, "nameIdPolicyBuilder must be configured in OpenSAML");
}
void setClock(Clock clock) {
this.clock = clock;
}
void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
this.relayStateResolver = relayStateResolver;
}
void setRequestMatcher(RequestMatcher requestMatcher) {
this.requestMatcher = requestMatcher;
}
void setParametersConsumer(Consumer<AuthnRequestParameters> parametersConsumer) {
this.parametersConsumer = parametersConsumer;
}
@Override
public <T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest request) {
RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
if (!result.isMatch()) {
return null;
}
String registrationId = result.getVariables().get("registrationId");
RelyingPartyRegistration registration = this.relyingPartyRegistrationResolver.resolve(request, registrationId);
if (registration == null) {
return null;
}
UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
String entityId = uriResolver.resolve(registration.getEntityId());
String assertionConsumerServiceLocation = uriResolver
.resolve(registration.getAssertionConsumerServiceLocation());
AuthnRequest authnRequest = this.authnRequestBuilder.buildObject();
authnRequest.setForceAuthn(Boolean.FALSE);
authnRequest.setIsPassive(Boolean.FALSE);
authnRequest.setProtocolBinding(registration.getAssertionConsumerServiceBinding().getUrn());
Issuer iss = this.issuerBuilder.buildObject();
iss.setValue(entityId);
authnRequest.setIssuer(iss);
authnRequest.setDestination(registration.getAssertingPartyMetadata().getSingleSignOnServiceLocation());
authnRequest.setAssertionConsumerServiceURL(assertionConsumerServiceLocation);
if (registration.getNameIdFormat() != null) {
NameIDPolicy nameIdPolicy = this.nameIdPolicyBuilder.buildObject();
nameIdPolicy.setFormat(registration.getNameIdFormat());
authnRequest.setNameIDPolicy(nameIdPolicy);
}
authnRequest.setIssueInstant(Instant.now(this.clock));
this.parametersConsumer.accept(new AuthnRequestParameters(request, registration, authnRequest));
if (authnRequest.getID() == null) {
authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
}
String relayState = this.relayStateResolver.convert(request);
Saml2MessageBinding binding = registration.getAssertingPartyMetadata().getSingleSignOnServiceBinding();
if (binding == Saml2MessageBinding.POST) {
if (registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned()
|| registration.isAuthnRequestsSigned()) {
this.saml.withSigningKeys(registration.getSigningX509Credentials())
.algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms())
.sign(authnRequest);
}
String xml = serialize(authnRequest);
String encoded = Saml2Utils.withDecoded(xml).encode();
return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration)
.samlRequest(encoded)
.relayState(relayState)
.id(authnRequest.getID())
.build();
}
else {
String xml = serialize(authnRequest);
String deflatedAndEncoded = Saml2Utils.withDecoded(xml).deflate(true).encode();
Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
.withRelyingPartyRegistration(registration)
.samlRequest(deflatedAndEncoded)
.relayState(relayState)
.id(authnRequest.getID());
if (registration.getAssertingPartyMetadata().getWantAuthnRequestsSigned()
|| registration.isAuthnRequestsSigned()) {
Map<String, String> signingParameters = new HashMap<>();
signingParameters.put(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded);
if (relayState != null) {
signingParameters.put(Saml2ParameterNames.RELAY_STATE, relayState);
}
Map<String, String> query = this.saml.withSigningKeys(registration.getSigningX509Credentials())
.algorithms(registration.getAssertingPartyMetadata().getSigningAlgorithms())
.sign(signingParameters);
builder.sigAlg(query.get(Saml2ParameterNames.SIG_ALG))
.signature(query.get(Saml2ParameterNames.SIGNATURE));
}
return (T) builder.build();
}
}
private String serialize(AuthnRequest authnRequest) {
return this.saml.serialize(authnRequest).serialize();
}
private static final class PathPatternQueryRequestMatcher implements RequestMatcher {
private final RequestMatcher matcher;
PathPatternQueryRequestMatcher(String path, String... params) {
List<RequestMatcher> matchers = new ArrayList<>();
matchers.add(pathPattern(path));
for (String param : params) {
String[] parts = param.split("=");
if (parts.length == 1) {
matchers.add(new ParameterRequestMatcher(parts[0]));
}
else {
matchers.add(new ParameterRequestMatcher(parts[0], parts[1]));
}
}
this.matcher = new AndRequestMatcher(matchers);
}
@Override
public boolean matches(HttpServletRequest request) {
return matcher(request).isMatch();
}
@Override
public MatchResult matcher(HttpServletRequest request) {
return this.matcher.matcher(request);
}
}
static final class AuthnRequestParameters {
private final HttpServletRequest request;
private final RelyingPartyRegistration registration;
private final AuthnRequest authnRequest;
AuthnRequestParameters(HttpServletRequest request, RelyingPartyRegistration registration,
AuthnRequest authnRequest) {
this.request = request;
this.registration = registration;
this.authnRequest = authnRequest;
}
HttpServletRequest getRequest() {
return this.request;
}
RelyingPartyRegistration getRelyingPartyRegistration() {
return this.registration;
}
AuthnRequest getAuthnRequest() {
return this.authnRequest;
}
}
}