SamlMetadataKeyLocator.java
/*
* Copyright 2023 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.protocol.saml;
import java.security.Key;
import java.security.KeyManagementException;
import java.security.MessageDigest;
import java.security.cert.CertificateException;
import java.util.Iterator;
import java.util.function.Predicate;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.keys.PublicKeyLoader;
import org.keycloak.keys.PublicKeyStorageProvider;
import org.keycloak.rotation.KeyLocator;
/**
* <p>KeyLocator that caches the keys into a PublicKeyStorageProvider.</p>
*
* @author rmartinc
*/
public class SamlMetadataKeyLocator implements KeyLocator {
private final String modelKey;
private final PublicKeyLoader loader;
private final PublicKeyStorageProvider keyStorage;
private final KeyUse use;
public SamlMetadataKeyLocator(String modelKey, PublicKeyLoader loader, KeyUse use, PublicKeyStorageProvider keyStorage) {
this.modelKey = modelKey;
this.loader = loader;
this.keyStorage = keyStorage;
this.use = use;
}
@Override
public Key getKey(String kid) throws KeyManagementException {
if (kid == null) {
return null;
}
// search the key by kid and reload if expired or null
KeyWrapper keyWrapper = keyStorage.getFirstPublicKey(modelKey, sameKidPredicate(kid), loader);
return keyWrapper != null? keyWrapper.getPublicKey() : null;
}
@Override
public Key getKey(Key key) throws KeyManagementException {
if (key == null) {
return null;
}
// search the key and reload if expired or null
KeyWrapper keyWrapper = keyStorage.getFirstPublicKey(modelKey, sameKeyPredicate(key), loader);
return keyWrapper != null? keyWrapper.getPublicKey() : null;
}
@Override
public void refreshKeyCache() {
keyStorage.reloadKeys(modelKey, loader);
}
@Override
public Iterator<Key> iterator() {
// force a refresh if a certificate is expired?
return keyStorage.getKeys(modelKey, loader)
.stream()
.filter(k -> isSameUse(k) && isValidCertificate(k))
.map(KeyWrapper::getPublicKey)
.iterator();
}
private Predicate<KeyWrapper> sameKidPredicate(String kid) {
return keyWrapper -> isSameKid(keyWrapper, kid);
}
private boolean isSameKid(KeyWrapper keyWrapper, String kid) {
String k = keyWrapper.getKid();
if (k == null) {
return false;
}
return k.equals(kid) && isSameUse(keyWrapper) && isValidCertificate(keyWrapper);
}
private Predicate<KeyWrapper> sameKeyPredicate(Key key) {
return keyWrapper -> isSameKey(keyWrapper, key);
}
private boolean isSameKey(KeyWrapper keyWrapper, Key key) {
Key k = keyWrapper.getPublicKey();
if (k == null) {
return false;
}
return isSameUse(keyWrapper)
&& key.getAlgorithm().equals(k.getAlgorithm())
&& MessageDigest.isEqual(k.getEncoded(), key.getEncoded())
&& isValidCertificate(keyWrapper);
}
private boolean isSameUse(KeyWrapper k) {
if (k == null) {
return false;
}
// if key use is null means it is valid for both uses
return k.getUse() == null || k.getUse().equals(this.use);
}
private boolean isValidCertificate(KeyWrapper key) {
if (key == null || key.getCertificate() == null) {
return false;
}
try {
key.getCertificate().checkValidity();
return true;
} catch (CertificateException e) {
return false;
}
}
}