JdbcAssertingPartyMetadataRepository.java
/*
* Copyright 2004-present the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.saml2.provider.service.registration;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.function.Function;
import org.springframework.core.serializer.DefaultDeserializer;
import org.springframework.core.serializer.DefaultSerializer;
import org.springframework.core.serializer.Deserializer;
import org.springframework.core.serializer.Serializer;
import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcOperations;
import org.springframework.jdbc.core.PreparedStatementSetter;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.jdbc.core.SqlParameterValue;
import org.springframework.security.saml2.core.Saml2X509Credential;
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
import org.springframework.util.Assert;
import org.springframework.util.function.ThrowingFunction;
/**
* A JDBC implementation of {@link AssertingPartyMetadataRepository}.
*
* @author Cathy Wang
* @since 7.0
*/
public final class JdbcAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository {
private final JdbcOperations jdbcOperations;
private final RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper();
private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper();
// @formatter:off
static final String[] COLUMN_NAMES = { "entity_id",
"single_sign_on_service_location",
"single_sign_on_service_binding",
"want_authn_requests_signed",
"signing_algorithms",
"verification_credentials",
"encryption_credentials",
"single_logout_service_location",
"single_logout_service_response_location",
"single_logout_service_binding" };
// @formatter:on
private static final String TABLE_NAME = "saml2_asserting_party_metadata";
private static final String ENTITY_ID_FILTER = "entity_id = ?";
// @formatter:off
private static final String LOAD_BY_ID_SQL = "SELECT " + String.join(",", COLUMN_NAMES)
+ " FROM " + TABLE_NAME
+ " WHERE " + ENTITY_ID_FILTER;
private static final String LOAD_ALL_SQL = "SELECT " + String.join(",", COLUMN_NAMES)
+ " FROM " + TABLE_NAME;
// @formatter:on
// @formatter:off
private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
+ " (" + String.join(",", COLUMN_NAMES) + ") VALUES (" + String.join(",", Collections.nCopies(COLUMN_NAMES.length, "?")) + ")";
// @formatter:on
// @formatter:off
private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME
+ " SET " + String.join(" = ?,", Arrays.copyOfRange(COLUMN_NAMES, 1, COLUMN_NAMES.length))
+ " = ?"
+ " WHERE " + ENTITY_ID_FILTER;
// @formatter:on
/**
* Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
* parameters.
* @param jdbcOperations the JDBC operations
*/
public JdbcAssertingPartyMetadataRepository(JdbcOperations jdbcOperations) {
Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
this.jdbcOperations = jdbcOperations;
}
@Override
public AssertingPartyMetadata findByEntityId(String entityId) {
Assert.hasText(entityId, "entityId cannot be empty");
SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) };
PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss,
this.assertingPartyMetadataRowMapper);
return !result.isEmpty() ? result.get(0) : null;
}
@Override
public Iterator<AssertingPartyMetadata> iterator() {
List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_ALL_SQL,
this.assertingPartyMetadataRowMapper);
return result.iterator();
}
/**
* Persist this {@link AssertingPartyMetadata}
* @param metadata the metadata to persist
*/
public void save(AssertingPartyMetadata metadata) {
Assert.notNull(metadata, "metadata cannot be null");
int rows = updateCredentialRecord(metadata);
if (rows == 0) {
insertCredentialRecord(metadata);
}
}
private void insertCredentialRecord(AssertingPartyMetadata metadata) {
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray());
}
private int updateCredentialRecord(AssertingPartyMetadata metadata) {
List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
SqlParameterValue credentialId = parameters.remove(0);
parameters.add(credentialId);
return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray());
}
/**
* The default {@link RowMapper} that maps the current row in
* {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
*/
private static final class AssertingPartyMetadataRowMapper implements RowMapper<AssertingPartyMetadata> {
private final Deserializer<Object> deserializer = new DefaultDeserializer();
@Override
public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException {
String entityId = rs.getString(COLUMN_NAMES[0]);
String singleSignOnUrl = rs.getString(COLUMN_NAMES[1]);
Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString(COLUMN_NAMES[2]));
boolean singleSignOnSignRequest = rs.getBoolean(COLUMN_NAMES[3]);
List<String> algorithms = List.of(rs.getString(COLUMN_NAMES[4]).split(","));
byte[] verificationCredentialsBytes = rs.getBytes(COLUMN_NAMES[5]);
byte[] encryptionCredentialsBytes = rs.getBytes(COLUMN_NAMES[6]);
ThrowingFunction<byte[], Collection<Saml2X509Credential>> credentials = (
bytes) -> (Collection<Saml2X509Credential>) this.deserializer.deserializeFromByteArray(bytes);
AssertingPartyMetadata.Builder<?> builder = new AssertingPartyDetails.Builder();
Collection<Saml2X509Credential> verificationCredentials = credentials.apply(verificationCredentialsBytes);
Collection<Saml2X509Credential> encryptionCredentials = (encryptionCredentialsBytes != null)
? credentials.apply(encryptionCredentialsBytes) : List.of();
String singleLogoutUrl = rs.getString(COLUMN_NAMES[7]);
String singleLogoutResponseUrl = rs.getString(COLUMN_NAMES[8]);
Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString(COLUMN_NAMES[9]));
builder.entityId(entityId)
.wantAuthnRequestsSigned(singleSignOnSignRequest)
.singleSignOnServiceLocation(singleSignOnUrl)
.singleSignOnServiceBinding(singleSignOnBinding)
.singleLogoutServiceLocation(singleLogoutUrl)
.singleLogoutServiceBinding(singleLogoutBinding)
.singleLogoutServiceResponseLocation(singleLogoutResponseUrl)
.signingAlgorithms((a) -> a.addAll(algorithms))
.verificationX509Credentials((c) -> c.addAll(verificationCredentials))
.encryptionX509Credentials((c) -> c.addAll(encryptionCredentials));
return builder.build();
}
}
private static class AssertingPartyMetadataParametersMapper
implements Function<AssertingPartyMetadata, List<SqlParameterValue>> {
private final Serializer<Object> serializer = new DefaultSerializer();
@Override
public List<SqlParameterValue> apply(AssertingPartyMetadata record) {
List<SqlParameterValue> parameters = new ArrayList<>();
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
parameters.add(new SqlParameterValue(Types.BLOB, String.join(",", record.getSigningAlgorithms())));
ThrowingFunction<Collection<Saml2X509Credential>, byte[]> credentials = this.serializer::serializeToByteArray;
parameters
.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials())));
parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials())));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation()));
parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn()));
return parameters;
}
}
}