SAMLDecryptionKeysLocator.java

/*
 * Copyright 2023 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.protocol.saml;



import org.apache.xml.security.encryption.EncryptedData;
import org.apache.xml.security.encryption.EncryptedKey;
import org.apache.xml.security.encryption.EncryptionMethod;
import org.apache.xml.security.exceptions.XMLSecurityException;
import org.apache.xml.security.keys.KeyInfo;
import org.apache.xml.security.keys.content.KeyName;
import org.keycloak.common.util.DerUtils;
import org.keycloak.crypto.KeyUse;
import org.keycloak.crypto.KeyWrapper;
import org.keycloak.models.KeycloakSession;
import org.keycloak.models.RealmModel;
import org.keycloak.saml.processing.core.util.XMLEncryptionUtil;

import java.security.Key;
import java.security.PrivateKey;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
 * This implementation locates the decryption keys within realm keys.
 * It filters realm keys based on algorithm provided within {@link EncryptedData}
 *
 * Example of encrypted data:
 * <pre>
 * {@code
 * <xenc:EncryptedData Type="http://www.w3.org/2001/04/xmlenc#Element">
 *     <xenc:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#aes128-cbc"/>
 *     <ds:KeyInfo>
 *         <xenc:EncryptedKey>
 *             <xenc:EncryptionMethod Algorithm="http://www.w3.org/2001/04/xmlenc#rsa-oaep-mgf1p"/>
 *             <xenc:CipherData>
 *                 <xenc:CipherValue>
 *                     .....
 *                 </xenc:CipherValue>
 *             </xenc:CipherData>
 *         </xenc:EncryptedKey>
 *     </ds:KeyInfo>
 *     <xenc:CipherData>
 *         <xenc:CipherValue>
 *             ...
 *         </xenc:CipherValue>
 *     </xenc:CipherData>
 * </xenc:EncryptedData>
 * }
 * </pre>
 *
 */
public class SAMLDecryptionKeysLocator implements XMLEncryptionUtil.DecryptionKeyLocator {

    private final KeycloakSession session;
    private final RealmModel realm;
    private final String requestedAlgorithm;

    public SAMLDecryptionKeysLocator(KeycloakSession session, RealmModel realm, String requestedAlgorithm) {
        this.session = session;
        this.realm = realm;
        this.requestedAlgorithm = requestedAlgorithm;
    }

    private List<String> getKeyNames(KeyInfo keyInfo) {
        List<String> keyNames = new LinkedList<>();

        try {
            for (int i = 0; i < keyInfo.lengthKeyName(); i++) {
                KeyName keyName = keyInfo.itemKeyName(i);
                if (keyName != null) {
                    keyNames.add(keyName.getKeyName());
                }
            }
        } catch (XMLSecurityException e) {
            throw new IllegalStateException("Cannot load keyNames from document", e);
        }

        return keyNames;
    }

    private Predicate<KeyWrapper> hasMatchingAlgorithm(String algorithm) {
        SAMLEncryptionAlgorithms usedAlgorithm = SAMLEncryptionAlgorithms.forXMLEncIdentifier(algorithm);

        if (usedAlgorithm == null) {
            throw new IllegalStateException("Keycloak does not support encryption keys for given algorithm: " + algorithm);
        }

        return keyWrapper -> Objects.equals(keyWrapper.getAlgorithmOrDefault(), usedAlgorithm.getKeycloakIdentifier());
    }

    @Override
    public List<PrivateKey> getKeys(EncryptedData encryptedData) {
        // Check encryptedData contains keyinfo
        KeyInfo keyInfo = encryptedData.getKeyInfo();
        if (keyInfo == null) {
            throw new IllegalStateException("EncryptedData does not contain KeyInfo");
        }

        Stream<KeyWrapper> keysStream = session.keys().getKeysStream(realm)
                .filter(key -> key.getStatus().isEnabled() && KeyUse.ENC.equals(key.getUse()));

        if (requestedAlgorithm != null && !requestedAlgorithm.trim().isEmpty()) {
            keysStream = keysStream.filter(keyWrapper -> Objects.equals(keyWrapper.getAlgorithmOrDefault(), requestedAlgorithm));
        }

        // If encryptedData contains keyName we will use only for keys with given kid
        if (keyInfo.containsKeyName()) {
            List<String> keyNames = getKeyNames(keyInfo);
            keysStream = keysStream.filter(keyWrapper -> keyNames.contains(keyWrapper.getKid()));
        }

        // Look for algorithm used inside encryptedData and allow only keys generated for specific algorithm
        try {
            EncryptedKey encryptedKey = keyInfo.itemEncryptedKey(0);
            if (encryptedKey != null) {
                EncryptionMethod encryptionMethod = encryptedKey.getEncryptionMethod();

                if (encryptionMethod == null) {
                    throw new IllegalArgumentException("KeyInfo does not contain encryption method");
                }

                String algorithm = encryptionMethod.getAlgorithm();
                if (algorithm == null) {
                    throw new IllegalArgumentException("Not able to find algorithm for given encryption method");
                }
                keysStream = keysStream.filter(hasMatchingAlgorithm(algorithm));
            }
        } catch (XMLSecurityException e) {
            throw new IllegalArgumentException("EncryptedData does not contain KeyInfo ", e);
        }

        // Map keys to PrivateKey
        return keysStream
                .map(KeyWrapper::getPrivateKey)
                .map(Key::getEncoded)
                .map(encoded -> {
                    try {
                        return DerUtils.decodePrivateKey(encoded);
                    } catch (Exception e) {
                        throw new RuntimeException("Could not decode private key.", e);
                    }
                })
                .collect(Collectors.toList());
    }
}