SpkiPinningClientTlsStrategyTest.java
/*
* ====================================================================
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
* ====================================================================
*
* This software consists of voluntary contributions made by many
* individuals on behalf of the Apache Software Foundation. For more
* information on the Apache Software Foundation, please see
* <http://www.apache.org/>.
*
*/
package org.apache.hc.client5.http.ssl;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.math.BigInteger;
import java.net.IDN;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.Principal;
import java.security.PublicKey;
import java.security.cert.CertificateEncodingException;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Date;
import java.util.Locale;
import java.util.Set;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSession;
import javax.security.auth.x500.X500Principal;
import org.junit.jupiter.api.Test;
class SpkiPinningClientTlsStrategyTest {
private static String sha256Pin(final byte[] spki) throws Exception {
final byte[] digest = MessageDigest.getInstance("SHA-256").digest(spki);
return "sha256/" + Base64.getEncoder().encodeToString(digest);
}
@Test
void exactHostMatch() throws Exception {
final byte[] spki = new byte[]{1, 2, 3, 4, 5};
final String pin = sha256Pin(spki);
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com", pin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(spki));
assertDoesNotThrow(() -> strategy.enforcePins("api.example.com", session));
}
@Test
void wildcardMatch() throws Exception {
final byte[] spki = new byte[]{9, 9, 9, 9};
final String pin = sha256Pin(spki);
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("*.example.com", pin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(spki));
assertDoesNotThrow(() -> strategy.enforcePins("svc.example.com", session));
}
@Test
void pinningFailure() throws Exception {
final byte[] spki = new byte[]{7, 7, 7};
final String wrongPin = sha256Pin(new byte[]{8, 8, 8});
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com", wrongPin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(spki));
assertThrows(SSLException.class, () -> strategy.enforcePins("api.example.com", session));
}
@Test
void wildcardDoesNotMatchMultiLabel() throws Exception {
final byte[] spki = new byte[]{4, 2, 4, 2};
final String pin = sha256Pin(spki);
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("*.example.com", pin)
.build();
// a.b.example.com should NOT match single-label wildcard -> no pinning enforced -> no throw
final SSLSession session = new FakeSession(new X509WithKey(new byte[]{1, 2, 3}));
assertDoesNotThrow(() -> strategy.enforcePins("a.b.example.com", session));
}
@Test
void backupPinSucceedsWhenFirstPinDoesNotMatch() throws Exception {
final byte[] spkiGood = new byte[]{10, 11, 12, 13};
final byte[] spkiBad = new byte[]{99, 99, 99, 99};
final String wrongPin = sha256Pin(spkiBad);
final String goodPin = sha256Pin(spkiGood);
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
// wrong pin first, correct pin second
.add("api.example.com", wrongPin, goodPin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(spkiGood));
assertDoesNotThrow(() -> strategy.enforcePins("api.example.com", session));
}
@Test
void idnExactHostMatch() throws Exception {
// Host: b��cher.example -> xn--bcher-kva.example
final byte[] spki = new byte[]{42, 42, 42, 42};
final String pin = sha256Pin(spki);
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("b��cher.example", pin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(spki));
// enforcePins expects IDNA ASCII (like verifySession would pass)
final String ascii = IDN.toASCII("b��cher.example").toLowerCase(Locale.ROOT);
assertDoesNotThrow(() -> strategy.enforcePins(ascii, session));
}
@Test
void invalidBase64PinRejected() {
assertThrows(IllegalArgumentException.class, () -> SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com", "sha256/###not_base64###"));
}
@Test
void wrongLengthPinRejected() {
// Base64 of 1 byte -> decoded length != 32
final String shortPin = "sha256/" + Base64.getEncoder().encodeToString(new byte[]{1});
assertThrows(IllegalArgumentException.class, () -> SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com", shortPin));
}
@Test
void emptyPinsRejected() {
assertThrows(IllegalArgumentException.class, () -> SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com"));
}
@Test
void invalidWildcardPatternRejected() {
// "*.": not a valid single-label wildcard
assertThrows(IllegalArgumentException.class, () -> SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("*.", "sha256/" + Base64.getEncoder().encodeToString(new byte[32])));
}
@Test
void wildcardConfiguredButWrongPinFails() throws Exception {
final byte[] spki = new byte[]{5, 5, 5, 5};
final String wrongPin = sha256Pin(new byte[]{6, 6, 6, 6});
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("*.example.com", wrongPin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(spki));
assertThrows(SSLException.class, () -> strategy.enforcePins("svc.example.com", session));
}
@Test
void noConfiguredPinsForHostShortCircuits() throws Exception {
// Pins configured for other domain, not for foo.bar
final byte[] spki = new byte[]{1, 1, 1, 1};
final String pin = sha256Pin(spki);
final SpkiPinningClientTlsStrategy strategy = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com", pin)
.build();
final SSLSession session = new FakeSession(new X509WithKey(new byte[]{9, 9, 9}));
// No rules match -> pinning not enforced -> no throw
assertDoesNotThrow(() -> strategy.enforcePins("foo.bar", session));
}
@Test
void verifySessionInvalidIdnThrowsSslException() throws NoSuchAlgorithmException {
final SpkiPinningClientTlsStrategy s = SpkiPinningClientTlsStrategy
.newBuilder(SSLContext.getDefault())
.add("api.example.com", "sha256/" + Base64.getEncoder().encodeToString(new byte[32]))
.build();
final SSLSession session = new FakeSession(new X509WithKey(new byte[]{1}));
assertThrows(SSLException.class, () -> s.verifySession("\uDC00bad", session));
}
private static final class X509WithKey extends X509Certificate {
private final PublicKey key;
X509WithKey(final byte[] spki) {
this.key = new PublicKey() {
@Override
public String getAlgorithm() {
return "RSA";
}
@Override
public String getFormat() {
return "X.509";
}
@Override
public byte[] getEncoded() {
return spki;
}
};
}
@Override
public PublicKey getPublicKey() {
return key;
}
@Override
public void checkValidity() {
}
@Override
public void checkValidity(final Date date) {
}
@Override
public int getVersion() {
return 3;
}
@Override
public BigInteger getSerialNumber() {
return BigInteger.ONE;
}
@Override
public Principal getIssuerDN() {
return new X500Principal("CN=issuer");
}
@Override
public Principal getSubjectDN() {
return new X500Principal("CN=subject");
}
@Override
public Date getNotBefore() {
return new Date(0L);
}
@Override
public Date getNotAfter() {
return new Date(4102444800000L);
} // ~2100-01-01
@Override
public byte[] getTBSCertificate() throws CertificateEncodingException {
throw new UnsupportedOperationException();
}
@Override
public byte[] getSignature() {
return new byte[0];
}
@Override
public String getSigAlgName() {
return "NONE";
}
@Override
public String getSigAlgOID() {
return "0.0";
}
@Override
public byte[] getSigAlgParams() {
return null;
}
@Override
public boolean[] getIssuerUniqueID() {
return null;
}
@Override
public boolean[] getSubjectUniqueID() {
return null;
}
@Override
public boolean[] getKeyUsage() {
return null;
}
@Override
public int getBasicConstraints() {
return -1;
}
@Override
public byte[] getEncoded() throws CertificateEncodingException {
throw new UnsupportedOperationException();
}
@Override
public void verify(final PublicKey key) {
}
@Override
public void verify(final PublicKey key, final String sigProvider) {
}
@Override
public String toString() {
return "X509WithKey";
}
@Override
public X500Principal getIssuerX500Principal() {
return new X500Principal("CN=issuer");
}
@Override
public X500Principal getSubjectX500Principal() {
return new X500Principal("CN=subject");
}
@Override
public Set<String> getCriticalExtensionOIDs() {
return null;
}
@Override
public Set<String> getNonCriticalExtensionOIDs() {
return null;
}
@Override
public byte[] getExtensionValue(final String oid) {
return null;
}
@Override
public boolean hasUnsupportedCriticalExtension() {
return false;
}
}
private static final class FakeSession implements SSLSession {
private final X509Certificate[] chain;
FakeSession(final X509Certificate cert) {
this.chain = new X509Certificate[]{cert};
}
@Override
public java.security.cert.Certificate[] getPeerCertificates() {
return chain;
}
@Override
public javax.security.cert.X509Certificate[] getPeerCertificateChain() {
return new javax.security.cert.X509Certificate[0];
}
@Override
public String getProtocol() {
return "TLSv1.3";
}
@Override
public String getCipherSuite() {
return "TLS_AES_128_GCM_SHA256";
}
@Override
public Principal getPeerPrincipal() {
return null;
}
@Override
public Principal getLocalPrincipal() {
return null;
}
@Override
public java.security.cert.Certificate[] getLocalCertificates() {
return new java.security.cert.Certificate[0];
}
@Override
public String getPeerHost() {
return "api.example.com";
}
@Override
public int getPeerPort() {
return 443;
}
@Override
public int getPacketBufferSize() {
return 0;
}
@Override
public int getApplicationBufferSize() {
return 0;
}
@Override
public long getCreationTime() {
return 0;
}
@Override
public long getLastAccessedTime() {
return 0;
}
@Override
public void invalidate() {
}
@Override
public boolean isValid() {
return true;
}
@Override
public Object getValue(final String s) {
return null;
}
@Override
public String[] getValueNames() {
return new String[0];
}
@Override
public void putValue(final String s, final Object o) {
}
@Override
public void removeValue(final String s) {
}
@Override
public javax.net.ssl.SSLSessionContext getSessionContext() {
return null;
}
@Override
public byte[] getId() {
return new byte[0];
}
}
}