QuickF1Eval.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.BufferedInputStream;
import java.io.BufferedWriter;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;

import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.FeatureExtractor;
import org.apache.tika.language.detect.LanguageDetector;
import org.apache.tika.language.detect.LanguageResult;

/**
 * Quick evaluation tool: compute macro F1, accuracy, and throughput
 * for multiple CharSoup models and OpenNLP on the same raw test set.
 * <p>
 * Usage:
 * <pre>
 *   QuickF1Eval &lt;testFile&gt; &lt;model1.ldm&gt; [model2.ldm ...] [reportFile]
 * </pre>
 * The test file is raw (unpreprocessed) lang\ttext format.
 * The last argument is treated as the report file if it doesn't
 * end in .ldm.
 */
public class QuickF1Eval {

    /**
     * Maps every confusable language to a canonical representative
     * (the first member of its group). Non-confusable languages
     * map to themselves.
     */
    private static final Map<String, String> CANONICAL = buildCanonical();

    private static Map<String, String> buildCanonical() {
        Map<String, String> map = new HashMap<>();
        for (String[] group : CompareDetectors.CONFUSABLE_GROUPS) {
            String canon = group[0];
            for (String lang : group) {
                map.put(lang, canon);
            }
        }
        return map;
    }

    static String canonicalize(String lang) {
        return CANONICAL.getOrDefault(lang, lang);
    }

    public static void main(String[] args) throws Exception {
        if (args.length < 2) {
            System.err.println(
                    "Usage: QuickF1Eval <testFile> <model1.ldm>"
                            + " [model2.ldm ...] [langFile]"
                            + " [reportFile]");
            System.exit(1);
        }

        Path testFile = Paths.get(args[0]);
        int threads = Runtime.getRuntime().availableProcessors();

        List<Path> modelFiles = new ArrayList<>();
        Path reportFile = null;
        Path langFile = null;
        boolean includeOpenNLP = true;
        boolean includeOptimaize = true;
        boolean doCollapse = true;
        boolean perLang = false;
        for (int i = 1; i < args.length; i++) {
            if (args[i].endsWith(".ldm")) {
                modelFiles.add(Paths.get(args[i]));
            } else if (args[i].endsWith("_langs.txt")) {
                langFile = Paths.get(args[i]);
                includeOpenNLP = false;
                includeOptimaize = false;
            } else if ("--no-collapse".equals(args[i])) {
                doCollapse = false;
            } else if ("--per-lang".equals(args[i])) {
                perLang = true;
            } else {
                reportFile = Paths.get(args[i]);
            }
        }

        // Load test data (raw text)
        System.out.println("Loading test data: " + testFile);
        List<LabeledSentence> allData =
                TrainLanguageModel.readPreprocessedFile(testFile);
        System.out.printf(Locale.US,
                "Test sentences: %,d%n", allData.size());

        // Determine allowed canonical languages
        Set<String> allowedLangs = new HashSet<>();
        String langSource;
        if (langFile != null) {
            // Use explicit language whitelist
            for (String line : Files.readAllLines(
                    langFile, StandardCharsets.UTF_8)) {
                String l = line.trim();
                if (!l.isEmpty()) {
                    allowedLangs.add(l);
                }
            }
            langSource = langFile.getFileName().toString();
        } else {
            // Use OpenNLP's language set
            System.out.println("Loading OpenNLP...");
            LanguageDetector opennlp =
                    CompareDetectors.loadDetector(
                            "org.apache.tika.langdetect.opennlp"
                                    + ".OpenNLPDetector");
            if (opennlp != null) {
                CharSoupModel firstModel;
                try (InputStream is = new BufferedInputStream(
                        Files.newInputStream(
                                modelFiles.get(0)))) {
                    firstModel = CharSoupModel.load(is);
                }
                for (String lang : firstModel.getLabels()) {
                    if (opennlp.hasModel(lang)
                            || opennlp.hasModel(
                            canonicalize(lang))) {
                        allowedLangs.add(canonicalize(lang));
                    }
                }
            }
            langSource = "OpenNLP shared";
        }

        // Filter to allowed languages
        List<LabeledSentence> sharedData = new ArrayList<>();
        Set<String> sharedCanonical = new HashSet<>();
        for (LabeledSentence s : allData) {
            String lang = s.getLanguage();
            String mapped = doCollapse
                    ? canonicalize(lang) : lang;
            if (allowedLangs.contains(mapped)) {
                sharedData.add(s);
                sharedCanonical.add(mapped);
            }
        }
        System.out.printf(Locale.US,
                "Shared languages (%s%s): %d "
                        + "(%,d sentences)%n%n",
                langSource,
                doCollapse ? ", collapsed" : ", no-collapse",
                sharedCanonical.size(), sharedData.size());

        StringBuilder report = new StringBuilder();
        report.append(String.format(Locale.US,
                "=== Macro F1 Evaluation "
                        + "(shared langs%s) ===%n",
                doCollapse ? ", confusables collapsed"
                        : ", no-collapse"));
        report.append(String.format(Locale.US,
                "Test set:  %s%n",
                testFile.getFileName()));
        report.append(String.format(Locale.US,
                "Sentences: %,d (from %,d total, "
                        + "filtered to %d shared canonical langs)%n",
                sharedData.size(), allData.size(),
                sharedCanonical.size()));
        report.append(String.format(Locale.US,
                "Threads:   %d%n", threads));
        if (doCollapse) {
            report.append(String.format(Locale.US,
                    "Confusable groups collapsed:%n"));
            for (String[] g :
                    CompareDetectors.CONFUSABLE_GROUPS) {
                boolean relevant = false;
                for (String m : g) {
                    if (allowedLangs.contains(
                            canonicalize(m))) {
                        relevant = true;
                        break;
                    }
                }
                if (relevant) {
                    report.append(String.format(Locale.US,
                            "  {%s} -> %s%n",
                            String.join("/", g), g[0]));
                }
            }
        }
        report.append("\n");

        report.append(String.format(Locale.US,
                "%-20s  %8s  %8s  %10s  %12s%n",
                "Model", "MacroF1", "Accuracy", "Time(ms)",
                "Sent/sec"));
        report.append("-".repeat(66)).append("\n");

        boolean warmedUp = false;

        // Evaluate each CharSoup model
        for (Path mf : modelFiles) {
            System.out.println("Loading: " + mf);
            CharSoupModel model;
            try (InputStream is = new BufferedInputStream(
                    Files.newInputStream(mf))) {
                model = CharSoupModel.load(is);
            }
            FeatureExtractor extractor = model.createExtractor();
            String label = mf.getParent() != null
                    ? mf.getParent().getFileName().toString()
                    : mf.getFileName().toString();
            System.out.printf(Locale.US,
                    "  %s: %d classes, %d buckets%n",
                    label, model.getNumClasses(),
                    model.getNumBuckets());

            if (!warmedUp) {
                System.out.println("Warming up...");
                int w = Math.min(500, sharedData.size());
                for (int i = 0; i < w; i++) {
                    model.predict(extractor.extract(
                            sharedData.get(i).getText()));
                }
                warmedUp = true;
            }

            System.out.println("Evaluating " + label + "...");
            F1Result r = evalBigramParallel(
                    model, extractor, sharedData,
                    threads, doCollapse);
            appendResult(report, "charsoup-" + label, r);
            printResult(label, r);
            if (perLang) {
                appendPerLang(report, r);
            }
        }

        // Evaluate OpenNLP (only when not using explicit lang file)
        if (includeOpenNLP) {
            System.out.println("Loading OpenNLP for eval...");
            LanguageDetector opennlp =
                    CompareDetectors.loadDetector(
                            "org.apache.tika.langdetect.opennlp"
                                    + ".OpenNLPDetector");
            if (opennlp != null) {
                int w = Math.min(500, sharedData.size());
                for (int i = 0; i < w; i++) {
                    opennlp.reset();
                    opennlp.addText(
                            sharedData.get(i).getText());
                    opennlp.detectAll();
                }
                System.out.println("Evaluating OpenNLP...");
                F1Result r = evalOpenNLPParallel(
                        opennlp, sharedData, threads, true);
                appendResult(report, "opennlp", r);
                printResult("opennlp", r);
            }
        }

        // Evaluate Optimaize (only when not using explicit lang file)
        if (includeOptimaize) {
            System.out.println("Loading Optimaize for eval...");
            LanguageDetector optimaize =
                    CompareDetectors.loadDetector(
                            "org.apache.tika.langdetect.optimaize"
                                    + ".OptimaizeLangDetector");
            if (optimaize != null) {
                int w = Math.min(500, sharedData.size());
                for (int i = 0; i < w; i++) {
                    optimaize.reset();
                    optimaize.addText(
                            sharedData.get(i).getText());
                    optimaize.detectAll();
                }
                System.out.println("Evaluating Optimaize...");
                F1Result r = evalOptimaizeParallel(
                        optimaize, sharedData, threads, true);
                appendResult(report, "optimaize", r);
                printResult("optimaize", r);
            }
        }

        report.append("\n");
        String reportStr = report.toString();
        System.out.println(reportStr);

        if (reportFile != null) {
            if (reportFile.getParent() != null) {
                Files.createDirectories(reportFile.getParent());
            }
            try (BufferedWriter w = Files.newBufferedWriter(
                    reportFile, StandardCharsets.UTF_8)) {
                w.write(reportStr);
            }
            System.out.println("Report written to: " + reportFile);
        }
    }

    private static void appendResult(StringBuilder report,
                                     String name, F1Result r) {
        report.append(String.format(Locale.US,
                "%-20s  %8.4f  %7.2f%%  %,10d  %,12.0f%n",
                name, r.macroF1,
                100.0 * r.correct / r.total,
                r.elapsedMs,
                r.total / (r.elapsedMs / 1000.0)));
        report.append(formatBottom10(r, "  "));
    }

    private static void printResult(String label, F1Result r) {
        System.out.printf(Locale.US,
                "  -> MacroF1=%.4f  Acc=%.2f%%  %,dms%n%n",
                r.macroF1,
                100.0 * r.correct / r.total,
                r.elapsedMs);
    }

    // ---- F1 accumulator ----

    static class F1Stats {
        // lang -> [TP, FP, FN]
        Map<String, int[]> counts = new HashMap<>();
        int correct = 0;
        int total = 0;
        boolean collapse;

        F1Stats(boolean collapse) {
            this.collapse = collapse;
        }

        void record(String truth, String predicted) {
            if (collapse) {
                truth = canonicalize(truth);
                predicted = canonicalize(predicted);
            }
            total++;
            boolean hit = truth.equals(predicted);
            if (hit) {
                correct++;
            }
            int[] tc = counts.computeIfAbsent(
                    truth, k -> new int[3]);
            int[] pc = counts.computeIfAbsent(
                    predicted, k -> new int[3]);
            if (hit) {
                tc[0]++;
            } else {
                tc[2]++;
                pc[1]++;
            }
        }

        void merge(F1Stats other) {
            correct += other.correct;
            total += other.total;
            for (var e : other.counts.entrySet()) {
                int[] mine = counts.computeIfAbsent(
                        e.getKey(), k -> new int[3]);
                int[] theirs = e.getValue();
                mine[0] += theirs[0];
                mine[1] += theirs[1];
                mine[2] += theirs[2];
            }
        }

        double macroF1() {
            double sum = 0;
            int n = 0;
            for (int[] c : counts.values()) {
                int tp = c[0], fp = c[1], fn = c[2];
                if (tp + fn == 0) {
                    continue;
                }
                double p = tp + fp > 0
                        ? (double) tp / (tp + fp) : 0;
                double r = (double) tp / (tp + fn);
                double f1 = p + r > 0
                        ? 2 * p * r / (p + r) : 0;
                sum += f1;
                n++;
            }
            return n > 0 ? sum / n : 0;
        }
    }

    static class F1Result {
        double macroF1;
        int correct;
        int total;
        long elapsedMs;
        Map<String, int[]> counts;
    }

    // ---- CharSoup evaluation ----

    static F1Result evalBigramParallel(
            CharSoupModel model, FeatureExtractor extractor,
            List<LabeledSentence> data, int threads,
            boolean collapse) throws Exception {
        List<List<LabeledSentence>> chunks =
                CompareDetectors.partition(data, threads);
        ExecutorService pool =
                Executors.newFixedThreadPool(chunks.size());
        try {
            List<Future<F1Stats>> futures = new ArrayList<>();
            long wallStart = System.nanoTime();
            for (List<LabeledSentence> chunk : chunks) {
                FeatureExtractor te = model.createExtractor();
                futures.add(pool.submit(
                        () -> evalBigramChunk(
                                model, te, chunk, collapse)));
            }
            F1Stats merged = new F1Stats(collapse);
            for (Future<F1Stats> f : futures) {
                merged.merge(f.get());
            }
            long wallEnd = System.nanoTime();

            F1Result r = new F1Result();
            r.macroF1 = merged.macroF1();
            r.correct = merged.correct;
            r.total = merged.total;
            r.elapsedMs = (wallEnd - wallStart) / 1_000_000;
            r.counts = merged.counts;
            return r;
        } finally {
            pool.shutdown();
        }
    }

    static F1Stats evalBigramChunk(
            CharSoupModel model, FeatureExtractor extractor,
            List<LabeledSentence> data, boolean collapse) {
        F1Stats stats = new F1Stats(collapse);
        for (LabeledSentence s : data) {
            int[] features = extractor.extract(s.getText());
            float[] probs = model.predict(features);
            int predIdx = 0;
            for (int c = 1; c < probs.length; c++) {
                if (probs[c] > probs[predIdx]) {
                    predIdx = c;
                }
            }
            stats.record(s.getLanguage(),
                    model.getLabel(predIdx));
        }
        return stats;
    }

    // ---- OpenNLP evaluation ----

    static F1Result evalOpenNLPParallel(
            LanguageDetector detector,
            List<LabeledSentence> data, int threads,
            boolean collapse) throws Exception {
        List<List<LabeledSentence>> chunks =
                CompareDetectors.partition(data, threads);
        ExecutorService pool =
                Executors.newFixedThreadPool(chunks.size());
        try {
            List<Future<F1Stats>> futures = new ArrayList<>();
            long wallStart = System.nanoTime();
            for (List<LabeledSentence> chunk : chunks) {
                futures.add(pool.submit(
                        () -> evalOpenNLPChunk(
                                detector, chunk, collapse)));
            }
            F1Stats merged = new F1Stats(collapse);
            for (Future<F1Stats> f : futures) {
                merged.merge(f.get());
            }
            long wallEnd = System.nanoTime();

            F1Result r = new F1Result();
            r.macroF1 = merged.macroF1();
            r.correct = merged.correct;
            r.total = merged.total;
            r.elapsedMs = (wallEnd - wallStart) / 1_000_000;
            r.counts = merged.counts;
            return r;
        } finally {
            pool.shutdown();
        }
    }

    static F1Stats evalOpenNLPChunk(
            LanguageDetector detector,
            List<LabeledSentence> data, boolean collapse) {
        F1Stats stats = new F1Stats(collapse);
        LanguageDetector local;
        try {
            local = CompareDetectors.loadDetector(
                    "org.apache.tika.langdetect.opennlp"
                            + ".OpenNLPDetector");
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (local == null) {
            return stats;
        }
        for (LabeledSentence s : data) {
            local.reset();
            local.addText(s.getText());
            List<LanguageResult> results = local.detectAll();
            String predicted = results.isEmpty()
                    ? "unk" : results.get(0).getLanguage();
            stats.record(s.getLanguage(), predicted);
        }
        return stats;
    }

    static F1Result evalOptimaizeParallel(
            LanguageDetector optimaize,
            List<LabeledSentence> data, int threads, boolean collapse)
            throws Exception {
        List<List<LabeledSentence>> chunks =
                CompareDetectors.partition(data, threads);
        ExecutorService pool =
                Executors.newFixedThreadPool(chunks.size());
        try {
            List<Future<F1Stats>> futures = new ArrayList<>();
            long wallStart = System.nanoTime();
            for (List<LabeledSentence> chunk : chunks) {
                futures.add(pool.submit(() ->
                        evalOptimaizeChunk(optimaize, chunk, collapse)));
            }
            F1Stats merged = new F1Stats(collapse);
            for (Future<F1Stats> f : futures) {
                merged.merge(f.get());
            }
            long wallEnd = System.nanoTime();
            F1Result r = new F1Result();
            r.macroF1 = merged.macroF1();
            r.correct = merged.correct;
            r.total = merged.total;
            r.elapsedMs = (wallEnd - wallStart) / 1_000_000;
            r.counts = merged.counts;
            return r;
        } finally {
            pool.shutdown();
        }
    }

    static F1Stats evalOptimaizeChunk(
            LanguageDetector detector,
            List<LabeledSentence> data, boolean collapse) {
        F1Stats stats = new F1Stats(collapse);
        if (detector == null) {
            return stats;
        }
        for (LabeledSentence s : data) {
            detector.reset();
            detector.addText(s.getText());
            List<LanguageResult> results = detector.detectAll();
            String rawPred = results.isEmpty()
                    ? "unk" : results.get(0).getLanguage();
            String predicted = CompareDetectors.optimaizePredToIso3(rawPred);
            stats.record(s.getLanguage(), predicted);
        }
        return stats;
    }

    // ---- Formatting ----

    static String formatBottom10(F1Result r, String indent) {
        List<Map.Entry<String, double[]>> langF1 =
                new ArrayList<>();
        for (var e : r.counts.entrySet()) {
            int[] c = e.getValue();
            int tp = c[0], fp = c[1], fn = c[2];
            if (tp + fn == 0) {
                continue;
            }
            double p = tp + fp > 0
                    ? (double) tp / (tp + fp) : 0;
            double rec = (double) tp / (tp + fn);
            double f1 = p + rec > 0
                    ? 2 * p * rec / (p + rec) : 0;
            langF1.add(Map.entry(e.getKey(),
                    new double[]{f1, p, rec, tp + fn}));
        }
        langF1.sort((a, b) -> Double.compare(
                a.getValue()[0], b.getValue()[0]));

        StringBuilder sb = new StringBuilder();
        sb.append(indent).append(String.format(Locale.US,
                "Bottom 10: "));
        int show = Math.min(10, langF1.size());
        for (int i = 0; i < show; i++) {
            var e = langF1.get(i);
            if (i > 0) {
                sb.append(", ");
            }
            sb.append(String.format(Locale.US,
                    "%s=%.3f", e.getKey(), e.getValue()[0]));
        }
        sb.append("\n");
        return sb.toString();
    }

    static void appendPerLang(StringBuilder report,
                              F1Result r) {
        List<Map.Entry<String, double[]>> langF1 =
                new ArrayList<>();
        for (var e : r.counts.entrySet()) {
            int[] c = e.getValue();
            int tp = c[0], fp = c[1], fn = c[2];
            if (tp + fn == 0) {
                continue;
            }
            double p = tp + fp > 0
                    ? (double) tp / (tp + fp) : 0;
            double rec = (double) tp / (tp + fn);
            double f1 = p + rec > 0
                    ? 2 * p * rec / (p + rec) : 0;
            langF1.add(Map.entry(e.getKey(),
                    new double[]{f1, p, rec, tp + fn}));
        }
        langF1.sort((a, b) -> Double.compare(
                a.getValue()[0], b.getValue()[0]));
        report.append(String.format(Locale.US,
                "  %-16s  %8s  %8s  %8s  %8s%n",
                "Language", "F1", "Prec", "Recall",
                "Count"));
        report.append("  ").append("-".repeat(56))
                .append("\n");
        for (var e : langF1) {
            double[] v = e.getValue();
            report.append(String.format(Locale.US,
                    "  %-16s  %8.4f  %8.4f  %8.4f  %8.0f%n",
                    e.getKey(), v[0], v[1], v[2], v[3]));
        }
        report.append("\n");
    }
}