DefaultSavedRequest.java

/*
 * Copyright 2004-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.web.savedrequest;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.TreeMap;

import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import jakarta.servlet.http.Cookie;
import jakarta.servlet.http.HttpServletRequest;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.jspecify.annotations.Nullable;

import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.web.util.UriComponentsBuilder;

/**
 * Represents central information from a {@code HttpServletRequest}.
 * <p>
 * This class is used by
 * {@link org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter}
 * and {@link org.springframework.security.web.savedrequest.SavedRequestAwareWrapper} to
 * reproduce the request after successful authentication. An instance of this class is
 * stored at the time of an authentication exception by
 * {@link org.springframework.security.web.access.ExceptionTranslationFilter}.
 * <p>
 * <em>IMPLEMENTATION NOTE</em>: It is assumed that this object is accessed only from the
 * context of a single thread, so no synchronization around internal collection classes is
 * performed.
 * <p>
 * This class is based on code in Apache Tomcat.
 *
 * @author Craig McClanahan
 * @author Andrey Grebnev
 * @author Ben Alex
 * @author Luke Taylor
 */
public class DefaultSavedRequest implements SavedRequest {

	private static final long serialVersionUID = 620L;

	protected static final Log logger = LogFactory.getLog(DefaultSavedRequest.class);

	private static final String HEADER_IF_NONE_MATCH = "If-None-Match";

	private static final String HEADER_IF_MODIFIED_SINCE = "If-Modified-Since";

	private final ArrayList<SavedCookie> cookies = new ArrayList<>();

	private final ArrayList<Locale> locales = new ArrayList<>();

	private final Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);

	private final Map<String, String[]> parameters = new TreeMap<>();

	private final @Nullable String contextPath;

	private final String method;

	private final @Nullable String pathInfo;

	private final @Nullable String queryString;

	private final String requestURI;

	private final @Nullable String requestURL;

	private final String scheme;

	private final String serverName;

	private final @Nullable String servletPath;

	private final int serverPort;

	private final @Nullable String matchingRequestParameterName;

	public DefaultSavedRequest(HttpServletRequest request) {
		this(request, (String) null);
	}

	public DefaultSavedRequest(HttpServletRequest request, @Nullable String matchingRequestParameterName) {
		Assert.notNull(request, "Request required");
		// Cookies
		addCookies(request.getCookies());
		// Headers
		Enumeration<String> names = request.getHeaderNames();
		while (names.hasMoreElements()) {
			String name = names.nextElement();
			// Skip If-Modified-Since and If-None-Match header. SEC-1412, SEC-1624.
			if (HEADER_IF_MODIFIED_SINCE.equalsIgnoreCase(name) || HEADER_IF_NONE_MATCH.equalsIgnoreCase(name)) {
				continue;
			}
			Enumeration<String> values = request.getHeaders(name);
			while (values.hasMoreElements()) {
				this.addHeader(name, values.nextElement());
			}
		}
		// Locales
		addLocales(request.getLocales());
		// Parameters
		addParameters(request.getParameterMap());
		// Primitives
		this.method = request.getMethod();
		this.pathInfo = request.getPathInfo();
		this.queryString = request.getQueryString();
		this.requestURI = request.getRequestURI();
		this.serverPort = request.getServerPort();
		this.requestURL = request.getRequestURL().toString();
		this.scheme = request.getScheme();
		this.serverName = request.getServerName();
		this.contextPath = request.getContextPath();
		this.servletPath = request.getServletPath();
		this.matchingRequestParameterName = matchingRequestParameterName;
	}

	/**
	 * Private constructor invoked through Builder
	 */
	private DefaultSavedRequest(Builder builder) {
		this.contextPath = builder.contextPath;
		this.method = (builder.method != null) ? builder.method : "GET";
		this.pathInfo = builder.pathInfo;
		this.queryString = builder.queryString;
		this.requestURI = Objects.requireNonNull(builder.requestURI);
		this.requestURL = builder.requestURL;
		this.scheme = Objects.requireNonNull(builder.scheme);
		this.serverName = Objects.requireNonNull(builder.serverName);
		this.servletPath = builder.servletPath;
		this.serverPort = builder.serverPort;
		this.matchingRequestParameterName = builder.matchingRequestParameterName;
	}

	/**
	 * @since 4.2
	 */
	private void addCookies(Cookie[] cookies) {
		if (cookies != null) {
			for (Cookie cookie : cookies) {
				this.addCookie(cookie);
			}
		}
	}

	private void addCookie(Cookie cookie) {
		this.cookies.add(new SavedCookie(cookie));
	}

	private void addHeader(String name, String value) {
		List<String> values = this.headers.computeIfAbsent(name, (key) -> new ArrayList<>());
		values.add(value);
	}

	/**
	 * @since 4.2
	 */
	private void addLocales(Enumeration<Locale> locales) {
		while (locales.hasMoreElements()) {
			Locale locale = locales.nextElement();
			this.addLocale(locale);
		}
	}

	private void addLocale(Locale locale) {
		this.locales.add(locale);
	}

	/**
	 * @since 4.2
	 */
	private void addParameters(Map<String, String[]> parameters) {
		if (!ObjectUtils.isEmpty(parameters)) {
			for (String paramName : parameters.keySet()) {
				Object paramValues = parameters.get(paramName);
				if (paramValues instanceof String[]) {
					this.addParameter(paramName, (String[]) paramValues);
				}
				else {
					logger.warn("ServletRequest.getParameterMap() returned non-String array");
				}
			}
		}
	}

	private void addParameter(String name, String[] values) {
		this.parameters.put(name, values);
	}

	public @Nullable String getContextPath() {
		return this.contextPath;
	}

	@Override
	public List<Cookie> getCookies() {
		List<Cookie> cookieList = new ArrayList<>(this.cookies.size());
		for (SavedCookie savedCookie : this.cookies) {
			cookieList.add(savedCookie.getCookie());
		}
		return cookieList;
	}

	/**
	 * Indicates the URL that the user agent used for this request.
	 * @return the full URL of this request
	 */
	@Override
	public String getRedirectUrl() {
		String queryString = createQueryString(this.queryString, this.matchingRequestParameterName);
		return UrlUtils.buildFullRequestUrl(this.scheme, this.serverName, this.serverPort, this.requestURI,
				queryString);
	}

	@Override
	public Collection<String> getHeaderNames() {
		return this.headers.keySet();
	}

	@Override
	public List<String> getHeaderValues(String name) {
		List<String> values = this.headers.get(name);
		return (values != null) ? values : Collections.emptyList();
	}

	@Override
	public List<Locale> getLocales() {
		return this.locales;
	}

	@Override
	public String getMethod() {
		return this.method;
	}

	@Override
	public Map<String, String[]> getParameterMap() {
		return this.parameters;
	}

	public Collection<String> getParameterNames() {
		return this.parameters.keySet();
	}

	@Override
	public String @Nullable [] getParameterValues(String name) {
		return this.parameters.get(name);
	}

	public @Nullable String getPathInfo() {
		return this.pathInfo;
	}

	public @Nullable String getQueryString() {
		return (this.queryString);
	}

	public @Nullable String getRequestURI() {
		return (this.requestURI);
	}

	public @Nullable String getRequestURL() {
		return this.requestURL;
	}

	public @Nullable String getScheme() {
		return this.scheme;
	}

	public @Nullable String getServerName() {
		return this.serverName;
	}

	public int getServerPort() {
		return this.serverPort;
	}

	public @Nullable String getServletPath() {
		return this.servletPath;
	}

	private boolean propertyEquals(@Nullable Object arg1, Object arg2) {
		if ((arg1 == null) && (arg2 == null)) {
			return true;
		}
		if (arg1 == null || arg2 == null) {
			return false;
		}
		return arg1.equals(arg2);
	}

	@Override
	public String toString() {
		return "DefaultSavedRequest [" + getRedirectUrl() + "]";
	}

	private static @Nullable String createQueryString(@Nullable String queryString,
			@Nullable String matchingRequestParameterName) {
		if (matchingRequestParameterName == null) {
			return queryString;
		}
		if (queryString == null || queryString.length() == 0) {
			return matchingRequestParameterName;
		}
		return UriComponentsBuilder.newInstance()
			.query(queryString)
			.replaceQueryParam(matchingRequestParameterName)
			.queryParam(matchingRequestParameterName)
			.build()
			.getQuery();
	}

	/**
	 * @since 4.2
	 */
	@JsonIgnoreProperties(ignoreUnknown = true)
	@com.fasterxml.jackson.databind.annotation.JsonPOJOBuilder(withPrefix = "set")
	@tools.jackson.databind.annotation.JsonPOJOBuilder(withPrefix = "set")
	public static class Builder {

		private @Nullable List<SavedCookie> cookies = null;

		private @Nullable List<Locale> locales = null;

		private Map<String, List<String>> headers = new TreeMap<>(String.CASE_INSENSITIVE_ORDER);

		private Map<String, String[]> parameters = new TreeMap<>();

		private @Nullable String contextPath;

		private @Nullable String method;

		private @Nullable String pathInfo;

		private @Nullable String queryString;

		private @Nullable String requestURI;

		private @Nullable String requestURL;

		private @Nullable String scheme;

		private @Nullable String serverName;

		private @Nullable String servletPath;

		private int serverPort = 80;

		private @Nullable String matchingRequestParameterName;

		public Builder setCookies(List<SavedCookie> cookies) {
			this.cookies = cookies;
			return this;
		}

		public Builder setLocales(List<Locale> locales) {
			this.locales = locales;
			return this;
		}

		public Builder setHeaders(Map<String, List<String>> header) {
			this.headers.putAll(header);
			return this;
		}

		public Builder setParameters(Map<String, String[]> parameters) {
			this.parameters = parameters;
			return this;
		}

		public Builder setContextPath(String contextPath) {
			this.contextPath = contextPath;
			return this;
		}

		public Builder setMethod(String method) {
			this.method = method;
			return this;
		}

		public Builder setPathInfo(String pathInfo) {
			this.pathInfo = pathInfo;
			return this;
		}

		public Builder setQueryString(@Nullable String queryString) {
			this.queryString = queryString;
			return this;
		}

		public Builder setRequestURI(@Nullable String requestURI) {
			this.requestURI = requestURI;
			return this;
		}

		public Builder setRequestURL(String requestURL) {
			this.requestURL = requestURL;
			return this;
		}

		public Builder setScheme(@Nullable String scheme) {
			this.scheme = scheme;
			return this;
		}

		public Builder setServerName(@Nullable String serverName) {
			this.serverName = serverName;
			return this;
		}

		public Builder setServletPath(String servletPath) {
			this.servletPath = servletPath;
			return this;
		}

		public Builder setServerPort(int serverPort) {
			this.serverPort = serverPort;
			return this;
		}

		public Builder setMatchingRequestParameterName(String matchingRequestParameterName) {
			this.matchingRequestParameterName = matchingRequestParameterName;
			return this;
		}

		public DefaultSavedRequest build() {
			DefaultSavedRequest savedRequest = new DefaultSavedRequest(this);
			if (!ObjectUtils.isEmpty(this.cookies)) {
				for (SavedCookie cookie : this.cookies) {
					savedRequest.addCookie(cookie.getCookie());
				}
			}
			if (!ObjectUtils.isEmpty(this.locales)) {
				savedRequest.locales.addAll(this.locales);
			}
			savedRequest.addParameters(this.parameters);
			this.headers.remove(HEADER_IF_MODIFIED_SINCE);
			this.headers.remove(HEADER_IF_NONE_MATCH);
			for (Map.Entry<String, List<String>> entry : this.headers.entrySet()) {
				String headerName = entry.getKey();
				List<String> headerValues = entry.getValue();
				for (String headerValue : headerValues) {
					savedRequest.addHeader(headerName, headerValue);
				}
			}
			return savedRequest;
		}

	}

}