HttpSessionLogoutRequestRepository.java

/*
 * Copyright 2004-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.saml2.provider.service.web.authentication.logout;

import java.security.MessageDigest;

import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;

import org.springframework.security.crypto.codec.Utf8;
import org.springframework.security.saml2.core.Saml2ParameterNames;
import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
import org.springframework.util.Assert;

/**
 * An implementation of an {@link Saml2LogoutRequestRepository} that stores
 * {@link Saml2LogoutRequest} in the {@code HttpSession}.
 *
 * @author Josh Cummings
 * @since 5.6
 * @see Saml2LogoutRequestRepository
 * @see Saml2LogoutRequest
 */
public final class HttpSessionLogoutRequestRepository implements Saml2LogoutRequestRepository {

	private static final String DEFAULT_LOGOUT_REQUEST_ATTR_NAME = HttpSessionLogoutRequestRepository.class.getName()
			+ ".LOGOUT_REQUEST";

	/**
	 * {@inheritDoc}
	 */
	@Override
	public Saml2LogoutRequest loadLogoutRequest(HttpServletRequest request) {
		Assert.notNull(request, "request cannot be null");
		HttpSession session = request.getSession(false);
		if (session == null) {
			return null;
		}
		Saml2LogoutRequest logoutRequest = (Saml2LogoutRequest) session.getAttribute(DEFAULT_LOGOUT_REQUEST_ATTR_NAME);
		if (stateParameterEquals(request, logoutRequest)) {
			return logoutRequest;
		}
		return null;
	}

	/**
	 * {@inheritDoc}
	 */
	@Override
	public void saveLogoutRequest(Saml2LogoutRequest logoutRequest, HttpServletRequest request,
			HttpServletResponse response) {
		Assert.notNull(request, "request cannot be null");
		Assert.notNull(response, "response cannot be null");
		if (logoutRequest == null) {
			request.getSession().removeAttribute(DEFAULT_LOGOUT_REQUEST_ATTR_NAME);
			return;
		}
		String state = logoutRequest.getRelayState();
		Assert.hasText(state, "logoutRequest.state cannot be empty");
		request.getSession().setAttribute(DEFAULT_LOGOUT_REQUEST_ATTR_NAME, logoutRequest);
	}

	/**
	 * {@inheritDoc}
	 */
	@Override
	public Saml2LogoutRequest removeLogoutRequest(HttpServletRequest request, HttpServletResponse response) {
		Assert.notNull(request, "request cannot be null");
		Assert.notNull(response, "response cannot be null");
		Saml2LogoutRequest logoutRequest = loadLogoutRequest(request);
		if (logoutRequest == null) {
			return null;
		}
		request.getSession().removeAttribute(DEFAULT_LOGOUT_REQUEST_ATTR_NAME);
		return logoutRequest;
	}

	private String getStateParameter(HttpServletRequest request) {
		return request.getParameter(Saml2ParameterNames.RELAY_STATE);
	}

	private boolean stateParameterEquals(HttpServletRequest request, Saml2LogoutRequest logoutRequest) {
		String stateParameter = getStateParameter(request);
		if (stateParameter == null || logoutRequest == null) {
			return false;
		}
		String relayState = logoutRequest.getRelayState();
		return MessageDigest.isEqual(Utf8.encode(stateParameter), Utf8.encode(relayState));
	}

}