ASN1Decoder.java
/*
* Copyright 2024 Red Hat, Inc. and/or its affiliates
* and other contributors as indicated by the @author tags.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/
package org.keycloak.authorization.client.util.crypto;
import java.io.ByteArrayInputStream;
import java.io.EOFException;
import java.io.IOException;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;
/**
*
* @author rmartinc
*/
class ASN1Decoder {
private final ByteArrayInputStream is;
private final int limit;
private int count;
ASN1Decoder(byte[] bytes) {
is = new ByteArrayInputStream(bytes);
count = 0;
limit = bytes.length;
}
public static ASN1Decoder create(byte[] bytes) {
return new ASN1Decoder(bytes);
}
public List<byte[]> readSequence() throws IOException {
int tag = readTag();
int tagNo = readTagNumber(tag);
if (tagNo != ASN1Encoder.SEQUENCE) {
throw new IOException("Invalid Sequence tag " + tagNo);
}
int length = readLength();
List<byte[]> result = new ArrayList<>();
while (length > 0) {
byte[] bytes = readNext();
result.add(bytes);
length = length - bytes.length;
}
return result;
}
public BigInteger readInteger() throws IOException {
int tag = readTag();
int tagNo = readTagNumber(tag);
if (tagNo != ASN1Encoder.INTEGER) {
throw new IOException("Invalid Integer tag " + tagNo);
}
int length = readLength();
byte[] bytes = read(length);
return new BigInteger(bytes);
}
byte[] readNext() throws IOException {
mark();
int tag = readTag();
readTagNumber(tag);
int length = readLength();
length += reset();
return read(length);
}
int readTag() throws IOException {
int tag = read();
if (tag < 0) {
throw new EOFException("EOF found inside tag value.");
}
return tag;
}
int readTagNumber(int tag) throws IOException {
int tagNo = tag & 0x1f;
//
// with tagged object tag number is bottom 5 bits, or stored at the start of the content
//
if (tagNo == 0x1f) {
tagNo = 0;
int b = read();
// X.690-0207 8.1.2.4.2
// "c) bits 7 to 1 of the first subsequent octet shall not all be zero."
if ((b & 0x7f) == 0) // Note: -1 will pass
{
throw new IOException("corrupted stream - invalid high tag number found");
}
while ((b >= 0) && ((b & 0x80) != 0)) {
tagNo |= (b & 0x7f);
tagNo <<= 7;
b = read();
}
if (b < 0) {
throw new EOFException("EOF found inside tag value.");
}
tagNo |= (b & 0x7f);
}
return tagNo;
}
int readLength() throws IOException {
int length = read();
if (length < 0) {
throw new EOFException("EOF found when length expected");
}
if (length == 0x80) {
return -1; // indefinite-length encoding
}
if (length > 127) {
int size = length & 0x7f;
// Note: The invalid long form "0xff" (see X.690 8.1.3.5c) will be caught here
if (size > 4) {
throw new IOException("DER length more than 4 bytes: " + size);
}
length = 0;
for (int i = 0; i < size; i++) {
int next = read();
if (next < 0) {
throw new EOFException("EOF found reading length");
}
length = (length << 8) + next;
}
if (length < 0) {
throw new IOException("corrupted stream - negative length found");
}
if (length >= limit) // after all we must have read at least 1 byte
{
throw new IOException("corrupted stream - out of bounds length found");
}
}
return length;
}
byte[] read(int length) throws IOException {
byte[] bytes = new byte[length];
int totalBytesRead = 0;
while (totalBytesRead < length) {
int bytesRead = is.read(bytes, totalBytesRead, length - totalBytesRead);
if (bytesRead == -1) {
throw new IOException(String.format("EOF found reading %d bytes", length));
}
totalBytesRead += bytesRead;
}
count += length;
return bytes;
}
void mark() {
count = 0;
is.mark(is.available());
}
int reset() {
int tmp = count;
is.reset();
return tmp;
}
int read() {
int tmp = is.read();
if (tmp >= 0) {
count++;
}
return tmp;
}
}