AbstractSTSClientTest.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.cxf.ws.security.trust;

import javax.xml.XMLConstants;

import org.w3c.dom.Document;
import org.w3c.dom.Element;

import org.apache.cxf.Bus;
import org.apache.cxf.helpers.DOMUtils;
import org.apache.cxf.ws.mex.model._2004_09.MetadataSection;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

public class AbstractSTSClientTest {

    @Test
    public void testSchemaLocationDownloadDisabledByDefault() throws Exception {
        TestableAbstractSTSClient client = new TestableAbstractSTSClient(null);
        MetadataSection section = new MetadataSection();
        section.setDialect(XMLConstants.W3C_XML_SCHEMA_NS_URI);
        section.setLocation("http://example.org/schema.xsd");

        assertFalse(client.isAllowMexMetadataSchemaLocation());
        assertTrue(client.getSchemaElement((Element)section.getAny(), section.getLocation()) == null);
    }

    @Test
    public void testSchemaLocationDownloadAllowedWhenEnabled() throws Exception {
        TestableAbstractSTSClient client = new TestableAbstractSTSClient(null);
        MetadataSection section = new MetadataSection();
        section.setDialect(XMLConstants.W3C_XML_SCHEMA_NS_URI);
        section.setLocation("http://example.org/schema.xsd");

        client.setAllowMexMetadataSchemaLocation(true);

        Element schemaElement = client.getSchemaElement((Element)section.getAny(), section.getLocation());
        assertEquals(1, client.getDownloadSchemaInvocations());
        assertEquals("http://example.org/schema.xsd", client.getLastDownloadedLocation());
        assertEquals(XMLConstants.W3C_XML_SCHEMA_NS_URI, schemaElement.getNamespaceURI());
        assertEquals("schema", schemaElement.getLocalName());
    }

    @Test
    public void testInlineSchemaElementDoesNotDownload() throws Exception {
        TestableAbstractSTSClient client = new TestableAbstractSTSClient(null);
        MetadataSection section = new MetadataSection();
        section.setDialect(XMLConstants.W3C_XML_SCHEMA_NS_URI);

        Document document = DOMUtils.createDocument();
        Element inlineSchema = document.createElementNS(XMLConstants.W3C_XML_SCHEMA_NS_URI, "xsd:schema");
        section.setAny(inlineSchema);

        Element schemaElement = client.getSchemaElement((Element)section.getAny(), section.getLocation());
        assertSame(inlineSchema, schemaElement);
        assertEquals(0, client.getDownloadSchemaInvocations());
    }

    @Test
    public void testFtpProtocolAttemptedDownload() throws Exception {
        DownloadingAbstractSTSClient client = new DownloadingAbstractSTSClient(null);
        try {
            client.downloadSchemaWithDefaultResolver("ftp://example.org/schema.xsd");
            fail("Expected an exception for disallowed ftp:// scheme");
        } catch (Exception ex) {
            assertTrue(ex.getMessage().contains("ftp"));
            assertTrue(ex.getMessage().contains("not permitted"));
        }
    }

    private static final class TestableAbstractSTSClient extends AbstractSTSClient {
        private int downloadSchemaInvocations;
        private String lastDownloadedLocation;

        TestableAbstractSTSClient(Bus bus) {
            super(bus);
        }

        @Override
        protected Element downloadSchema(String schemaLocation) {
            downloadSchemaInvocations++;
            lastDownloadedLocation = schemaLocation;
            Document document = DOMUtils.createDocument();
            return document.createElementNS(XMLConstants.W3C_XML_SCHEMA_NS_URI, "xsd:schema");
        }

        int getDownloadSchemaInvocations() {
            return downloadSchemaInvocations;
        }

        String getLastDownloadedLocation() {
            return lastDownloadedLocation;
        }
    }

    private static final class DownloadingAbstractSTSClient extends AbstractSTSClient {
        DownloadingAbstractSTSClient(Bus bus) {
            super(bus);
        }

        Element downloadSchemaWithDefaultResolver(String schemaLocation) throws Exception {
            return super.downloadSchema(schemaLocation);
        }
    }
}