Webauthn4JRelyingPartyOperations.java

/*
 * Copyright 2004-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.security.web.webauthn.management;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.webauthn4j.WebAuthnManager;
import com.webauthn4j.converter.util.CborConverter;
import com.webauthn4j.converter.util.ObjectConverter;
import com.webauthn4j.credential.CredentialRecordImpl;
import com.webauthn4j.data.AuthenticationData;
import com.webauthn4j.data.AuthenticationParameters;
import com.webauthn4j.data.RegistrationData;
import com.webauthn4j.data.RegistrationParameters;
import com.webauthn4j.data.RegistrationRequest;
import com.webauthn4j.data.attestation.AttestationObject;
import com.webauthn4j.data.attestation.authenticator.AttestedCredentialData;
import com.webauthn4j.data.attestation.authenticator.AuthenticatorData;
import com.webauthn4j.data.attestation.authenticator.COSEKey;
import com.webauthn4j.data.client.Origin;
import com.webauthn4j.data.client.challenge.Challenge;
import com.webauthn4j.data.client.challenge.DefaultChallenge;
import com.webauthn4j.data.extension.authenticator.AuthenticationExtensionAuthenticatorOutput;
import com.webauthn4j.data.extension.authenticator.RegistrationExtensionAuthenticatorOutput;
import com.webauthn4j.server.ServerProperty;
import org.jspecify.annotations.NullUnmarked;
import org.jspecify.annotations.Nullable;

import org.springframework.security.authentication.AuthenticationTrustResolver;
import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
import org.springframework.security.core.Authentication;
import org.springframework.security.web.webauthn.api.AttestationConveyancePreference;
import org.springframework.security.web.webauthn.api.AuthenticatorAssertionResponse;
import org.springframework.security.web.webauthn.api.AuthenticatorAttestationResponse;
import org.springframework.security.web.webauthn.api.AuthenticatorSelectionCriteria;
import org.springframework.security.web.webauthn.api.AuthenticatorTransport;
import org.springframework.security.web.webauthn.api.Bytes;
import org.springframework.security.web.webauthn.api.CredentialRecord;
import org.springframework.security.web.webauthn.api.ImmutableAuthenticationExtensionsClientInput;
import org.springframework.security.web.webauthn.api.ImmutableAuthenticationExtensionsClientInputs;
import org.springframework.security.web.webauthn.api.ImmutableCredentialRecord;
import org.springframework.security.web.webauthn.api.ImmutablePublicKeyCose;
import org.springframework.security.web.webauthn.api.ImmutablePublicKeyCredentialUserEntity;
import org.springframework.security.web.webauthn.api.PublicKeyCredential;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialCreationOptions;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialCreationOptions.PublicKeyCredentialCreationOptionsBuilder;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialDescriptor;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialParameters;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialRequestOptions;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialRequestOptions.PublicKeyCredentialRequestOptionsBuilder;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialRpEntity;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialType;
import org.springframework.security.web.webauthn.api.PublicKeyCredentialUserEntity;
import org.springframework.security.web.webauthn.api.ResidentKeyRequirement;
import org.springframework.security.web.webauthn.api.UserVerificationRequirement;
import org.springframework.util.Assert;

/**
 * A <a href="https://webauthn4j.github.io/webauthn4j/en/">WebAuthn4j</a> implementation
 * of {@link WebAuthnRelyingPartyOperations}.
 *
 * @author Rob Winch
 * @since 6.4
 */
public class Webauthn4JRelyingPartyOperations implements WebAuthnRelyingPartyOperations {

	private final PublicKeyCredentialUserEntityRepository userEntities;

	private final UserCredentialRepository userCredentials;

	private final Set<String> allowedOrigins;

	private final PublicKeyCredentialRpEntity rp;

	private ObjectConverter objectConverter = new ObjectConverter();

	private final AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();

	private WebAuthnManager webAuthnManager = WebAuthnManager.createNonStrictWebAuthnManager();

	private Consumer<PublicKeyCredentialCreationOptionsBuilder> customizeCreationOptions = (options) -> {
	};

	private Consumer<PublicKeyCredentialRequestOptionsBuilder> customizeRequestOptions = (options) -> {
	};

	/**
	 * Creates a new instance.
	 * @param userEntities the {@link PublicKeyCredentialUserEntityRepository} to use.
	 * @param userCredentials the {@link UserCredentialRepository} to use.
	 * @param rpEntity the {@link PublicKeyCredentialRpEntity} to use.
	 * @param allowedOrigins the allowed origins.
	 */
	public Webauthn4JRelyingPartyOperations(PublicKeyCredentialUserEntityRepository userEntities,
			UserCredentialRepository userCredentials, PublicKeyCredentialRpEntity rpEntity,
			Set<String> allowedOrigins) {
		Assert.notNull(userEntities, "userEntities cannot be null");
		Assert.notNull(userCredentials, "userCredentials cannot be null");
		Assert.notNull(rpEntity, "rpEntity cannot be null");
		Assert.notNull(allowedOrigins, "allowedOrigins cannot be null");
		this.userEntities = userEntities;
		this.userCredentials = userCredentials;
		this.rp = rpEntity;
		this.allowedOrigins = allowedOrigins;
	}

	/**
	 * Sets the {@link WebAuthnManager} to use. The default is
	 * {@link WebAuthnManager#createNonStrictWebAuthnManager()}
	 * @param webAuthnManager the {@link WebAuthnManager}.
	 */
	public void setWebAuthnManager(WebAuthnManager webAuthnManager) {
		Assert.notNull(webAuthnManager, "webAuthnManager cannot be null");
		this.webAuthnManager = webAuthnManager;
	}

	/**
	 * Sets the {@link ObjectConverter} to use.
	 * @param objectConverter the {@link ObjectConverter} to use. Cannot be null.
	 */
	void setObjectConverter(ObjectConverter objectConverter) {
		Assert.notNull(objectConverter, "objectConverter cannot be null");
		this.objectConverter = objectConverter;
	}

	/**
	 * Sets a {@link Consumer} used to customize the
	 * {@link PublicKeyCredentialCreationOptionsBuilder} for
	 * {@link #createPublicKeyCredentialCreationOptions(PublicKeyCredentialCreationOptionsRequest)}.
	 * The default values are always populated, but can be overridden with this property.
	 * @param customizeCreationOptions the {@link Consumer} to customize the
	 * {@link PublicKeyCredentialCreationOptionsBuilder}
	 */
	public void setCustomizeCreationOptions(
			Consumer<PublicKeyCredentialCreationOptionsBuilder> customizeCreationOptions) {
		Assert.notNull(customizeCreationOptions, "customizeCreationOptions must not be null");
		this.customizeCreationOptions = customizeCreationOptions;
	}

	/**
	 * Sets a {@link Consumer} used to customize the
	 * {@link PublicKeyCredentialRequestOptionsBuilder} for
	 * {@link #createCredentialRequestOptions(PublicKeyCredentialRequestOptionsRequest)}.The
	 * default values are always populated, but can be overridden with this property.
	 * @param customizeRequestOptions the {@link Consumer} to customize the
	 * {@link PublicKeyCredentialRequestOptionsBuilder}
	 */
	public void setCustomizeRequestOptions(Consumer<PublicKeyCredentialRequestOptionsBuilder> customizeRequestOptions) {
		Assert.notNull(customizeRequestOptions, "customizeRequestOptions cannot be null");
		this.customizeRequestOptions = customizeRequestOptions;
	}

	@Override
	public PublicKeyCredentialCreationOptions createPublicKeyCredentialCreationOptions(
			PublicKeyCredentialCreationOptionsRequest request) {
		if (request == null) {
			throw new IllegalArgumentException("request cannot be null");
		}
		Authentication authentication = request.getAuthentication();
		if (!this.trustResolver.isAuthenticated(authentication)) {
			throw new IllegalArgumentException("Authentication must be authenticated");
		}
		AuthenticatorSelectionCriteria authenticatorSelection = AuthenticatorSelectionCriteria.builder()
			.userVerification(UserVerificationRequirement.PREFERRED)
			.residentKey(ResidentKeyRequirement.REQUIRED)
			.build();

		ImmutableAuthenticationExtensionsClientInputs clientInputs = new ImmutableAuthenticationExtensionsClientInputs(
				ImmutableAuthenticationExtensionsClientInput.credProps);

		PublicKeyCredentialUserEntity userEntity = findUserEntityOrCreateAndSave(authentication.getName());
		List<CredentialRecord> credentialRecords = this.userCredentials.findByUserId(userEntity.getId());

		PublicKeyCredentialCreationOptions options = PublicKeyCredentialCreationOptions.builder()
			.attestation(AttestationConveyancePreference.NONE)
			.pubKeyCredParams(PublicKeyCredentialParameters.EdDSA, PublicKeyCredentialParameters.ES256,
					PublicKeyCredentialParameters.RS256)
			.authenticatorSelection(authenticatorSelection)
			.challenge(Bytes.random())
			.extensions(clientInputs)
			.timeout(Duration.ofMinutes(5))
			.user(userEntity)
			.rp(this.rp)
			.excludeCredentials(credentialDescriptors(credentialRecords))
			.customize(this.customizeCreationOptions)
			.build();
		return options;
	}

	private static List<PublicKeyCredentialDescriptor> credentialDescriptors(List<CredentialRecord> credentialRecords) {
		List<PublicKeyCredentialDescriptor> result = new ArrayList<>();
		for (CredentialRecord credentialRecord : credentialRecords) {
			Bytes id = Bytes.fromBase64(credentialRecord.getCredentialId().toBase64UrlString());
			PublicKeyCredentialDescriptor credentialDescriptor = PublicKeyCredentialDescriptor.builder()
				.id(id)
				.transports(credentialRecord.getTransports())
				.build();
			result.add(credentialDescriptor);
		}
		return result;
	}

	private PublicKeyCredentialUserEntity findUserEntityOrCreateAndSave(String username) {
		final PublicKeyCredentialUserEntity foundUserEntity = this.userEntities.findByUsername(username);
		if (foundUserEntity != null) {
			return foundUserEntity;
		}

		PublicKeyCredentialUserEntity userEntity = ImmutablePublicKeyCredentialUserEntity.builder()
			.displayName(username)
			.id(Bytes.random())
			.name(username)
			.build();
		this.userEntities.save(userEntity);
		return userEntity;
	}

	@Override
	public CredentialRecord registerCredential(RelyingPartyRegistrationRequest rpRegistrationRequest) {
		Assert.notNull(rpRegistrationRequest, "rpRegistrationRequest cannot be null");
		Bytes credentialId = rpRegistrationRequest.getPublicKey().getCredential().getRawId();
		CredentialRecord existingCredential = this.userCredentials.findByCredentialId(credentialId);
		if (existingCredential != null) {
			throw new IllegalArgumentException("Credential with id " + credentialId + " already exists");
		}
		PublicKeyCredentialCreationOptions creationOptions = rpRegistrationRequest.getCreationOptions();
		String rpId = creationOptions.getRp().getId();
		RelyingPartyPublicKey publicKey = rpRegistrationRequest.getPublicKey();
		PublicKeyCredential<AuthenticatorAttestationResponse> credential = publicKey.getCredential();
		AuthenticatorAttestationResponse response = credential.getResponse();
		// Server properties
		Set<Origin> origins = toOrigins();
		byte[] base64Challenge = creationOptions.getChallenge().getBytes();
		byte[] attestationObject = response.getAttestationObject().getBytes();
		byte[] clientDataJSON = response.getClientDataJSON().getBytes();
		Challenge challenge = new DefaultChallenge(base64Challenge);
		ServerProperty serverProperty = new ServerProperty(origins, rpId, challenge);
		boolean userVerificationRequired = creationOptions.getAuthenticatorSelection()
			.getUserVerification() == UserVerificationRequirement.REQUIRED;
		// requireUserPresence The constant Boolean value true
		// https://www.w3.org/TR/webauthn-3/#sctn-op-make-cred
		boolean userPresenceRequired = true;
		List<com.webauthn4j.data.PublicKeyCredentialParameters> pubKeyCredParams = convertCredentialParamsToWebauthn4j(
				creationOptions.getPubKeyCredParams());
		Set<String> transports = convertTransportsToString(response);
		RegistrationRequest webauthn4jRegistrationRequest = new RegistrationRequest(attestationObject, clientDataJSON,
				transports);
		RegistrationParameters registrationParameters = new RegistrationParameters(serverProperty, pubKeyCredParams,
				userVerificationRequired, userPresenceRequired);
		RegistrationData wa4jRegistrationData = this.webAuthnManager.verify(webauthn4jRegistrationRequest,
				registrationParameters);
		AttestationObject wa4jAttestationObject = wa4jRegistrationData.getAttestationObject();
		Assert.notNull(wa4jAttestationObject, "attestationObject cannot be null");
		AuthenticatorData<RegistrationExtensionAuthenticatorOutput> wa4jAuthData = wa4jAttestationObject
			.getAuthenticatorData();

		CborConverter cborConverter = this.objectConverter.getCborConverter();
		AttestedCredentialData wa4jCredData = wa4jAuthData.getAttestedCredentialData();
		Assert.notNull(wa4jCredData, "attestedCredentialData cannot be null");
		COSEKey coseKey = wa4jCredData.getCOSEKey();
		byte[] rawCoseKey = cborConverter.writeValueAsBytes(coseKey);
		ImmutableCredentialRecord userCredential = ImmutableCredentialRecord.builder()
			.userEntityUserId(creationOptions.getUser().getId())
			.credentialType(credential.getType())
			.credentialId(credential.getRawId())
			.publicKey(new ImmutablePublicKeyCose(rawCoseKey))
			.signatureCount(wa4jAuthData.getSignCount())
			.uvInitialized(wa4jAuthData.isFlagUV())
			.transports(convertTransports(wa4jRegistrationData.getTransports()))
			.backupEligible(wa4jAuthData.isFlagBE())
			.backupState(wa4jAuthData.isFlagBS())
			.label(publicKey.getLabel())
			.attestationClientDataJSON(credential.getResponse().getClientDataJSON())
			.attestationObject(credential.getResponse().getAttestationObject())
			.build();
		this.userCredentials.save(userCredential);
		return userCredential;
	}

	private static @Nullable Set<String> convertTransportsToString(AuthenticatorAttestationResponse response) {
		if (response.getTransports() == null) {
			return null;
		}
		Set<String> transports = new HashSet<>(response.getTransports().size());
		for (AuthenticatorTransport transport : response.getTransports()) {
			transports.add(transport.getValue());
		}
		return transports;
	}

	private List<com.webauthn4j.data.PublicKeyCredentialParameters> convertCredentialParamsToWebauthn4j(
			List<PublicKeyCredentialParameters> parameters) {
		return parameters.stream().map(this::convertParamToWebauthn4j).toList();
	}

	private com.webauthn4j.data.PublicKeyCredentialParameters convertParamToWebauthn4j(
			PublicKeyCredentialParameters parameter) {
		if (parameter.getType() != PublicKeyCredentialType.PUBLIC_KEY) {
			throw new IllegalArgumentException(
					"Cannot convert unknown credential type " + parameter.getType() + " to webauthn4j");
		}
		long algValue = parameter.getAlg().getValue();
		com.webauthn4j.data.attestation.statement.COSEAlgorithmIdentifier alg = com.webauthn4j.data.attestation.statement.COSEAlgorithmIdentifier
			.create(algValue);
		return new com.webauthn4j.data.PublicKeyCredentialParameters(
				com.webauthn4j.data.PublicKeyCredentialType.PUBLIC_KEY, alg);
	}

	private Set<Origin> toOrigins() {
		return this.allowedOrigins.stream().map(Origin::new).collect(Collectors.toSet());
	}

	private static Set<AuthenticatorTransport> convertTransports(
			@Nullable Set<com.webauthn4j.data.AuthenticatorTransport> transports) {
		if (transports == null) {
			return Collections.emptySet();
		}
		return transports.stream()
			.map((t) -> AuthenticatorTransport.valueOf(t.getValue()))
			.collect(Collectors.toUnmodifiableSet());
	}

	@Override
	public PublicKeyCredentialRequestOptions createCredentialRequestOptions(
			PublicKeyCredentialRequestOptionsRequest request) {
		Authentication authentication = request.getAuthentication();
		List<CredentialRecord> credentialRecords = findCredentialRecords(authentication);
		return PublicKeyCredentialRequestOptions.builder()
			.allowCredentials(credentialDescriptors(credentialRecords))
			.challenge(Bytes.random())
			.rpId(this.rp.getId())
			.timeout(Duration.ofMinutes(5))
			.userVerification(UserVerificationRequirement.PREFERRED)
			.customize(this.customizeRequestOptions)
			.build();
	}

	@NullUnmarked
	private List<CredentialRecord> findCredentialRecords(@Nullable Authentication authentication) {
		if (!this.trustResolver.isAuthenticated(authentication)) {
			return Collections.emptyList();
		}
		PublicKeyCredentialUserEntity userEntity = this.userEntities.findByUsername(authentication.getName());
		if (userEntity == null) {
			return Collections.emptyList();
		}
		return this.userCredentials.findByUserId(userEntity.getId());
	}

	@Override
	public PublicKeyCredentialUserEntity authenticate(RelyingPartyAuthenticationRequest request) {
		PublicKeyCredentialRequestOptions requestOptions = request.getRequestOptions();
		AuthenticatorAssertionResponse assertionResponse = request.getPublicKey().getResponse();
		Bytes keyId = request.getPublicKey().getRawId();
		CredentialRecord credentialRecord = this.userCredentials.findByCredentialId(keyId);
		if (credentialRecord == null) {
			throw new IllegalArgumentException("Unable to find CredentialRecord with id " + keyId);
		}
		CborConverter cborConverter = this.objectConverter.getCborConverter();
		Bytes attestationObject = credentialRecord.getAttestationObject();
		Assert.notNull(attestationObject, "attestationObject cannot be null");
		AttestationObject wa4jAttestationObject = cborConverter.readValue(attestationObject.getBytes(),
				AttestationObject.class);
		Assert.notNull(wa4jAttestationObject, "attestationObject cannot be null");
		AuthenticatorData<RegistrationExtensionAuthenticatorOutput> wa4jAuthData = wa4jAttestationObject
			.getAuthenticatorData();
		AttestedCredentialData wa4jCredData = wa4jAuthData.getAttestedCredentialData();
		Assert.notNull(wa4jCredData, "attestedCredentialData cannot be null");

		Set<Origin> origins = toOrigins();
		Challenge challenge = new DefaultChallenge(requestOptions.getChallenge().getBytes());
		String rpId = requestOptions.getRpId();
		Assert.notNull(rpId, "rpId cannot be null");
		ServerProperty serverProperty = new ServerProperty(origins, rpId, challenge);
		boolean userVerificationRequired = request.getRequestOptions()
			.getUserVerification() == UserVerificationRequirement.REQUIRED;

		com.webauthn4j.data.AuthenticationRequest authenticationRequest = new com.webauthn4j.data.AuthenticationRequest(
				request.getPublicKey().getRawId().getBytes(), assertionResponse.getAuthenticatorData().getBytes(),
				assertionResponse.getClientDataJSON().getBytes(), assertionResponse.getSignature().getBytes());

		// CollectedClientData and ExtensionsClientOutputs is registration data, and can
		// be null at authentication time.
		com.webauthn4j.credential.CredentialRecord wa4jCredentialRecord = new CredentialRecordImpl(
				wa4jAttestationObject, null, null, convertTransportsToWebauthn4j(credentialRecord.getTransports()));
		List<byte[]> allowCredentials = convertAllowedCredentialsToWebauthn4j(
				request.getRequestOptions().getAllowCredentials());
		AuthenticationParameters authenticationParameters = new AuthenticationParameters(serverProperty,
				wa4jCredentialRecord, allowCredentials.isEmpty() ? null : allowCredentials, userVerificationRequired);

		AuthenticationData wa4jAuthenticationData = this.webAuthnManager.verify(authenticationRequest,
				authenticationParameters);

		AuthenticatorData<AuthenticationExtensionAuthenticatorOutput> wa4jValidatedAuthData = wa4jAuthenticationData
			.getAuthenticatorData();
		Assert.notNull(wa4jValidatedAuthData, "authenticatorData cannot be null");
		long updatedSignCount = wa4jValidatedAuthData.getSignCount();
		ImmutableCredentialRecord updatedRecord = ImmutableCredentialRecord.fromCredentialRecord(credentialRecord)
			.lastUsed(Instant.now())
			.signatureCount(updatedSignCount)
			.build();
		this.userCredentials.save(updatedRecord);

		PublicKeyCredentialUserEntity userEntity = this.userEntities.findById(credentialRecord.getUserEntityUserId());
		if (userEntity == null) {
			throw new IllegalArgumentException(
					"Unable to find UserEntity with id " + credentialRecord.getUserEntityUserId() + " for " + request);
		}
		return userEntity;
	}

	private static Set<com.webauthn4j.data.AuthenticatorTransport> convertTransportsToWebauthn4j(
			Set<AuthenticatorTransport> transports) {
		return transports.stream()
			.map(AuthenticatorTransport::getValue)
			.map(com.webauthn4j.data.AuthenticatorTransport::create)
			.collect(Collectors.toSet());
	}

	private static List<byte[]> convertAllowedCredentialsToWebauthn4j(
			List<PublicKeyCredentialDescriptor> allowedCredentials) {
		return allowedCredentials.stream()
			.map(PublicKeyCredentialDescriptor::getId)
			.filter(Objects::nonNull)
			.map(Bytes::getBytes)
			.collect(Collectors.toList());
	}

}