WebXmlServletMappingExtractor.java

/*******************************************************************************
 * Copyright (c) 2025 Eclipse RDF4J contributors.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Distribution License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/org/documents/edl-v10.php.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 *******************************************************************************/
// Some portions generated by Codex
package org.eclipse.rdf4j.tools.serverboot;

import java.io.InputStream;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set;

import javax.xml.parsers.DocumentBuilderFactory;

import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.w3c.dom.Document;
import org.w3c.dom.Node;
import org.w3c.dom.NodeList;

final class WebXmlServletMappingExtractor {

	private WebXmlServletMappingExtractor() {
	}

	static List<String> extractMappings(String resourceLocation, String servletName, String contextPrefix,
			boolean includeBasePatterns) {
		List<String> basePatterns = readServletUrlPatterns(resourceLocation, servletName);
		return expandUrlPatterns(basePatterns, contextPrefix, includeBasePatterns);
	}

	private static List<String> readServletUrlPatterns(String resourceLocation, String servletName) {
		Resource resource = new ClassPathResource(resourceLocation);
		if (!resource.exists()) {
			throw new IllegalStateException("Missing resource " + resourceLocation);
		}
		try (InputStream inputStream = resource.getInputStream()) {
			DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
			factory.setNamespaceAware(false);
			Document document = factory.newDocumentBuilder().parse(inputStream);
			NodeList mappings = document.getElementsByTagName("servlet-mapping");
			List<String> patterns = new ArrayList<>();
			for (int i = 0; i < mappings.getLength(); i++) {
				Node mapping = mappings.item(i);
				String name = childText(mapping, "servlet-name");
				if (!servletName.equals(name)) {
					continue;
				}
				NodeList children = mapping.getChildNodes();
				for (int j = 0; j < children.getLength(); j++) {
					Node child = children.item(j);
					if ("url-pattern".equals(child.getNodeName())) {
						String pattern = child.getTextContent();
						if (pattern != null && !pattern.isBlank()) {
							patterns.add(pattern.trim());
						}
					}
				}
			}
			if (patterns.isEmpty()) {
				throw new IllegalStateException(
						"No servlet-mapping entries found for " + servletName + " in " + resourceLocation);
			}
			return patterns;
		} catch (Exception e) {
			throw new IllegalStateException(
					"Failed to parse servlet mappings for " + servletName + " from " + resourceLocation, e);
		}
	}

	private static List<String> expandUrlPatterns(List<String> basePatterns, String contextPrefix,
			boolean includeBasePatterns) {
		Set<String> expanded = new LinkedHashSet<>();
		for (String pattern : basePatterns) {
			if (pattern == null || pattern.isEmpty()) {
				continue;
			}
			if (includeBasePatterns) {
				expanded.add(pattern);
				if (pattern.endsWith("/*")) {
					expanded.add(pattern.substring(0, pattern.length() - 2));
				}
			}
			if (pattern.startsWith("*")) {
				continue;
			}
			String normalizedPattern = pattern.startsWith("/") ? pattern : "/" + pattern;
			if (contextPrefix != null && !contextPrefix.isBlank()) {
				String prefixed = contextPrefix + normalizedPattern;
				expanded.add(prefixed);
				if (prefixed.endsWith("/*")) {
					expanded.add(prefixed.substring(0, prefixed.length() - 2));
				}
			} else if (!includeBasePatterns) {
				expanded.add(normalizedPattern);
				if (normalizedPattern.endsWith("/*")) {
					expanded.add(normalizedPattern.substring(0, normalizedPattern.length() - 2));
				}
			}
		}
		return new ArrayList<>(expanded);
	}

	private static String childText(Node parent, String childName) {
		NodeList children = parent.getChildNodes();
		for (int i = 0; i < children.getLength(); i++) {
			Node child = children.item(i);
			if (childName.equals(child.getNodeName())) {
				return child.getTextContent() != null ? child.getTextContent().trim() : null;
			}
		}
		return null;
	}
}