ConfusableDiff.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.BufferedReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TreeMap;


/**
 * Trains two models (bigrams-only vs bigrams+wordUnigrams) and
 * compares per-language F1 to see where word unigrams help most.
 * Highlights known confusable language groups.
 *
 * Usage: ConfusableDiff <prepDir>
 */
public class ConfusableDiff {

    private static final String[][] CONFUSABLE_GROUPS = {
        {"hrv", "bos", "srp"},
        {"nob", "nno", "dan", "swe"},
        {"ces", "slk"},
        {"ind", "msa", "zlm"},
        {"por", "glg", "spa"},
        {"aze", "tur", "tuk"},
        {"ukr", "rus", "bel"},
        {"bul", "mkd"},
        {"hin", "mar"},
        {"urd", "fas"},
        {"nld", "afr"},
        {"cat", "oci"},
        {"zho", "jpn", "kor"},
    };

    public static void main(String[] args) throws Exception {
        if (args.length < 1) {
            System.err.println("Usage: ConfusableDiff <prepDir>");
            System.exit(1);
        }
        Path prepDir = Paths.get(args[0]);
        Path trainFile = prepDir.resolve("train_5m.txt");
        if (!Files.exists(trainFile)) {
            trainFile = prepDir.resolve("train.txt");
        }

        int threads = Runtime.getRuntime().availableProcessors();
        int buckets = 16384;

        System.out.println("Loading dev + test...");
        List<LabeledSentence> dev = readReservoir(
                prepDir.resolve("dev.txt"), 100_000);
        List<LabeledSentence> test = readReservoir(
                prepDir.resolve("test.txt"), 200_000);
        System.out.printf(Locale.US, "dev=%,d (%d langs)  test=%,d (%d langs)%n%n",
                dev.size(), countLangs(dev), test.size(), countLangs(test));

        // Train bigrams-only model
        System.out.println("=== Training: bigrams-only ===");
        Phase2Trainer bigramOnly = new Phase2Trainer(buckets)
                .setAdamLr(0.001f)
                .setSgdLr(0.01f, 0.001f)
                .setAdamEpochs(2)
                .setMaxEpochs(6)
                .setPatience(2)
                .setCheckpointInterval(500_000)
                .setDevSubsampleSize(10_000)
                .setNumThreads(threads)
                .setVerbose(false)
                .setPreprocessed(true);
        bigramOnly.train(trainFile, dev);

        // Train bigrams+wordUnigrams model
        System.out.println("=== Training: bigrams+wordUnigrams ===");
        Phase2Trainer withWords = new Phase2Trainer(buckets)
                .setAdamLr(0.001f)
                .setSgdLr(0.01f, 0.001f)
                .setAdamEpochs(2)
                .setMaxEpochs(6)
                .setPatience(2)
                .setCheckpointInterval(500_000)
                .setDevSubsampleSize(10_000)
                .setNumThreads(threads)
                .setVerbose(false)
                .setPreprocessed(true);
        withWords.train(trainFile, dev);

        // Compute per-language F1 on test set
        Map<String, Double> f1Bigram = perLanguageF1(bigramOnly, test);
        Map<String, Double> f1Words = perLanguageF1(withWords, test);

        // Build confusable group lookup
        Map<String, String> groupLabel = new HashMap<>();
        for (String[] group : CONFUSABLE_GROUPS) {
            String label = String.join("/", group);
            for (String lang : group) {
                groupLabel.put(lang, label);
            }
        }

        // Compute diffs sorted by improvement
        TreeMap<Double, String> byDelta = new TreeMap<>();
        Set<String> allLangs = new HashSet<>();
        allLangs.addAll(f1Bigram.keySet());
        allLangs.addAll(f1Words.keySet());

        System.out.println();
        System.out.printf(Locale.US,
                "%-8s  %8s  %8s  %8s  %-25s%n",
                "lang", "bigram", "+words", "delta", "confusable_group");
        System.out.println("-".repeat(75));

        // Collect and sort
        List<String[]> rows = new ArrayList<>();
        for (String lang : allLangs) {
            double fb = f1Bigram.getOrDefault(lang, 0.0);
            double fw = f1Words.getOrDefault(lang, 0.0);
            double delta = fw - fb;
            rows.add(new String[]{
                    lang,
                    String.format(Locale.US, "%.4f", fb),
                    String.format(Locale.US, "%.4f", fw),
                    String.format(Locale.US, "%+.4f", delta),
                    groupLabel.getOrDefault(lang, "")
            });
        }
        rows.sort((a, b) -> {
            double da = Double.parseDouble(a[3]);
            double db = Double.parseDouble(b[3]);
            return Double.compare(db, da);
        });

        double confGainSum = 0;
        int confGainCount = 0;
        double nonConfGainSum = 0;
        int nonConfGainCount = 0;

        for (String[] row : rows) {
            String lang = row[0];
            double delta = Double.parseDouble(row[3]);
            boolean isConf = groupLabel.containsKey(lang);
            String marker = "";
            if (isConf) {
                marker = " <<";
                confGainSum += delta;
                confGainCount++;
            } else {
                nonConfGainSum += delta;
                nonConfGainCount++;
            }
            System.out.printf(Locale.US,
                    "%-8s  %8s  %8s  %8s  %-25s%s%n",
                    row[0], row[1], row[2], row[3], row[4], marker);
        }

        System.out.println();
        System.out.printf(Locale.US,
                "Avg delta (confusable langs, n=%d): %+.4f%n",
                confGainCount,
                confGainCount > 0 ? confGainSum / confGainCount : 0.0);
        System.out.printf(Locale.US,
                "Avg delta (other langs, n=%d):       %+.4f%n",
                nonConfGainCount,
                nonConfGainCount > 0 ? nonConfGainSum / nonConfGainCount : 0.0);

        // Top confusion pairs for each model
        System.out.println("\n=== Top-20 confusion pairs: bigrams-only ===");
        printTopConfusions(bigramOnly, test, 20);

        System.out.println("\n=== Top-20 confusion pairs: +wordUnigrams ===");
        printTopConfusions(withWords, test, 20);
    }

    private static Map<String, Double> perLanguageF1(
            Phase2Trainer model, List<LabeledSentence> data) {
        Map<String, int[]> counts = new HashMap<>();

        for (LabeledSentence s : data) {
            String trueLabel = s.getLanguage();
            String predicted = model.predict(s.getText());

            counts.computeIfAbsent(trueLabel,
                    k -> new int[3]);
            if (predicted.equals(trueLabel)) {
                counts.get(trueLabel)[0]++;
            } else {
                counts.get(trueLabel)[2]++;
                counts.computeIfAbsent(predicted,
                        k -> new int[3])[1]++;
            }
        }

        Map<String, Double> f1Map = new HashMap<>();
        for (Map.Entry<String, int[]> e : counts.entrySet()) {
            int tp = e.getValue()[0];
            int fp = e.getValue()[1];
            int fn = e.getValue()[2];
            if (tp + fn == 0) {
                continue;
            }
            double p = tp + fp > 0 ? (double) tp / (tp + fp) : 0;
            double r = (double) tp / (tp + fn);
            double f1 = p + r > 0 ? 2 * p * r / (p + r) : 0;
            f1Map.put(e.getKey(), f1);
        }
        return f1Map;
    }

    private static void printTopConfusions(
            Phase2Trainer model, List<LabeledSentence> data,
            int topN) {
        Map<String, Integer> confPairs = new HashMap<>();

        for (LabeledSentence s : data) {
            String predicted = model.predict(s.getText());
            if (!predicted.equals(s.getLanguage())) {
                String key = s.getLanguage() + " -> " + predicted;
                confPairs.merge(key, 1, Integer::sum);
            }
        }

        confPairs.entrySet().stream()
                .sorted((a, b) -> b.getValue() - a.getValue())
                .limit(topN)
                .forEach(e -> System.out.printf(
                        Locale.US, "  %5d  %s%n",
                        e.getValue(), e.getKey()));
    }

    private static int countLangs(List<LabeledSentence> data) {
        Set<String> langs = new HashSet<>();
        for (LabeledSentence s : data) {
            langs.add(s.getLanguage());
        }
        return langs.size();
    }

    private static List<LabeledSentence> readReservoir(
            Path file, int maxLines) throws Exception {
        LabeledSentence[] reservoir =
                new LabeledSentence[maxLines];
        Random rng = new Random(42);
        int seen = 0;
        try (BufferedReader br = Files.newBufferedReader(
                file, StandardCharsets.UTF_8)) {
            String line;
            while ((line = br.readLine()) != null) {
                int tab = line.indexOf('\t');
                if (tab < 0) {
                    continue;
                }
                LabeledSentence s = new LabeledSentence(
                        line.substring(0, tab),
                        line.substring(tab + 1));
                if (seen < maxLines) {
                    reservoir[seen] = s;
                } else {
                    int j = rng.nextInt(seen + 1);
                    if (j < maxLines) {
                        reservoir[j] = s;
                    }
                }
                seen++;
            }
        }
        int fill = Math.min(seen, maxLines);
        List<LabeledSentence> result = new ArrayList<>(fill);
        for (int i = 0; i < fill; i++) {
            result.add(reservoir[i]);
        }
        return result;
    }
}