CharSoupModel.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Locale;
/**
* INT8-quantized multinomial logistic regression model for
* language detection.
* <p>
* Binary format (big-endian, magic "LDM1"):
* <pre>
* v1 layout:
* Offset Field
* 0 4B magic: 0x4C444D31
* 4 4B version: 1
* 8 4B numBuckets (B)
* 12 4B numClasses (C)
* 16+ Labels: C entries of [2B length + UTF-8 bytes]
* Scales: C �� 4B float (per-class dequantization)
* Biases: C �� 4B float (per-class bias term)
* Weights: B �� C bytes (bucket-major, INT8 signed)
*
* v2 layout (adds feature flags after numClasses):
* Offset Field
* 0 4B magic: 0x4C444D31
* 4 4B version: 2
* 8 4B numBuckets (B)
* 12 4B numClasses (C)
* 16 4B featureFlags (bitmask of FLAG_* constants)
* 20+ Labels, Scales, Biases, Weights (same as v1)
* </pre>
* <p>
* Weights are stored in bucket-major order:
* {@code weights[bucket * numClasses + class]}. This layout
* is optimal for the sparse dot-product in {@link #predict}
* ��� each non-zero bucket reads a contiguous run of
* {@code numClasses} bytes, ideal for SIMD and cache
* prefetching.
* <p>
* Feature extraction always uses
* {@link ScriptAwareFeatureExtractor}, which produces
* character bigrams (with sentinels for non-CJK), whole-word
* unigrams, CJK character unigrams, and CJK space bridging.
*/
public class CharSoupModel {
static final int MAGIC = 0x4C444D31; // "LDM1"
static final int VERSION_V1 = 1;
static final int VERSION_V2 = 2;
/** Feature flag: enable character trigrams. */
public static final int FLAG_TRIGRAMS = 1 << 0;
/** Feature flag: enable skip bigrams. */
public static final int FLAG_SKIP_BIGRAMS = 1 << 1;
/** Feature flag: enable 3-char word suffixes. */
public static final int FLAG_SUFFIXES = 1 << 2;
/** Feature flag: enable 4-char word suffixes. */
public static final int FLAG_SUFFIX4 = 1 << 3;
/** Feature flag: enable 3-char word prefixes. */
public static final int FLAG_PREFIX = 1 << 4;
/** Feature flag: enable whole-word unigrams. */
public static final int FLAG_WORD_UNIGRAMS = 1 << 5;
/** Feature flag: enable non-CJK character unigrams. */
public static final int FLAG_CHAR_UNIGRAMS = 1 << 6;
/** Feature flag: enable character 4-grams. */
public static final int FLAG_4GRAMS = 1 << 7;
/** Feature flag: enable character 5-grams. */
public static final int FLAG_5GRAMS = 1 << 8;
/** Default flags for v1 models (word unigrams only). */
public static final int V1_DEFAULT_FLAGS = FLAG_WORD_UNIGRAMS;
private final int numBuckets;
private final int numClasses;
private final String[] labels;
private final float[] scales;
private final float[] biases;
/**
* Flat INT8 weight array in bucket-major order:
* {@code [bucket * numClasses + class]}.
*/
private final byte[] flatWeights;
/**
* Bitmask of feature flags that were active during training.
* See {@code FLAG_*} constants. Used by {@link #createExtractor()} to
* reconstruct the exact same feature extractor at inference time.
*/
private final int featureFlags;
/**
* Construct from class-major {@code byte[][]} weights with default feature
* configuration (word unigrams only ��� backward compatible with v1).
*/
public CharSoupModel(int numBuckets, int numClasses,
String[] labels, float[] scales,
float[] biases, byte[][] weights) {
this(numBuckets, numClasses, labels, scales, biases, weights, V1_DEFAULT_FLAGS);
}
/**
* Construct from class-major {@code byte[][]} weights with explicit feature flags.
*
* @param featureFlags bitmask of {@code FLAG_*} constants
*/
public CharSoupModel(int numBuckets, int numClasses,
String[] labels, float[] scales,
float[] biases, byte[][] weights,
int featureFlags) {
this.numBuckets = numBuckets;
this.numClasses = numClasses;
this.labels = labels;
this.scales = scales;
this.biases = biases;
this.flatWeights = transposeToBucketMajor(weights, numBuckets, numClasses);
this.featureFlags = featureFlags;
}
private CharSoupModel(int numBuckets, int numClasses,
String[] labels, float[] scales,
float[] biases, byte[] flatWeights,
int featureFlags) {
this.numBuckets = numBuckets;
this.numClasses = numClasses;
this.labels = labels;
this.scales = scales;
this.biases = biases;
this.flatWeights = flatWeights;
this.featureFlags = featureFlags;
}
private static byte[] transposeToBucketMajor(
byte[][] classMajor, int numBuckets,
int numClasses) {
byte[] flat = new byte[numBuckets * numClasses];
for (int c = 0; c < numClasses; c++) {
byte[] row = classMajor[c];
for (int b = 0; b < numBuckets; b++) {
flat[b * numClasses + c] = row[b];
}
}
return flat;
}
// ================================================================
// Loading
// ================================================================
/**
* Load a model from the classpath.
*/
public static CharSoupModel loadFromClasspath(
String resourcePath) throws IOException {
try (InputStream is =
CharSoupModel.class.getResourceAsStream(
resourcePath)) {
if (is == null) {
throw new IOException(
"Model resource not found: "
+ resourcePath);
}
return load(is);
}
}
/**
* Load a model from an input stream.
* Supports both v1 (LDM1) and v2 (LDM2) formats.
*/
public static CharSoupModel load(InputStream is)
throws IOException {
DataInputStream dis = new DataInputStream(is);
int magic = dis.readInt();
if (magic != MAGIC) {
throw new IOException(String.format(Locale.US,
"Invalid magic: expected 0x%08X, got 0x%08X",
MAGIC, magic));
}
int version = dis.readInt();
if (version != VERSION_V1 && version != VERSION_V2) {
throw new IOException(
"Unsupported version: " + version
+ " (expected " + VERSION_V1
+ " or " + VERSION_V2 + ")");
}
int numBuckets = dis.readInt();
int numClasses = dis.readInt();
int featureFlags = V1_DEFAULT_FLAGS;
if (version == VERSION_V2) {
featureFlags = dis.readInt();
}
String[] labels = readLabels(dis, numClasses);
float[] scales = readFloats(dis, numClasses);
float[] biases = readFloats(dis, numClasses);
byte[] flat = new byte[numBuckets * numClasses];
dis.readFully(flat);
return new CharSoupModel(numBuckets, numClasses,
labels, scales, biases, flat, featureFlags);
}
// ================================================================
// Saving
// ================================================================
/**
* Write the model in LDM2 binary format (includes feature flags).
*/
public void save(OutputStream os) throws IOException {
DataOutputStream dos = new DataOutputStream(os);
dos.writeInt(MAGIC);
dos.writeInt(VERSION_V2);
dos.writeInt(numBuckets);
dos.writeInt(numClasses);
dos.writeInt(featureFlags);
writeLabels(dos);
writeFloats(dos, scales);
writeFloats(dos, biases);
dos.write(flatWeights);
dos.flush();
}
// ================================================================
// Inference
// ================================================================
/**
* Compute softmax probabilities for the given feature
* vector. Uses a sparse inner loop ��� only non-zero
* buckets are visited.
*
* @param features int array of size {@code numBuckets}
* @return float array of size {@code numClasses}
* (softmax probabilities, sum ��� 1.0)
*/
public float[] predict(int[] features) {
float[] logits = predictLogits(features);
return softmax(logits);
}
/**
* Compute raw logits (pre-softmax scores) for the given
* feature vector. Higher logits indicate stronger match.
* Unlike {@link #predict}, this preserves the full dynamic
* range of the model's output, which is useful when
* comparing confidence across different input texts.
*
* @param features int array of size {@code numBuckets}
* @return float array of size {@code numClasses}
* (raw logits, not normalized)
*/
public float[] predictLogits(int[] features) {
int nnz = 0;
for (int b = 0; b < numBuckets; b++) {
if (features[b] != 0) {
nnz++;
}
}
int[] nzIdx = new int[nnz];
int pos = 0;
for (int b = 0; b < numBuckets; b++) {
if (features[b] != 0) {
nzIdx[pos++] = b;
}
}
long[] dots = new long[numClasses];
for (int i = 0; i < nnz; i++) {
int b = nzIdx[i];
int fv = features[b];
int off = b * numClasses;
for (int c = 0; c < numClasses; c++) {
dots[c] += (long) flatWeights[off + c] * fv;
}
}
float[] logits = new float[numClasses];
for (int c = 0; c < numClasses; c++) {
logits[c] = biases[c] + scales[c] * dots[c];
}
return logits;
}
/**
* In-place softmax with numerical stability.
*/
public static float[] softmax(float[] logits) {
float max = Float.NEGATIVE_INFINITY;
for (float v : logits) {
if (v > max) {
max = v;
}
}
float sum = 0f;
for (int i = 0; i < logits.length; i++) {
logits[i] = (float) Math.exp(logits[i] - max);
sum += logits[i];
}
if (sum > 0f) {
for (int i = 0; i < logits.length; i++) {
logits[i] /= sum;
}
}
return logits;
}
/**
* Shannon entropy (in bits) of a probability distribution.
*/
public static float entropy(float[] probs) {
double h = 0.0;
for (float p : probs) {
if (p > 0f) {
h -= p * (Math.log(p) / Math.log(2.0));
}
}
return (float) h;
}
// ================================================================
// Accessors
// ================================================================
public int getNumBuckets() {
return numBuckets;
}
public int getNumClasses() {
return numClasses;
}
public String[] getLabels() {
return labels;
}
public String getLabel(int classIndex) {
return labels[classIndex];
}
public float[] getScales() {
return scales;
}
public float[] getBiases() {
return biases;
}
/**
* Return weights in class-major {@code [class][bucket]}
* layout. Creates a new array each call.
*/
public byte[][] getWeights() {
byte[][] cm = new byte[numClasses][numBuckets];
for (int b = 0; b < numBuckets; b++) {
int off = b * numClasses;
for (int c = 0; c < numClasses; c++) {
cm[c][b] = flatWeights[off + c];
}
}
return cm;
}
/**
* Create the production {@link FeatureExtractor} for this model by dispatching
* on the {@link #featureFlags} embedded in the binary.
* <p>
* Supported flag sets:
* <ul>
* <li>{@link ScriptAwareFeatureExtractor#FEATURE_FLAGS} ��� general model</li>
* <li>{@link ShortTextFeatureExtractor#FEATURE_FLAGS} ��� short-text model</li>
* </ul>
*
* @throws IllegalStateException if the flags do not match any known production extractor.
* Experimental configs should use {@code ResearchFeatureExtractor} in the test module.
*/
public FeatureExtractor createExtractor() {
if (featureFlags == ScriptAwareFeatureExtractor.FEATURE_FLAGS) {
return new ScriptAwareFeatureExtractor(numBuckets);
}
if (featureFlags == ShortTextFeatureExtractor.FEATURE_FLAGS) {
return new ShortTextFeatureExtractor(numBuckets);
}
throw new IllegalStateException(String.format(
Locale.ROOT,
"No production FeatureExtractor for featureFlags=0x%03x. "
+ "Known: ScriptAware=0x%03x, ShortText=0x%03x. "
+ "Use ResearchFeatureExtractor (test scope) for experimental configs.",
featureFlags,
ScriptAwareFeatureExtractor.FEATURE_FLAGS,
ShortTextFeatureExtractor.FEATURE_FLAGS));
}
public int getFeatureFlags() {
return featureFlags;
}
/**
* Returns a new model with the same weights but a different feature-flags bitmask.
* Useful for correcting flags on models saved before this field was properly set.
*
* @param newFlags bitmask of {@code FLAG_*} constants
* @return copy of this model with updated feature flags
*/
public CharSoupModel withFeatureFlags(int newFlags) {
return new CharSoupModel(numBuckets, numClasses, labels.clone(),
scales.clone(), biases.clone(), flatWeights.clone(), newFlags);
}
// ================================================================
// Internal I/O helpers
// ================================================================
private static String[] readLabels(DataInputStream dis,
int numClasses)
throws IOException {
String[] labels = new String[numClasses];
for (int c = 0; c < numClasses; c++) {
int len = dis.readUnsignedShort();
byte[] utf8 = new byte[len];
dis.readFully(utf8);
labels[c] = new String(utf8, StandardCharsets.UTF_8);
}
return labels;
}
private static float[] readFloats(DataInputStream dis,
int count)
throws IOException {
float[] arr = new float[count];
for (int i = 0; i < count; i++) {
arr[i] = dis.readFloat();
}
return arr;
}
private void writeLabels(DataOutputStream dos)
throws IOException {
for (int c = 0; c < numClasses; c++) {
byte[] utf8 =
labels[c].getBytes(StandardCharsets.UTF_8);
dos.writeShort(utf8.length);
dos.write(utf8);
}
}
private static void writeFloats(DataOutputStream dos,
float[] arr)
throws IOException {
for (float v : arr) {
dos.writeFloat(v);
}
}
}