ConfusionDump.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.BufferedReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;

import org.apache.tika.langdetect.charsoup.CharSoupLanguageDetector;

/**
 * Confusion analysis for a CharSoup model against a TSV test file.
 *
 * Two modes:
 *
 *   Recall mode (default):
 *     Filters to rows where true label == targetLang and shows
 *     what the model predicted for them (where does targetLang bleed out).
 *
 *   False-positive mode (--fp flag):
 *     Scans ALL rows, keeps rows where predicted == targetLang but
 *     true label != targetLang, and shows which true labels are
 *     bleeding into targetLang.
 *
 * Prediction goes through the full CharSoupLanguageDetector pipeline
 * (script gating + length-gated confusables + confusable-group collapse),
 * so results reflect production behaviour. The modelFile argument is
 * accepted for API compatibility but ignored; the detector loads its
 * model from the classpath resource slot.
 *
 * Usage:
 *   ConfusionDump <testFile> <modelFile> <targetLang> [maxChars]
 *   ConfusionDump <testFile> <modelFile> <targetLang> --fp [maxChars]
 */
public class ConfusionDump {

    public static void main(String[] args) throws Exception {
        if (args.length < 3) {
            System.err.println(
                    "Usage: ConfusionDump <testFile> <modelFile> <targetLang> [--fp] [maxChars]");
            System.exit(1);
        }
        String testFile = args[0];
        // args[1] is the model file path ��� accepted for CLI compatibility but
        // the detector always loads from its classpath resource slot.
        String targetLang = args[2];

        boolean fpMode = false;
        boolean showSentences = false;
        String filterPredicted = null;
        int maxChars = Integer.MAX_VALUE;
        for (int i = 3; i < args.length; i++) {
            if ("--fp".equals(args[i])) {
                fpMode = true;
            } else if ("--show".equals(args[i])) {
                showSentences = true;
            } else if (args[i].startsWith("--predicted=")) {
                filterPredicted = args[i].substring("--predicted=".length());
            } else {
                maxChars = Integer.parseInt(args[i]);
            }
        }

        CharSoupLanguageDetector detector = new CharSoupLanguageDetector();

        if (fpMode) {
            runFalsePositiveMode(testFile, detector, targetLang, maxChars);
        } else {
            runRecallMode(testFile, detector, targetLang, maxChars,
                    showSentences, filterPredicted);
        }
    }

    private static void runRecallMode(String testFile, CharSoupLanguageDetector detector,
            String targetLang, int maxChars, boolean showSentences, String filterPredicted)
            throws Exception {
        int total = 0, correct = 0;
        Map<String, Integer> predicted = new HashMap<>();

        try (BufferedReader br = Files.newBufferedReader(
                Paths.get(testFile), StandardCharsets.UTF_8)) {
            String line;
            while ((line = br.readLine()) != null) {
                int tab = line.indexOf('\t');
                if (tab < 0) continue;
                String lang = line.substring(0, tab).trim();
                if (!lang.equals(targetLang)) continue;
                String text = line.substring(tab + 1).trim();
                if (text.isEmpty()) continue;
                if (maxChars < text.length()) text = text.substring(0, maxChars);

                String pred = predict(detector, text);
                total++;
                if (pred.equals(targetLang)) correct++;
                predicted.merge(pred, 1, Integer::sum);

                if (showSentences
                        && !pred.equals(targetLang)
                        && (filterPredicted == null || filterPredicted.equals(pred))) {
                    System.out.printf(Locale.US, "[%s���%s] %s%n", targetLang, pred,
                            text.length() > 120 ? text.substring(0, 120) + "���" : text);
                }
            }
        }

        System.out.printf(Locale.US,
                "%nRecall for '%s': %d correct / %d total = %.2f%%%n%n",
                targetLang, correct, total, 100.0 * correct / total);
        printHistogram(predicted, total);
    }

    private static void runFalsePositiveMode(String testFile,
            CharSoupLanguageDetector detector, String targetLang, int maxChars)
            throws Exception {
        int totalFp = 0;
        Map<String, Integer> trueLangs = new HashMap<>();

        try (BufferedReader br = Files.newBufferedReader(
                Paths.get(testFile), StandardCharsets.UTF_8)) {
            String line;
            while ((line = br.readLine()) != null) {
                int tab = line.indexOf('\t');
                if (tab < 0) continue;
                String trueLang = line.substring(0, tab).trim();
                if (trueLang.equals(targetLang)) continue;
                String text = line.substring(tab + 1).trim();
                if (text.isEmpty()) continue;
                if (maxChars < text.length()) text = text.substring(0, maxChars);

                String pred = predict(detector, text);
                if (pred.equals(targetLang)) {
                    totalFp++;
                    trueLangs.merge(trueLang, 1, Integer::sum);
                }
            }
        }

        System.out.printf(Locale.US,
                "False positives for '%s': %d sentences from other languages predicted as '%s'%n%n",
                targetLang, totalFp, targetLang);
        printHistogram(trueLangs, totalFp);
    }

    private static String predict(CharSoupLanguageDetector detector, String text) {
        detector.reset();
        detector.addText(text.toCharArray(), 0, text.length());
        List<org.apache.tika.language.detect.LanguageResult> results = detector.detectAll();
        if (results.isEmpty()) return "";
        String lang = results.get(0).getLanguage();
        return lang == null ? "" : lang;
    }

    private static void printHistogram(Map<String, Integer> counts, int total) {
        List<Map.Entry<String, Integer>> entries = new ArrayList<>(counts.entrySet());
        entries.sort(Map.Entry.<String, Integer>comparingByValue().reversed());
        System.out.printf(Locale.US, "  %-16s  %6s  %7s%n", "Language", "Count", "Share");
        System.out.println("  " + "-".repeat(34));
        for (Map.Entry<String, Integer> e : entries) {
            System.out.printf(Locale.US, "  %-16s  %6d  %6.1f%%%n",
                    e.getKey(), e.getValue(),
                    total > 0 ? 100.0 * e.getValue() / total : 0.0);
        }
    }
}