SdJwtVP.java
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.keycloak.sdjwt.vp;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import org.keycloak.common.VerificationException;
import org.keycloak.common.util.Base64Url;
import org.keycloak.crypto.SignatureSignerContext;
import org.keycloak.crypto.SignatureVerifierContext;
import org.keycloak.sdjwt.IssuerSignedJWT;
import org.keycloak.sdjwt.IssuerSignedJwtVerificationOpts;
import org.keycloak.sdjwt.SdJwt;
import org.keycloak.sdjwt.SdJwtUtils;
import org.keycloak.sdjwt.SdJwtVerificationContext;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
/**
* @author <a href="mailto:francis.pouatcha@adorsys.com">Francis Pouatcha</a>
*/
public class SdJwtVP {
private final String sdJwtVpString;
private final IssuerSignedJWT issuerSignedJWT;
private final Map<String, ArrayNode> claims;
private final Map<String, String> disclosures;
private final Map<String, String> recursiveDigests;
private final List<String> ghostDigests;
private final String hashAlgorithm;
private final Optional<KeyBindingJWT> keyBindingJWT;
private final SdJwtVerificationContext sdJwtVerificationContext;
public Map<String, ArrayNode> getClaims() {
return claims;
}
public IssuerSignedJWT getIssuerSignedJWT() {
return issuerSignedJWT;
}
public Map<String, String> getDisclosures() {
return disclosures;
}
public Collection<String> getDisclosuresString() {
return disclosures.values();
}
public Map<String, String> getRecursiveDigests() {
return recursiveDigests;
}
public Collection<String> getGhostDigests() {
return ghostDigests;
}
public String getHashAlgorithm() {
return hashAlgorithm;
}
public Optional<KeyBindingJWT> getKeyBindingJWT() {
return keyBindingJWT;
}
private SdJwtVP(String sdJwtVpString, String hashAlgorithm, IssuerSignedJWT issuerSignedJWT,
Map<String, ArrayNode> claims, Map<String, String> disclosures, Map<String, String> recursiveDigests,
List<String> ghostDigests, Optional<KeyBindingJWT> keyBindingJWT) {
this.sdJwtVpString = sdJwtVpString;
this.hashAlgorithm = hashAlgorithm;
this.issuerSignedJWT = issuerSignedJWT;
this.claims = Collections.unmodifiableMap(claims);
this.disclosures = Collections.unmodifiableMap(disclosures);
this.recursiveDigests = Collections.unmodifiableMap(recursiveDigests);
this.ghostDigests = Collections.unmodifiableList(ghostDigests);
this.keyBindingJWT = keyBindingJWT;
// Instantiate context for verification
this.sdJwtVerificationContext = new SdJwtVerificationContext(
this.sdJwtVpString,
this.issuerSignedJWT,
this.disclosures,
this.keyBindingJWT.orElse(null)
);
}
public static SdJwtVP of(String sdJwtString) {
int disclosureStart = sdJwtString.indexOf(SdJwt.DELIMITER);
int disclosureEnd = sdJwtString.lastIndexOf(SdJwt.DELIMITER);
if (disclosureStart == -1) {
throw new IllegalArgumentException("SD-JWT is malformed, expected to contain a '" + SdJwt.DELIMITER + "'");
}
String issuerSignedJWTString = sdJwtString.substring(0, disclosureStart);
String disclosuresString = "";
if (disclosureEnd > disclosureStart) {
disclosuresString = sdJwtString.substring(disclosureStart + 1, disclosureEnd);
}
IssuerSignedJWT issuerSignedJWT = IssuerSignedJWT.fromJws(issuerSignedJWTString);
ObjectNode issuerPayload = (ObjectNode) issuerSignedJWT.getPayload();
String hashAlgorithm = issuerPayload.get(IssuerSignedJWT.CLAIM_NAME_SD_HASH_ALGORITHM).asText();
Map<String, ArrayNode> claims = new HashMap<>();
Map<String, String> disclosures = new HashMap<>();
String[] split = disclosuresString.split(SdJwt.DELIMITER);
for (String disclosure : split) {
String disclosureDigest = SdJwtUtils.hashAndBase64EncodeNoPad(disclosure.getBytes(), hashAlgorithm);
if (disclosures.containsKey(disclosureDigest)) {
throw new IllegalArgumentException("Duplicate disclosure digest");
}
disclosures.put(disclosureDigest, disclosure);
ArrayNode disclosureData;
try {
disclosureData = (ArrayNode) SdJwtUtils.mapper.readTree(Base64Url.decode(disclosure));
claims.put(disclosureDigest, disclosureData);
} catch (IOException e) {
throw new IllegalArgumentException("Invalid disclosure data");
}
}
Set<String> allDigests = claims.keySet();
Map<String, String> recursiveDigests = new HashMap<>();
List<String> ghostDigests = new ArrayList<>();
allDigests.stream()
.forEach(disclosureDigest -> {
JsonNode node = findNode(issuerPayload, disclosureDigest);
node = processDisclosureDigest(node, disclosureDigest, claims, recursiveDigests, ghostDigests);
});
Optional<KeyBindingJWT> keyBindingJWT = Optional.empty();
if (sdJwtString.length() > disclosureEnd + 1) {
String keyBindingJWTString = sdJwtString.substring(disclosureEnd + 1);
keyBindingJWT = Optional.of(KeyBindingJWT.of(keyBindingJWTString));
}
// Drop the key binding String if any. As it is held by the keyBindingJwtObject
String sdJWtVPString = sdJwtString.substring(0, disclosureEnd + 1);
return new SdJwtVP(sdJWtVPString, hashAlgorithm, issuerSignedJWT, claims, disclosures, recursiveDigests,
ghostDigests, keyBindingJWT);
}
private static JsonNode processDisclosureDigest(JsonNode node, String disclosureDigest,
Map<String, ArrayNode> claims,
Map<String, String> recursiveDigests,
List<String> ghostDigests) {
if (node == null) { // digest is nested in another disclosure
Set<Entry<String, ArrayNode>> entrySet = claims.entrySet();
for (Entry<String, ArrayNode> entry : entrySet) {
if (entry.getKey().equals(disclosureDigest)) {
continue;
}
node = findNode(entry.getValue(), disclosureDigest);
if (node != null) {
recursiveDigests.put(disclosureDigest, entry.getKey());
break;
}
}
}
if (node == null) { // No digest found for disclosure.
ghostDigests.add(disclosureDigest);
}
return node;
}
public JsonNode getCnfClaim() {
return issuerSignedJWT.getCnfClaim().orElse(null);
}
public String present(List<String> disclosureDigests, JsonNode keyBindingClaims,
SignatureSignerContext holdSignatureSignerContext, String jwsType) {
StringBuilder sb = new StringBuilder();
if (disclosureDigests == null || disclosureDigests.isEmpty()) {
// disclose everything
sb.append(sdJwtVpString);
} else {
sb.append(issuerSignedJWT.toJws());
sb.append(SdJwt.DELIMITER);
for (String disclosureDigest : disclosureDigests) {
sb.append(disclosures.get(disclosureDigest));
sb.append(SdJwt.DELIMITER);
}
}
String unboundPresentation = sb.toString();
if (keyBindingClaims == null || holdSignatureSignerContext == null) {
return unboundPresentation;
}
String sd_hash = SdJwtUtils.hashAndBase64EncodeNoPad(unboundPresentation.getBytes(), getHashAlgorithm());
keyBindingClaims = ((ObjectNode) keyBindingClaims).put("sd_hash", sd_hash);
KeyBindingJWT keyBindingJWT = KeyBindingJWT.from(keyBindingClaims, holdSignatureSignerContext, jwsType);
sb.append(keyBindingJWT.toJws());
return sb.toString();
}
/**
* Verifies SD-JWT presentation.
*
* @param issuerVerifyingKeys Verifying keys for validating the Issuer-signed JWT. The caller
* is responsible for establishing trust in that the keys belong
* to the intended issuer.
* @param issuerSignedJwtVerificationOpts Options to parameterize the Issuer-Signed JWT verification.
* @param keyBindingJwtVerificationOpts Options to parameterize the Key Binding JWT verification.
* Must, among others, specify the Verifier's policy whether
* to check Key Binding.
* @throws VerificationException if verification failed
*/
public void verify(
List<SignatureVerifierContext> issuerVerifyingKeys,
IssuerSignedJwtVerificationOpts issuerSignedJwtVerificationOpts,
KeyBindingJwtVerificationOpts keyBindingJwtVerificationOpts
) throws VerificationException {
sdJwtVerificationContext.verifyPresentation(
issuerVerifyingKeys,
issuerSignedJwtVerificationOpts,
keyBindingJwtVerificationOpts,
null
);
}
/**
* Retrieve verification context for advanced scenarios.
*/
public SdJwtVerificationContext getSdJwtVerificationContext() {
return sdJwtVerificationContext;
}
// Recursively searches the node with the given value.
// Returns the node if found, null otherwise.
private static JsonNode findNode(JsonNode node, String value) {
if (node == null) {
return null;
}
if (node.isValueNode()) {
if (node.asText().equals(value)) {
return node;
} else {
return null;
}
}
if (node.isArray() || node.isObject()) {
for (JsonNode child : node) {
JsonNode found = findNode(child, value);
if (found != null) {
return found;
}
}
}
return null;
}
@Override
public String toString() {
return sdJwtVpString;
}
public String verbose() {
StringBuilder sb = new StringBuilder();
sb.append("Issuer Signed JWT: ");
sb.append(issuerSignedJWT.getPayload());
sb.append("\n");
disclosures.forEach((digest, disclosure) -> {
sb.append("\n");
sb.append("Digest: ");
sb.append(digest);
sb.append("\n");
sb.append("Disclosure: ");
sb.append(disclosure);
sb.append("\n");
sb.append("Content: ");
sb.append(claims.get(digest));
sb.append("\n");
});
sb.append("\n");
sb.append("Recursive Digests: ");
sb.append(recursiveDigests);
sb.append("\n");
sb.append("\n");
sb.append("Ghost Digests: ");
sb.append(ghostDigests);
sb.append("\n");
sb.append("\n");
if (keyBindingJWT.isPresent()) {
sb.append("Key Binding JWT: ");
sb.append("\n");
sb.append(keyBindingJWT.get().getPayload().toString());
sb.append("\n");
}
return sb.toString();
}
}