CombinedClaimsAttributeStatementProvider.java
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cxf.sts.claims;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.cxf.sts.token.provider.AttributeStatementProvider;
import org.apache.cxf.sts.token.provider.TokenProviderParameters;
import org.apache.wss4j.common.WSS4JConstants;
import org.apache.wss4j.common.saml.bean.AttributeBean;
import org.apache.wss4j.common.saml.bean.AttributeStatementBean;
import org.apache.wss4j.common.saml.builder.SAML2Constants;
/**
* This class differs from the ClaimsAttributeStatementProvider in that it combines claims that have the same name.
*/
public class CombinedClaimsAttributeStatementProvider implements AttributeStatementProvider {
private String nameFormat = SAML2Constants.ATTRNAME_FORMAT_UNSPECIFIED;
public AttributeStatementBean getStatement(TokenProviderParameters providerParameters) {
// Handle Claims
ProcessedClaimCollection retrievedClaims = ClaimsUtils.processClaims(providerParameters);
if (retrievedClaims == null) {
return null;
}
Iterator<ProcessedClaim> claimIterator = retrievedClaims.iterator();
if (!claimIterator.hasNext()) {
return null;
}
Map<AttributeKey, AttributeBean> attributeMap = new LinkedHashMap<>();
String tokenType = providerParameters.getTokenRequirements().getTokenType();
boolean saml2 = WSS4JConstants.WSS_SAML2_TOKEN_TYPE.equals(tokenType)
|| WSS4JConstants.SAML2_NS.equals(tokenType);
while (claimIterator.hasNext()) {
ProcessedClaim claim = claimIterator.next();
AttributeKey attributeKey = createAttributeKey(claim, saml2);
attributeMap.merge(
attributeKey,
createAttributeBean(attributeKey, claim.getValues()),
(v1, v2) -> {
v1.getAttributeValues().addAll(claim.getValues());
return v1;
});
}
AttributeStatementBean attrBean = new AttributeStatementBean();
attrBean.setSamlAttributes(new ArrayList<>(attributeMap.values()));
return attrBean;
}
private AttributeBean createAttributeBean(AttributeKey attributeKey, List<Object> claimValues) {
AttributeBean attributeBean =
new AttributeBean(attributeKey.getSimpleName(), attributeKey.getQualifiedName(), claimValues);
attributeBean.setNameFormat(attributeKey.getNameFormat());
return attributeBean;
}
private AttributeKey createAttributeKey(ProcessedClaim claim, boolean saml2) {
String claimType = claim.getClaimType();
if (saml2) {
return new AttributeKey(claimType, nameFormat, null);
} else {
String uri = claimType;
int lastSlash = uri.lastIndexOf('/');
if (lastSlash == (uri.length() - 1)) {
uri = uri.substring(0, lastSlash);
lastSlash = uri.lastIndexOf('/');
}
String namespace = uri.substring(0, lastSlash);
String name = uri.substring(lastSlash + 1, uri.length());
return new AttributeKey(namespace, null, name);
}
}
public String getNameFormat() {
return nameFormat;
}
public void setNameFormat(String nameFormat) {
this.nameFormat = nameFormat;
}
private static class AttributeKey {
private final String qualifiedName;
private final String simpleName;
private final String nameFormat;
// SAML 2.0 constructor
AttributeKey(String qualifiedName, String nameFormat, String simpleName) {
this.qualifiedName = qualifiedName;
this.nameFormat = nameFormat;
this.simpleName = simpleName;
}
public String getQualifiedName() {
return qualifiedName;
}
public String getSimpleName() {
return simpleName;
}
public String getNameFormat() {
return nameFormat;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof AttributeKey)) {
return false;
}
AttributeKey that = (AttributeKey) o;
if (qualifiedName == null && that.qualifiedName != null
|| qualifiedName != null && !qualifiedName.equals(that.qualifiedName)) {
return false;
}
if (simpleName == null && that.simpleName != null
|| simpleName != null && !simpleName.equals(that.simpleName)) {
return false;
}
return !(nameFormat == null && that.nameFormat != null
|| nameFormat != null && !nameFormat.equals(that.nameFormat));
}
@Override
public int hashCode() {
int result = 0;
if (qualifiedName != null) {
result = 31 * result + qualifiedName.hashCode();
}
if (simpleName != null) {
result = 31 * result + simpleName.hashCode();
}
if (nameFormat != null) {
result = 31 * result + nameFormat.hashCode();
}
return result;
}
}
}