CalibrateConfidence.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.DataInputStream;
import java.io.FileInputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.FeatureExtractor;

/**
 * Empirically calibrates per-class confidence by computing the distribution
 * of (max_logit, margin) over true-positive predictions on the Flores dev set.
 *
 * For each sentence:
 *   - predicted label = argmax(logits)
 *   - max_logit       = logits[predicted]
 *   - margin          = logits[0] - logits[1]  (top-1 minus top-2)
 *
 * True positives (ground truth == predicted) build the per-class calibration
 * distributions.  False positives are then z-scored against the predicted
 * class's true-positive distribution to surface "implausible" predictions.
 *
 * Usage:
 *   CalibrateConfidence <flores_tsv> <model_bin> [focus_lang ...]
 *
 * If focus_lang values are given, detailed false-positive z-scores are printed
 * only for those (Flores) language prefixes (e.g. "bod", "dzo").
 * Without focus_lang, a full per-class calibration table is printed.
 */
public class CalibrateConfidence {

    public static void main(String[] args) throws Exception {
        if (args.length < 2) {
            System.err.println(
                    "Usage: CalibrateConfidence <flores_tsv> <model_bin> "
                    + "[focus_lang ...]");
            System.exit(1);
        }
        String floresPath = args[0];
        String modelPath  = args[1];
        List<String> focusLangs = args.length > 2
                ? Arrays.asList(Arrays.copyOfRange(args, 2, args.length))
                : List.of();

        CharSoupModel model;
        try (DataInputStream dis = new DataInputStream(
                new FileInputStream(modelPath))) {
            model = CharSoupModel.load(dis);
        }
        String[] labels = model.getLabels();
        FeatureExtractor ext = model.createExtractor();
        int[] buf = new int[model.getNumBuckets()];

        // label ��� model index
        Map<String, Integer> labelIndex = new HashMap<>();
        for (int i = 0; i < labels.length; i++) {
            labelIndex.put(labels[i], i);
        }

        // per-class accumulators for true positives
        // stats[classIdx] = list of (maxLogit, margin)
        Map<Integer, List<float[]>> tpStats = new HashMap<>();

        // false positives we want to z-score:
        // list of (floresLang, text, predictedIdx, maxLogit, margin)
        List<Object[]> focusFP = new ArrayList<>();

        List<String> lines = Files.readAllLines(Paths.get(floresPath));
        for (String line : lines) {
            int tab = line.indexOf('\t');
            if (tab < 0) continue;
            String floresLang = line.substring(0, tab);
            String text = line.substring(tab + 1).trim();
            if (text.isEmpty()) continue;

            // map flores lang code (e.g. "kor_Hang") ��� model label (e.g. "kor")
            String trueLabel = floresLang.contains("_")
                    ? floresLang.substring(0, floresLang.indexOf('_'))
                    : floresLang;

            ext.extractFromPreprocessed(text, buf, true);
            float[] logits = model.predictLogits(buf);

            // rank
            Integer[] idx = new Integer[logits.length];
            for (int i = 0; i < idx.length; i++) idx[i] = i;
            Arrays.sort(idx, (a, b) -> Float.compare(logits[b], logits[a]));

            int predIdx   = idx[0];
            float maxLogit = logits[predIdx];
            float margin   = logits[idx[0]] - logits[idx[1]];
            String predLabel = labels[predIdx];

            Integer trueIdx = labelIndex.get(trueLabel);
            boolean correct = trueIdx != null && trueIdx == predIdx;

            if (correct) {
                tpStats.computeIfAbsent(predIdx, k -> new ArrayList<>())
                       .add(new float[]{maxLogit, margin});
            }

            // collect focus-language false positives
            if (!focusLangs.isEmpty()) {
                boolean isFocus = focusLangs.stream()
                        .anyMatch(fl -> floresLang.startsWith(fl));
                if (isFocus && !correct) {
                    focusFP.add(new Object[]{
                            floresLang, text, predIdx, maxLogit, margin});
                }
            }
        }

        // compute per-class mean/std for maxLogit and margin
        // classStats[classIdx] = {meanLogit, stdLogit, meanMargin, stdMargin, n}
        Map<Integer, double[]> classStats = new HashMap<>();
        for (var e : tpStats.entrySet()) {
            List<float[]> samples = e.getValue();
            int n = samples.size();
            double sumL = 0, sumM = 0;
            for (float[] s : samples) {
                sumL += s[0];
                sumM += s[1];
            }
            double meanL = sumL / n, meanM = sumM / n;
            double varL = 0, varM = 0;
            for (float[] s : samples) {
                varL += (s[0] - meanL) * (s[0] - meanL);
                varM += (s[1] - meanM) * (s[1] - meanM);
            }
            double stdL = n > 1 ? Math.sqrt(varL / (n - 1)) : 0;
            double stdM = n > 1 ? Math.sqrt(varM / (n - 1)) : 0;
            classStats.put(e.getKey(), new double[]{meanL, stdL, meanM, stdM, n});
        }

        if (focusLangs.isEmpty()) {
            // print full calibration table sorted by label
            System.out.printf(Locale.US, "%-12s  %5s  %8s  %7s  %8s  %7s%n",
                    "Label", "N_tp",
                    "logit_mu", "logit_��", "margin_mu", "margin_��");
            System.out.println("���".repeat(62));
            for (int i = 0; i < labels.length; i++) {
                double[] st = classStats.get(i);
                if (st == null) continue;
                System.out.printf(Locale.US,
                        "%-12s  %5.0f  %8.2f  %7.2f  %8.2f  %7.2f%n",
                        labels[i], st[4], st[0], st[1], st[2], st[3]);
            }
        } else {
            // print focus false-positive z-scores
            System.out.printf(Locale.US,
                    "%-14s  %-12s  %9s  %7s  %9s  %7s  %60s%n",
                    "FloresLang", "Predicted",
                    "maxLogit", "z_logit", "margin", "z_marg", "Text");
            System.out.println("���".repeat(120));
            for (Object[] fp : focusFP) {
                String floresLang = (String) fp[0];
                String text       = (String) fp[1];
                int predIdx       = (int)    fp[2];
                float maxLogit    = (float)  fp[3];
                float margin      = (float)  fp[4];

                double[] st = classStats.get(predIdx);
                double zLogit = st != null && st[1] > 0
                        ? (maxLogit - st[0]) / st[1] : Double.NaN;
                double zMarg  = st != null && st[3] > 0
                        ? (margin   - st[2]) / st[3] : Double.NaN;

                System.out.printf(Locale.US,
                        "%-14s  %-12s  %9.2f  %7.2f  %9.2f  %7.2f  %.60s%n",
                        floresLang, labels[predIdx],
                        maxLogit, zLogit, margin, zMarg, text);
            }

            // also show the calibration stats for predicted classes involved
            System.out.println();
            System.out.println("Calibration stats for predicted classes:");
            System.out.printf(Locale.US, "%-12s  %5s  %8s  %7s  %8s  %7s%n",
                    "Label", "N_tp",
                    "logit_mu", "logit_��", "margin_mu", "margin_��");
            System.out.println("���".repeat(62));
            focusFP.stream()
                    .mapToInt(fp -> (int) fp[2])
                    .distinct()
                    .forEach(predIdx -> {
                        double[] st = classStats.get(predIdx);
                        if (st == null) return;
                        System.out.printf(Locale.US,
                                "%-12s  %5.0f  %8.2f  %7.2f  %8.2f  %7.2f%n",
                                labels[predIdx], st[4],
                                st[0], st[1], st[2], st[3]);
                    });
        }
    }
}