LinearModel.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.ml;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.Locale;
import java.util.zip.GZIPInputStream;

/**
 * INT8-quantized multinomial logistic regression model for classification.
 * <p>
 * Binary format (big-endian, magic "LDM1"):
 * <pre>
 *   Offset  Field
 *   0       4B magic: 0x4C444D31
 *   4       4B version: 1
 *   8       4B numBuckets (B)
 *   12      4B numClasses (C)
 *   16+     Labels: C entries of [2B length + UTF-8 bytes]
 *           Scales: C �� 4B float (per-class dequantization)
 *           Biases: C �� 4B float (per-class bias term)
 *           Weights: B �� C bytes (bucket-major, INT8 signed)
 * </pre>
 * <p>
 * Weights are stored in bucket-major order:
 * {@code weights[bucket * numClasses + class]}. This layout
 * is optimal for the sparse dot-product in {@link #predict}
 * ��� each non-zero bucket reads a contiguous run of
 * {@code numClasses} bytes, ideal for SIMD and cache
 * prefetching.
 */
public class LinearModel {

    public static final int MAGIC = 0x4C444D31; // "LDM1"
    public static final int VERSION = 1;

    private final int numBuckets;
    private final int numClasses;
    private final String[] labels;
    private final float[] scales;
    private final float[] biases;

    /**
     * Flat INT8 weight array in bucket-major order:
     * {@code [bucket * numClasses + class]}.
     */
    private final byte[] flatWeights;

    /**
     * Construct from class-major {@code byte[][]} weights.
     * Transposes to bucket-major flat layout internally.
     */
    public LinearModel(int numBuckets, int numClasses,
                       String[] labels, float[] scales,
                       float[] biases, byte[][] weights) {
        this.numBuckets = numBuckets;
        this.numClasses = numClasses;
        this.labels = labels;
        this.scales = scales;
        this.biases = biases;
        this.flatWeights = transposeToBucketMajor(weights, numBuckets, numClasses);
    }

    private LinearModel(int numBuckets, int numClasses,
                        String[] labels, float[] scales,
                        float[] biases, byte[] flatWeights) {
        this.numBuckets = numBuckets;
        this.numClasses = numClasses;
        this.labels = labels;
        this.scales = scales;
        this.biases = biases;
        this.flatWeights = flatWeights;
    }

    private static byte[] transposeToBucketMajor(
            byte[][] classMajor, int numBuckets, int numClasses) {
        byte[] flat = new byte[numBuckets * numClasses];
        for (int c = 0; c < numClasses; c++) {
            byte[] row = classMajor[c];
            for (int b = 0; b < numBuckets; b++) {
                flat[b * numClasses + c] = row[b];
            }
        }
        return flat;
    }

    // ================================================================
    //  Loading
    // ================================================================

    /**
     * Load a model from the classpath.  Transparently handles both plain
     * LDM1 binaries and gzip-compressed LDM1 binaries (detected by magic bytes).
     */
    public static LinearModel loadFromClasspath(String resourcePath) throws IOException {
        try (InputStream is = LinearModel.class.getResourceAsStream(resourcePath)) {
            if (is == null) {
                throw new IOException("Model resource not found: " + resourcePath);
            }
            return load(is);
        }
    }

    /**
     * Load a model from a file on disk.  Transparently handles both plain
     * and gzip-compressed LDM1 files.
     */
    public static LinearModel loadFromPath(java.nio.file.Path path) throws IOException {
        try (InputStream is = new BufferedInputStream(
                java.nio.file.Files.newInputStream(path))) {
            return load(is);
        }
    }

    /**
     * Load a model from an input stream.  Transparently handles both plain
     * LDM1 binaries and gzip-compressed ones: if the first two bytes are the
     * gzip magic {@code 0x1F 0x8B} the stream is wrapped in a
     * {@link GZIPInputStream} before reading.
     */
    public static LinearModel load(InputStream is) throws IOException {
        // Buffer so we can peek at the magic without consuming it
        BufferedInputStream buf = is instanceof BufferedInputStream
                ? (BufferedInputStream) is : new BufferedInputStream(is);
        buf.mark(2);
        int b0 = buf.read();
        int b1 = buf.read();
        buf.reset();
        if (b0 == 0x1F && b1 == 0x8B) {
            is = new GZIPInputStream(buf);
        } else {
            is = buf;
        }
        return loadRaw(is);
    }

    /** Read LDM1 from an already-unwrapped (non-gzip) stream. */
    private static LinearModel loadRaw(InputStream is) throws IOException {
        DataInputStream dis = new DataInputStream(is);
        int magic = dis.readInt();
        if (magic != MAGIC) {
            throw new IOException(String.format(Locale.US,
                    "Invalid magic: expected 0x%08X, got 0x%08X", MAGIC, magic));
        }
        int version = dis.readInt();
        if (version != VERSION) {
            throw new IOException(
                    "Unsupported version: " + version + " (expected " + VERSION + ")");
        }

        int numBuckets = dis.readInt();
        int numClasses = dis.readInt();

        String[] labels = readLabels(dis, numClasses);
        float[] scales = readFloats(dis, numClasses);
        float[] biases = readFloats(dis, numClasses);

        byte[] flat = new byte[numBuckets * numClasses];
        dis.readFully(flat);

        return new LinearModel(numBuckets, numClasses, labels, scales, biases, flat);
    }

    // ================================================================
    //  Saving
    // ================================================================

    /**
     * Write the model in LDM1 binary format.
     */
    public void save(OutputStream os) throws IOException {
        DataOutputStream dos = new DataOutputStream(os);
        dos.writeInt(MAGIC);
        dos.writeInt(VERSION);
        dos.writeInt(numBuckets);
        dos.writeInt(numClasses);
        writeLabels(dos);
        writeFloats(dos, scales);
        writeFloats(dos, biases);
        dos.write(flatWeights);
        dos.flush();
    }

    // ================================================================
    //  Inference
    // ================================================================

    /**
     * Compute raw logits for the given feature vector (before softmax).
     * Uses a sparse inner loop ��� only non-zero buckets are visited.
     *
     * @param features int array of size {@code numBuckets}
     * @return float array of size {@code numClasses} (raw, unnormalized logits)
     */
    public float[] predictLogits(int[] features) {
        int nnz = 0;
        for (int b = 0; b < numBuckets; b++) {
            if (features[b] != 0) {
                nnz++;
            }
        }
        int[] nzIdx = new int[nnz];
        int pos = 0;
        for (int b = 0; b < numBuckets; b++) {
            if (features[b] != 0) {
                nzIdx[pos++] = b;
            }
        }

        long[] dots = new long[numClasses];
        for (int i = 0; i < nnz; i++) {
            int b = nzIdx[i];
            int fv = features[b];
            int off = b * numClasses;
            for (int c = 0; c < numClasses; c++) {
                dots[c] += (long) flatWeights[off + c] * fv;
            }
        }

        float[] logits = new float[numClasses];
        for (int c = 0; c < numClasses; c++) {
            logits[c] = biases[c] + scales[c] * dots[c];
        }
        return logits;
    }

    /**
     * Compute softmax probabilities for the given feature vector.
     *
     * @param features int array of size {@code numBuckets}
     * @return float array of size {@code numClasses} (softmax probabilities, sum ��� 1.0)
     */
    public float[] predict(int[] features) {
        return softmax(predictLogits(features));
    }

    /**
     * In-place softmax with numerical stability.
     */
    public static float[] softmax(float[] logits) {
        float max = Float.NEGATIVE_INFINITY;
        for (float v : logits) {
            if (v > max) {
                max = v;
            }
        }
        float sum = 0f;
        for (int i = 0; i < logits.length; i++) {
            logits[i] = (float) Math.exp(logits[i] - max);
            sum += logits[i];
        }
        if (sum > 0f) {
            for (int i = 0; i < logits.length; i++) {
                logits[i] /= sum;
            }
        }
        return logits;
    }

    /**
     * Shannon entropy (in bits) of a probability distribution.
     */
    public static float entropy(float[] probs) {
        double h = 0.0;
        for (float p : probs) {
            if (p > 0f) {
                h -= p * (Math.log(p) / Math.log(2.0));
            }
        }
        return (float) h;
    }

    // ================================================================
    //  Accessors
    // ================================================================

    public int getNumBuckets() {
        return numBuckets;
    }

    public int getNumClasses() {
        return numClasses;
    }

    public String[] getLabels() {
        return labels;
    }

    public String getLabel(int classIndex) {
        return labels[classIndex];
    }

    public float[] getScales() {
        return scales;
    }

    public float[] getBiases() {
        return biases;
    }

    /**
     * Return weights in class-major {@code [class][bucket]} layout.
     * Creates a new array each call.
     */
    public byte[][] getWeights() {
        byte[][] cm = new byte[numClasses][numBuckets];
        for (int b = 0; b < numBuckets; b++) {
            int off = b * numClasses;
            for (int c = 0; c < numClasses; c++) {
                cm[c][b] = flatWeights[off + c];
            }
        }
        return cm;
    }

    // ================================================================
    //  Internal I/O helpers
    // ================================================================

    private static String[] readLabels(DataInputStream dis, int numClasses) throws IOException {
        String[] labels = new String[numClasses];
        for (int c = 0; c < numClasses; c++) {
            int len = dis.readUnsignedShort();
            byte[] utf8 = new byte[len];
            dis.readFully(utf8);
            labels[c] = new String(utf8, StandardCharsets.UTF_8);
        }
        return labels;
    }

    private static float[] readFloats(DataInputStream dis, int count) throws IOException {
        float[] arr = new float[count];
        for (int i = 0; i < count; i++) {
            arr[i] = dis.readFloat();
        }
        return arr;
    }

    private void writeLabels(DataOutputStream dos) throws IOException {
        for (int c = 0; c < numClasses; c++) {
            byte[] utf8 = labels[c].getBytes(StandardCharsets.UTF_8);
            dos.writeShort(utf8.length);
            dos.write(utf8);
        }
    }

    private static void writeFloats(DataOutputStream dos, float[] arr) throws IOException {
        for (float v : arr) {
            dos.writeFloat(v);
        }
    }
}