SAMLAttributeValueParser.java

/*
 * Copyright 2016 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.keycloak.saml.processing.core.parsers.saml.assertion;

import java.util.Deque;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;
import javax.xml.stream.XMLEventFactory;
import javax.xml.stream.XMLStreamException;
import javax.xml.stream.events.Namespace;
import org.keycloak.saml.common.PicketLinkLogger;
import org.keycloak.saml.common.PicketLinkLoggerFactory;
import org.keycloak.saml.common.constants.JBossSAMLURIConstants;
import org.keycloak.saml.common.exceptions.ParsingException;
import org.keycloak.saml.common.parsers.StaxParser;
import org.keycloak.saml.common.util.StaxParserUtil;
import org.keycloak.saml.processing.core.parsers.util.SAMLParserUtil;
import org.keycloak.saml.processing.core.saml.v2.util.XMLTimeUtil;
import java.io.StringWriter;
import java.util.Objects;
import javax.xml.namespace.QName;
import javax.xml.stream.XMLEventReader;
import javax.xml.stream.XMLEventWriter;
import javax.xml.stream.XMLOutputFactory;
import javax.xml.stream.events.Attribute;
import javax.xml.stream.events.EndElement;
import javax.xml.stream.events.StartElement;
import javax.xml.stream.events.XMLEvent;

/**
 *
 */
public class SAMLAttributeValueParser implements StaxParser {

    private static final PicketLinkLogger logger = PicketLinkLoggerFactory.getLogger();

    private static final SAMLAttributeValueParser INSTANCE = new SAMLAttributeValueParser();
    private static final QName NIL = new QName(JBossSAMLURIConstants.XSI_NSURI.get(), "nil", JBossSAMLURIConstants.XSI_PREFIX.get());
    private static final QName XSI_TYPE = new QName(JBossSAMLURIConstants.XSI_NSURI.get(), "type", JBossSAMLURIConstants.XSI_PREFIX.get());

    private static final ThreadLocal<XMLEventFactory> XML_EVENT_FACTORY = ThreadLocal.withInitial(XMLEventFactory::newInstance);

    public static SAMLAttributeValueParser getInstance() {
        return INSTANCE;
    }

    @Override
    public Object parse(XMLEventReader xmlEventReader) throws ParsingException {
        StartElement element = StaxParserUtil.getNextStartElement(xmlEventReader);
        StaxParserUtil.validate(element, SAMLAssertionQNames.ATTRIBUTE_VALUE.getQName());

        Attribute nil = element.getAttributeByName(NIL);
        if (nil != null) {
            String nilValue = StaxParserUtil.getAttributeValue(nil);
            if (nilValue != null && (nilValue.equalsIgnoreCase("true") || nilValue.equals("1"))) {
                String elementText = StaxParserUtil.getElementText(xmlEventReader);
                if (elementText == null || elementText.isEmpty()) {
                    return null;
                } else {
                    throw logger.nullValueError("nil attribute is not in SAML20 format");
                }
            } else {
                throw logger.parserRequiredAttribute(JBossSAMLURIConstants.XSI_PREFIX.get() + ":nil");
            }
        }

        Attribute type = element.getAttributeByName(XSI_TYPE);
        if (type == null) {
            if (StaxParserUtil.hasTextAhead(xmlEventReader)) {
                return StaxParserUtil.getElementText(xmlEventReader);
            }
            // Else we may have Child Element
            XMLEvent xmlEvent = StaxParserUtil.peek(xmlEventReader);
            if (xmlEvent instanceof StartElement) {
                element = (StartElement) xmlEvent;
                final QName qName = element.getName();
                if (Objects.equals(qName, SAMLAssertionQNames.NAMEID.getQName())) {
                    return SAMLParserUtil.parseNameIDType(xmlEventReader);
                }
            } else if (xmlEvent instanceof EndElement) {
                return "";
            }

            // when no type attribute assigned -> assume anyType
            return parseAsString(xmlEventReader);
        }

        //      RK Added an additional type check for base64Binary type as calheers is passing this type
        String typeValue = StaxParserUtil.getAttributeValue(type);
        if (typeValue.contains(":string")) {
            return StaxParserUtil.getElementText(xmlEventReader);
        } else if (typeValue.contains(":anyType")) {
            return parseAsString(xmlEventReader);
        } else if(typeValue.contains(":base64Binary")){
            return StaxParserUtil.getElementText(xmlEventReader);
        } else if(typeValue.contains(":date")){
            return XMLTimeUtil.parse(StaxParserUtil.getElementText(xmlEventReader));
        } else if(typeValue.contains(":boolean")){
            return StaxParserUtil.getElementText(xmlEventReader);
        }

        return parseAsString(xmlEventReader);
    }

    private static String parseAsString(XMLEventReader xmlEventReader) throws ParsingException {
        try {
            if (xmlEventReader.peek().isStartElement()) {
                StringWriter sw = new StringWriter();
                XMLEventWriter writer = XMLOutputFactory.newInstance().createXMLEventWriter(sw);
                Deque<Map<String, String>> definedNamespaces = new LinkedList<>();
                int tagLevel = 0;
                while (xmlEventReader.hasNext() && (tagLevel > 0 || !xmlEventReader.peek().isEndElement())) {
                    XMLEvent event = (XMLEvent) xmlEventReader.next();
                    writer.add(event);
                    if (event.isStartElement()) {
                        definedNamespaces.push(addNamespaceWhenMissing(definedNamespaces, writer, event.asStartElement()));
                        tagLevel++;
                    }
                    if (event.isEndElement()) {
                        definedNamespaces.pop();
                        tagLevel--;
                    }
                }
                writer.close();
                return sw.toString();
            } else {
                return StaxParserUtil.getElementText(xmlEventReader);
            }
        } catch (Exception e) {
            throw logger.parserError(e);
        }
    }

    private static Map<String, String> addNamespaceWhenMissing(Deque<Map<String, String>> definedNamespaces, XMLEventWriter writer,
            StartElement startElement) throws XMLStreamException {

        final Map<String, String> necessaryNamespaces = new HashMap<>();
        // Namespace in tag
        if (startElement.getName().getPrefix() != null && !startElement.getName().getPrefix().isEmpty()) {
            necessaryNamespaces.put(startElement.getName().getPrefix(), startElement.getName().getNamespaceURI());
        }
        // Namespaces in attributes
        final Iterator<Attribute> attributes = startElement.getAttributes();
        while (attributes.hasNext()) {
            final Attribute attribute = attributes.next();
            if (attribute.getName().getPrefix() != null && !attribute.getName().getPrefix().isEmpty()) {
                necessaryNamespaces.put(attribute.getName().getPrefix(), attribute.getName().getNamespaceURI());
            }
        }

        // Already contained in stack
        necessaryNamespaces.entrySet().removeIf(nn -> definedNamespaces.stream().anyMatch(dn -> dn.containsKey(nn.getKey())));
        // Contained in current element
        Iterator<Namespace> namespaces = startElement.getNamespaces();
        while (namespaces.hasNext() && !necessaryNamespaces.isEmpty()) {
            necessaryNamespaces.remove(namespaces.next().getPrefix());
        }

        // Add all remaining necessaryNamespaces
        if (!necessaryNamespaces.isEmpty()) {
            XMLEventFactory xmlEventFactory = XML_EVENT_FACTORY.get();
            for (Map.Entry<String, String> entry : necessaryNamespaces.entrySet()) {
                writer.add(xmlEventFactory.createNamespace(entry.getKey(), entry.getValue()));
            }
        }
        return necessaryNamespaces;
    }
}