ModelQuantizer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.ScriptAwareFeatureExtractor;

/**
 * Quantizes float32 model weights to INT8 for compact storage.
 * <p>
 * Per-class quantization: each class row is independently scaled
 * to fit into [-127, 127]. The scale factor is stored alongside
 * the quantized weights to allow dequantization at inference time.
 * </p>
 */
public class ModelQuantizer {

    private ModelQuantizer() {
    }

    /**
     * Quantize a Phase2Trainer's float32 weights to INT8.
     * <p>
     * The feature flags are taken directly from
     * {@link ScriptAwareFeatureExtractor#FEATURE_FLAGS} because training
     * always uses {@code ScriptAwareFeatureExtractor} to match the inference
     * path in {@link CharSoupModel#createExtractor()}.
     *
     * @param trainer the trained Phase2Trainer
     * @return a CharSoupModel with INT8 quantized weights and correct feature flags
     */
    public static CharSoupModel quantize(Phase2Trainer trainer) {
        return quantize(trainer.getLabels(),
                trainer.getWeightsClassMajor(),
                trainer.getBiases(),
                trainer.getNumBuckets(),
                ScriptAwareFeatureExtractor.FEATURE_FLAGS);
    }

    /**
     * Quantize float32 weights to INT8.
     *
     * @param labels       class labels
     * @param weights      float32 weights [numClasses][numBuckets]
     * @param biases       float32 biases [numClasses]
     * @param numBuckets   number of feature buckets
     * @param featureFlags bitmask of {@link CharSoupModel}{@code .FLAG_*} constants
     * @return a CharSoupModel with INT8 quantized weights
     */
    public static CharSoupModel quantize(String[] labels,
                                       float[][] weights,
                                       float[] biases,
                                       int numBuckets,
                                       int featureFlags) {
        int numClasses = labels.length;
        float[] scales = new float[numClasses];
        byte[][] quantizedWeights =
                new byte[numClasses][numBuckets];

        for (int c = 0; c < numClasses; c++) {
            float maxAbs = 0f;
            for (int b = 0; b < numBuckets; b++) {
                float abs = Math.abs(weights[c][b]);
                if (abs > maxAbs) {
                    maxAbs = abs;
                }
            }

            if (maxAbs == 0f) {
                scales[c] = 1f;
            } else {
                scales[c] = maxAbs / 127f;
            }

            for (int b = 0; b < numBuckets; b++) {
                int q = Math.round(weights[c][b] / scales[c]);
                quantizedWeights[c][b] =
                        (byte) Math.max(-127, Math.min(127, q));
            }
        }

        return new CharSoupModel(numBuckets, numClasses,
                labels.clone(), scales, biases.clone(),
                quantizedWeights, featureFlags);
    }
}