TrainLanguageModel.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;

import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.FeatureExtractor;

/**
 * Two-pass training pipeline for the CharSoup language detector.
 *
 * <p><strong>WARNING ��� feature extraction must stay in sync with
 * {@link org.apache.tika.langdetect.charsoup.CharSoupFeatureExtractor}.</strong>
 * If you change training hyper-parameters (bucket count, n-gram order, etc.)
 * you must also update the corresponding inference constants and retrain from
 * scratch.
 *
 * <p>Data preparation (corpus ��� pool/dev/test splits) is handled by
 * {@link PrepareCorpus}, which owns the language-inclusion policy. Run
 * {@code PrepareCorpus} first, then point this trainer at the output directory
 * via {@code --prep-dir}. If {@code --prep-dir} already contains the three
 * expected outputs the prep step is skipped automatically.
 *
 * <h3>Pipeline</h3>
 * <ol>
 *   <li>Data prep via {@link PrepareCorpus#prepareData} (skipped if
 *       {@code --prep-dir} already populated)</li>
 *   <li><b>Pass 1</b>: Train with epoch-level resampling</li>
 *   <li><b>Filter</b>: Remove mislabeled sentences (respecting confusable
 *       groups)</li>
 *   <li><b>Pass 2</b>: Retrain on filtered pool</li>
 *   <li>Quantize to INT8, evaluate on raw test set, export</li>
 * </ol>
 *
 * <p>Usage:
 * <pre>
 *   TrainLanguageModel --corpus &lt;dir&gt; --output &lt;file&gt;
 *       [--prep-dir &lt;dir&gt;] [--buckets N] [--max-train N]
 *       [--skip-bigrams] [--trigrams] [--eval-only &lt;model&gt;]
 * </pre>
 */
public class TrainLanguageModel {

    private static final int DEFAULT_NUM_BUCKETS = 8_192;
    private static final long DEFAULT_TARGET_EPOCH_TOTAL = 5_000_000L;
    private static final int DEFAULT_MAX_TEST_PER_LANG  = 20_000;
    private static final int DEFAULT_MAX_DEV_PER_LANG   = 20_000;
    private static final int DEFAULT_MAX_TRAIN_PER_LANG = 0; // 0 = unlimited

    // Language-inclusion policy (exclusion list, merge aliases, thresholds)
    // lives in PrepareCorpus ��� see that class for documentation.

    public static void main(String[] args) throws IOException {
        Path corpusDir        = null;
        Path outputFile       = null;
        int  numBuckets       = DEFAULT_NUM_BUCKETS;
        long targetEpochTotal = DEFAULT_TARGET_EPOCH_TOTAL;
        int  maxTrainPerLang  = DEFAULT_MAX_TRAIN_PER_LANG;
        int  maxDevPerLang    = DEFAULT_MAX_DEV_PER_LANG;
        int  maxTestPerLang   = DEFAULT_MAX_TEST_PER_LANG;
        boolean useTrigrams    = false;
        boolean useSkipBigrams = false;
        boolean singlePass     = false;
        Path    evalOnlyModel  = null;
        Path    prepDirOverride = null;

        for (int i = 0; i < args.length; i++) {
            switch (args[i]) {
                case "--corpus":
                    corpusDir = Paths.get(args[++i]);
                    break;
                case "--output":
                    outputFile = Paths.get(args[++i]);
                    break;
                case "--buckets":
                    numBuckets = Integer.parseInt(args[++i]);
                    break;
                case "--max-train":
                    maxTrainPerLang = Integer.parseInt(args[++i]);
                    break;
                case "--max-dev":
                    maxDevPerLang = Integer.parseInt(args[++i]);
                    break;
                case "--max-test":
                    maxTestPerLang = Integer.parseInt(args[++i]);
                    break;
                case "--epoch-total":
                    targetEpochTotal = Long.parseLong(args[++i]);
                    break;
                case "--trigrams":
                    useTrigrams = true;
                    break;
                case "--skip-bigrams":
                    useSkipBigrams = true;
                    break;
                case "--single-pass":
                    singlePass = true;
                    break;
                case "--prep-dir":
                    prepDirOverride = Paths.get(args[++i]);
                    break;
                case "--eval-only":
                    evalOnlyModel = Paths.get(args[++i]);
                    break;
                default:
                    System.err.println("Unknown argument: " + args[i]);
                    printUsage();
                    System.exit(1);
            }
        }

        if (corpusDir == null || (outputFile == null && evalOnlyModel == null)) {
            printUsage();
            System.exit(1);
        }

        if (evalOnlyModel != null) {
            Path prepDir2 = prepDirOverride != null ? prepDirOverride
                    : (evalOnlyModel.getParent() != null
                        ? evalOnlyModel.getParent()
                        : Paths.get(".")).resolve("preprocessed");
            System.out.println("=== Eval-only mode ===");
            System.out.println("Model:    " + evalOnlyModel);
            System.out.println("Prep dir: " + prepDir2);
            CharSoupModel evalModel;
            try (java.io.InputStream is = new java.io.BufferedInputStream(
                    Files.newInputStream(evalOnlyModel))) {
                evalModel = CharSoupModel.load(is);
            }
            runEval(evalModel, prepDir2.resolve("test_raw.txt"), "test");
            return;
        }

        Path prepDir = prepDirOverride != null ? prepDirOverride
                : (outputFile.getParent() != null
                    ? outputFile.getParent()
                    : Paths.get(".")).resolve("preprocessed");

        System.out.println("=== Language Model Training Pipeline ===");
        System.out.println("Corpus:          " + corpusDir);
        System.out.println("Output:          " + outputFile);
        System.out.println("Prep dir:        " + prepDir);
        System.out.println("Buckets:         " + numBuckets);
        System.out.printf(Locale.US,
                "TargetEpochTotal: %,d%n", targetEpochTotal);
        System.out.println();

        long pipelineStart = System.nanoTime();
        long stepStart;

        // ---- Data preparation ----
        Path poolDir = prepDir.resolve("pool");
        Path devFile = prepDir.resolve("dev.txt");
        Path testFile = prepDir.resolve("test_raw.txt");

        if (Files.isDirectory(poolDir)
                && Files.exists(devFile)
                && Files.exists(testFile)) {
            System.out.println(
                    "--- Preprocessed data found ��� skipping data prep ---");
        } else {
            stepStart = System.nanoTime();
            System.out.println("--- Step 1: Data preparation (PrepareCorpus) ---");
            Files.createDirectories(poolDir);
            int[] counts = PrepareCorpus.prepareData(corpusDir, prepDir,
                    maxTrainPerLang, maxDevPerLang, maxTestPerLang);
            System.out.printf(Locale.US,
                    "Prepared: pool=%,d  dev=%,d  test=%,d%n",
                    counts[0], counts[1], counts[2]);
            System.out.printf(Locale.US,
                    "  [%.1f s]%n", elapsed(stepStart));
        }

        // Collect all labels from pool directory
        String[] allLabels = collectLabels(poolDir);
        System.out.printf(Locale.US,
                "Languages in pool: %d%n%n", allLabels.length);

        // Load fixed dev data
        stepStart = System.nanoTime();
        System.out.println("--- Loading dev data ---");
        List<LabeledSentence> devData =
                readPreprocessedFile(devFile);
        System.out.printf(Locale.US,
                "Dev: %,d sentences  [%.1f s]%n",
                devData.size(), elapsed(stepStart));

        Path epochFile = prepDir.resolve("epoch_train.txt");

        // ---- Pass 1 ----
        stepStart = System.nanoTime();
        System.out.println(
                "\n--- Step 2: Pass 1 ��� Initial training ---");
        Map<String, Integer> pass1Targets =
                computePerLangTargets(
                        scanPoolSizes(poolDir),
                        targetEpochTotal);
        Phase2Trainer pass1 = new Phase2Trainer(numBuckets)
                .setPreprocessed(true)
                .setUseSkipBigrams(useSkipBigrams)
                .setUseTrigrams(useTrigrams);
        pass1.trainWithResampling(allLabels,
                epochNum -> createEpochFile(
                        poolDir, epochFile,
                        pass1Targets, epochNum),
                devData);
        System.out.printf(Locale.US, "  [%.1f s]%n",
                elapsed(stepStart));

        Phase2Trainer finalTrainer;
        if (singlePass) {
            System.out.println("\n--- Single-pass mode: skipping filter and Pass 2 ---");
            finalTrainer = pass1;
        } else {
            // ---- Filter training pool ----
            stepStart = System.nanoTime();
            System.out.println(
                    "\n--- Step 3: Filtering mislabeled sentences ---");
            Path filteredPoolDir = prepDir.resolve("pool_filtered");
            long[] filterCounts = filterPool(
                    pass1, poolDir, filteredPoolDir);
            System.out.printf(Locale.US,
                    "Kept %,d / %,d sentences "
                            + "(removed %,d = %.1f%%)%n",
                    filterCounts[0], filterCounts[1],
                    filterCounts[1] - filterCounts[0],
                    100.0 * (filterCounts[1] - filterCounts[0])
                            / filterCounts[1]);
            System.out.printf(Locale.US, "  [%.1f s]%n",
                    elapsed(stepStart));

            // ---- Pass 2 ----
            stepStart = System.nanoTime();
            System.out.println(
                    "\n--- Step 4: Pass 2 ��� Retraining on "
                            + "filtered data ---");
            String[] filteredLabels =
                    collectLabels(filteredPoolDir);
            Map<String, Integer> pass2Targets =
                    computePerLangTargets(
                            scanPoolSizes(filteredPoolDir),
                            targetEpochTotal);
            Phase2Trainer pass2 = new Phase2Trainer(numBuckets)
                    .setPreprocessed(true)
                    .setUseSkipBigrams(useSkipBigrams)
                    .setUseTrigrams(useTrigrams);
            pass2.trainWithResampling(filteredLabels,
                    epochNum -> createEpochFile(
                            filteredPoolDir, epochFile,
                            pass2Targets, epochNum),
                    devData);
            System.out.printf(Locale.US, "  [%.1f s]%n",
                    elapsed(stepStart));
            finalTrainer = pass2;
        }

        // ---- Quantize ----
        stepStart = System.nanoTime();
        System.out.println("\n--- Step 5: Quantizing to INT8 ---");
        CharSoupModel quantized = ModelQuantizer.quantize(finalTrainer);
        System.out.printf(Locale.US, "  [%.1f s]%n",
                elapsed(stepStart));

        // ---- Evaluate on raw test set ----
        stepStart = System.nanoTime();
        System.out.println(
                "\n--- Step 6: Evaluating on raw test set ---");
        List<LabeledSentence> testData =
                readPreprocessedFile(testFile);

        // Evaluate at truncated lengths
        int[] truncLengths = {20, 50, 100, 200, 500};
        List<Integer> evalLengths = new ArrayList<>();
        for (int l : truncLengths) {
            evalLengths.add(l);
        }

        System.out.printf(Locale.US,
                "%-10s  %8s  %8s  %12s  %8s  %8s%n",
                "length", "macroF1", "median", ">=0.90/total",
                "accuracy", "n");
        System.out.println(
                "----------  --------  --------"
                + "  ------------  --------  --------");

        for (int maxChars : evalLengths) {
            List<LabeledSentence> subset =
                    truncateTestData(testData, maxChars);
            EvalResult r = evaluateQuantized(quantized, subset);
            String lenLabel = String.format(
                    Locale.US, "@%d", maxChars);
            System.out.printf(Locale.US,
                    "%-10s  %8.4f  %8.4f  %5d/%-6d  %8.4f  %,8d%n",
                    lenLabel, r.macroF1, r.medianF1,
                    r.numAbove90, r.numLangs,
                    r.accuracy, r.total);
        }

        // Worst-10 per length + full TSV dump
        int worstN = 10;
        System.out.println();

        // Collect all results first (avoid re-evaluating for TSV)
        List<EvalResult> allResults = new ArrayList<>();
        for (int maxChars : evalLengths) {
            allResults.add(evaluateQuantized(
                    quantized, truncateTestData(testData, maxChars)));
        }

        for (int ri = 0; ri < evalLengths.size(); ri++) {
            int maxChars = evalLengths.get(ri);
            EvalResult r = allResults.get(ri);
            System.out.printf(Locale.US,
                    "Worst %d languages (@%d chars):%n",
                    worstN, maxChars);
            int limit = Math.min(worstN, r.perLang.size());
            for (int i = 0; i < limit; i++) {
                LangF1 lf = r.perLang.get(i);
                System.out.printf(Locale.US,
                        "  %-8s  F1=%.4f%n", lf.lang, lf.f1);
            }
            System.out.println();
        }

        // Full per-language TSV: lang, f1@20, f1@50, f1@100, f1@200, f1@500
        Path tsvFile = outputFile.resolveSibling(
                outputFile.getFileName().toString()
                        .replaceFirst("\\.bin$", "")
                + "-per-lang.tsv");
        try (BufferedWriter tsv = Files.newBufferedWriter(
                tsvFile, StandardCharsets.UTF_8)) {
            tsv.write("lang");
            for (int maxChars : evalLengths) {
                tsv.write("\tf1@" + maxChars);
            }
            tsv.newLine();
            // Use the first result's lang list (all same langs)
            EvalResult first = allResults.get(0);
            // Build a map from lang ��� f1 for each length
            List<Map<String, Double>> f1Maps = new ArrayList<>();
            for (EvalResult r : allResults) {
                Map<String, Double> m = new HashMap<>();
                for (LangF1 lf : r.perLang) {
                    m.put(lf.lang, lf.f1);
                }
                f1Maps.add(m);
            }
            // Collect all langs (sorted)
            List<String> allLangs = new ArrayList<>();
            for (LangF1 lf : first.perLang) {
                allLangs.add(lf.lang);
            }
            allLangs.sort(String::compareTo);
            for (String lang : allLangs) {
                tsv.write(lang);
                for (Map<String, Double> m : f1Maps) {
                    tsv.write(String.format(Locale.US,
                            "\t%.4f", m.getOrDefault(lang, 0.0)));
                }
                tsv.newLine();
            }
        }
        System.out.println("Per-language TSV: " + tsvFile);

        System.out.printf(Locale.US, "  [%.1f s]%n",
                elapsed(stepStart));

        // ---- Export ----
        stepStart = System.nanoTime();
        System.out.println("\n--- Step 7: Exporting model ---");
        if (outputFile.getParent() != null) {
            Files.createDirectories(outputFile.getParent());
        }
        try (OutputStream os = new BufferedOutputStream(
                Files.newOutputStream(outputFile))) {
            quantized.save(os);
        }
        long fileSize = Files.size(outputFile);
        System.out.printf(Locale.US,
                "Model saved: %s (%.1f MB)%n",
                outputFile, fileSize / (1024.0 * 1024.0));
        System.out.printf(Locale.US, "  [%.1f s]%n",
                elapsed(stepStart));

        double totalMin = (System.nanoTime() - pipelineStart)
                / 1_000_000_000.0 / 60.0;
        System.out.printf(Locale.US,
                "%nDone! Total time: %.1f min%n", totalMin);
    }

    private static void printUsage() {
        System.err.println("Usage: TrainLanguageModel"
                + " --corpus <dir> --output <file>"
                + " [--prep-dir <dir>] [--buckets N] [--max-train N]"
                + " [--skip-bigrams] [--trigrams] [--single-pass]"
                + " [--eval-only <model>]");
        System.err.println("  --single-pass  skip filterPool + Pass 2 (Pass 1 only)");
        System.err.println("  Data preparation is handled by PrepareCorpus."
                + " Run PrepareCorpus first, or provide --corpus so this"
                + " trainer can call PrepareCorpus.prepareData automatically.");
    }

    private static double elapsed(long startNanos) {
        return (System.nanoTime() - startNanos) / 1_000_000_000.0;
    }

    // ================================================================
    //  Epoch file creation (resampling from pool)
    // ================================================================

    /**
     * Sample up to {@code maxPerLang} sentences from each
     * per-language pool file and write to a single epoch
     * training file with languages interleaved.
     * <p>
     * Two-phase approach to stay memory-efficient:
     * <ol>
     *   <li><b>Sample</b>: reservoir-sample each language
     *       one at a time into a per-language temp file.
     *       Peak memory = one language's reservoir.</li>
     *   <li><b>Interleave</b>: open all temp files, randomly
     *       pick a language for each line, write to epoch
     *       file. Memory = reader buffers only (~200 �� 8 KB).
     *       </li>
     * </ol>
     * Interleaving is critical: without it, the epoch file
     * would contain single-language blocks that cause
     * catastrophic forgetting during SGD.
     *
     * @return the epoch file path
     */
    static Path createEpochFile(Path poolDir, Path epochFile,
                                Map<String, Integer> perLangTargets,
                                int epochNum)
            throws IOException {
        Random rng = new Random(42L + epochNum * 31L);

        List<String> langs = new ArrayList<>();
        List<Path> poolFiles = new ArrayList<>();
        try (DirectoryStream<Path> ds =
                     Files.newDirectoryStream(poolDir)) {
            for (Path langFile : ds) {
                if (!Files.isRegularFile(langFile)) {
                    continue;
                }
                langs.add(
                        langFile.getFileName().toString());
                poolFiles.add(langFile);
            }
        }

        // Phase 1: reservoir sample each language into a
        // temp file (one language in memory at a time)
        Path sampledDir =
                Files.createTempDirectory("epoch_sampled_");
        List<Path> sampledFiles =
                new ArrayList<>(langs.size());

        for (int i = 0; i < poolFiles.size(); i++) {
            int target = perLangTargets.getOrDefault(
                    langs.get(i), 0);
            List<String> reservoir = new ArrayList<>();
            try (BufferedReader reader =
                         Files.newBufferedReader(
                                 poolFiles.get(i),
                                 StandardCharsets.UTF_8)) {
                String line;
                int lineNum = 0;
                while ((line = reader.readLine()) != null) {
                    if (lineNum < target) {
                        reservoir.add(line);
                    } else {
                        int j = rng.nextInt(lineNum + 1);
                        if (j < target) {
                            reservoir.set(j, line);
                        }
                    }
                    lineNum++;
                }
            }

            Collections.shuffle(reservoir, rng);
            Path sampledFile =
                    sampledDir.resolve(langs.get(i));
            try (BufferedWriter sw =
                         Files.newBufferedWriter(
                                 sampledFile,
                                 StandardCharsets.UTF_8)) {
                for (String text : reservoir) {
                    sw.write(text);
                    sw.newLine();
                }
            }
            sampledFiles.add(sampledFile);
        }

        // Phase 2: interleave into epoch file by randomly
        // picking among active languages for each line
        int numLangs = langs.size();
        BufferedReader[] readers =
                new BufferedReader[numLangs];
        String[] pending = new String[numLangs];
        List<Integer> active = new ArrayList<>(numLangs);
        int totalWritten = 0;

        try (BufferedWriter w = Files.newBufferedWriter(
                epochFile, StandardCharsets.UTF_8)) {
            for (int i = 0; i < numLangs; i++) {
                readers[i] = Files.newBufferedReader(
                        sampledFiles.get(i),
                        StandardCharsets.UTF_8);
                pending[i] = readers[i].readLine();
                if (pending[i] != null) {
                    active.add(i);
                }
            }

            while (!active.isEmpty()) {
                int pick = rng.nextInt(active.size());
                int idx = active.get(pick);

                w.write(langs.get(idx));
                w.write('\t');
                w.write(pending[idx]);
                w.newLine();
                totalWritten++;

                pending[idx] = readers[idx].readLine();
                if (pending[idx] == null) {
                    readers[idx].close();
                    readers[idx] = null;
                    int last = active.size() - 1;
                    active.set(pick, active.get(last));
                    active.remove(last);
                }
            }
        } finally {
            for (BufferedReader r : readers) {
                if (r != null) {
                    r.close();
                }
            }
            for (Path f : sampledFiles) {
                Files.deleteIfExists(f);
            }
            Files.deleteIfExists(sampledDir);
        }

        System.out.printf(Locale.US,
                "  Epoch %d: sampled %,d sentences%n",
                epochNum + 1, totalWritten);
        return epochFile;
    }

    // ================================================================
    //  Pool size scanning and per-language target computation
    // ================================================================

    /**
     * Count lines in each per-language pool file.
     */
    static Map<String, Long> scanPoolSizes(Path poolDir)
            throws IOException {
        Map<String, Long> sizes = new HashMap<>();
        try (DirectoryStream<Path> ds =
                     Files.newDirectoryStream(poolDir)) {
            for (Path p : ds) {
                if (!Files.isRegularFile(p)) {
                    continue;
                }
                long count = 0;
                try (BufferedReader br =
                             Files.newBufferedReader(
                                     p,
                                     StandardCharsets.UTF_8)) {
                    while (br.readLine() != null) {
                        count++;
                    }
                }
                sizes.put(p.getFileName().toString(), count);
            }
        }
        return sizes;
    }

    /**
     * Compute per-language epoch targets by binary-searching
     * for a flat cap C such that {@code �� min(n_i, C) ��� targetTotal}.
     * Languages with fewer sentences than C contribute all their
     * data; larger languages are uniformly capped.
     */
    static Map<String, Integer> computePerLangTargets(
            Map<String, Long> poolSizes,
            long targetTotal) {
        long totalAvailable = poolSizes.values().stream()
                .mapToLong(Long::longValue).sum();

        if (totalAvailable <= targetTotal) {
            System.out.printf(Locale.US,
                    "  Pool total=%,d <= target=%,d;"
                            + " using all data%n",
                    totalAvailable, targetTotal);
            Map<String, Integer> targets = new HashMap<>();
            poolSizes.forEach((lang, size) ->
                    targets.put(lang, (int) Math.min(
                            size, Integer.MAX_VALUE)));
            return targets;
        }

        long lo = 0;
        long hi = poolSizes.values().stream()
                .mapToLong(Long::longValue).max().orElse(0);
        while (lo < hi - 1) {
            long mid = (lo + hi) / 2;
            long total = poolSizes.values().stream()
                    .mapToLong(n -> Math.min(n, mid)).sum();
            if (total < targetTotal) {
                lo = mid;
            } else {
                hi = mid;
            }
        }
        long cap = hi;
        long actualTotal = poolSizes.values().stream()
                .mapToLong(n -> Math.min(n, cap)).sum();
        long capped = poolSizes.values().stream()
                .filter(n -> n > cap).count();
        System.out.printf(Locale.US,
                "  Epoch cap=%,d  actual=%,d  target=%,d"
                        + "  (%d/%d langs capped)%n",
                cap, actualTotal, targetTotal,
                capped, poolSizes.size());

        Map<String, Integer> targets = new HashMap<>();
        poolSizes.forEach((lang, size) ->
                targets.put(lang,
                        (int) Math.min(size, cap)));
        return targets;
    }

    // ================================================================
    //  Mislabeled sentence filtering (pool-based)
    // ================================================================

    /**
     * Filter all per-language pool files in parallel. For each
     * sentence, if the model's prediction doesn't match the label
     * (respecting confusable groups), it is removed.
     * <p>
     * Each worker thread reuses its own extractor and feature/logit
     * buffers to avoid per-sentence allocation overhead.
     *
     * @return long[2]: {kept, total}
     */
    static long[] filterPool(Phase2Trainer trainer,
                             Path poolDir, Path filteredDir)
            throws IOException {
        Files.createDirectories(filteredDir);

        List<Path> langFiles = new ArrayList<>();
        try (DirectoryStream<Path> ds =
                     Files.newDirectoryStream(poolDir)) {
            for (Path p : ds) {
                if (Files.isRegularFile(p)) {
                    langFiles.add(p);
                }
            }
        }
        langFiles.sort(Comparator.comparing(
                p -> p.getFileName().toString()));

        int numLangs = langFiles.size();
        AtomicLong keptTotal = new AtomicLong();
        AtomicLong grandTotal = new AtomicLong();
        AtomicLong langsProcessed = new AtomicLong();

        int threads = Runtime.getRuntime().availableProcessors();
        ExecutorService exec =
                Executors.newFixedThreadPool(threads);
        List<Future<?>> futures = new ArrayList<>(numLangs);

        for (Path langFile : langFiles) {
            futures.add(exec.submit(() -> {
                String lang =
                        langFile.getFileName().toString();
                Path outFile = filteredDir.resolve(lang);
                FeatureExtractor ext = trainer.getExtractor();
                int[] featureBuf =
                        new int[trainer.getNumBuckets()];
                float[] logitBuf =
                        new float[trainer.getNumClasses()];
                long langKept = 0;
                long langTotal = 0;
                try (BufferedReader reader =
                             Files.newBufferedReader(
                                     langFile,
                                     StandardCharsets.UTF_8);
                     BufferedWriter writer =
                             Files.newBufferedWriter(
                                     outFile,
                                     StandardCharsets.UTF_8)) {
                    String text;
                    while ((text = reader.readLine())
                            != null) {
                        langTotal++;
                        String predicted =
                                trainer.predictBuffered(
                                        text, ext,
                                        featureBuf, logitBuf);
                        if (CompareDetectors.isGroupMatch(
                                lang, predicted)) {
                            writer.write(text);
                            writer.newLine();
                            langKept++;
                        }
                    }
                } catch (IOException e) {
                    throw new RuntimeException(
                            "Filter failed: " + lang, e);
                }
                keptTotal.addAndGet(langKept);
                grandTotal.addAndGet(langTotal);
                long done = langsProcessed.incrementAndGet();
                System.out.printf(Locale.US,
                        "  %s: kept %,d/%,d"
                                + "  [%d/%d langs done]%n",
                        lang, langKept, langTotal,
                        done, numLangs);
            }));
        }

        exec.shutdown();
        for (Future<?> f : futures) {
            try {
                f.get();
            } catch (Exception e) {
                throw new RuntimeException(
                        "Filter thread failed", e);
            }
        }

        return new long[]{keptTotal.get(), grandTotal.get()};
    }

    // ================================================================
    //  Evaluation
    // ================================================================

    private static void runEval(CharSoupModel model,
                                Path dataFile,
                                String label) throws IOException {
        List<LabeledSentence> data = readPreprocessedFile(dataFile);
        System.out.printf(Locale.US,
                "%n--- Eval on %s (%,d sentences, %d langs) ---%n",
                label, data.size(),
                data.stream().map(LabeledSentence::getLanguage)
                        .distinct().count());
        int[] lengths = {20, 50, 100, 200, 500};
        List<Integer> evalLengths = new ArrayList<>();
        for (int l : lengths) {
            evalLengths.add(l);
        }
        System.out.printf(Locale.US,
                "%-10s  %8s  %8s  %12s  %8s%n",
                "length", "macroF1", "median",
                ">=0.90/total", "accuracy");
        System.out.println(
                "----------  --------  --------"
                        + "  ------------  --------");
        List<EvalResult> results = new ArrayList<>();
        for (int maxChars : lengths) {
            List<LabeledSentence> subset =
                    truncateTestData(data, maxChars);
            EvalResult r = evaluateQuantized(model, subset);
            results.add(r);
            System.out.printf(Locale.US,
                    "%-10s  %8.4f  %8.4f  %5d/%-6d  %8.4f%n",
                    "@" + maxChars, r.macroF1, r.medianF1,
                    r.numAbove90, r.numLangs, r.accuracy);
        }
        // Worst-10 at @500
        EvalResult last = results.get(results.size() - 1);
        System.out.printf(Locale.US,
                "%nWorst 10 languages (@500 chars, %s):%n", label);
        int limit = Math.min(10, last.perLang.size());
        for (int i = 0; i < limit; i++) {
            LangF1 lf = last.perLang.get(i);
            System.out.printf(Locale.US,
                    "  %-8s  F1=%.4f%n", lf.lang, lf.f1);
        }
        // TSV dump
        Path tsvFile = dataFile.resolveSibling(
                label.replace("-", "_") + "-per-lang.tsv");
        try (BufferedWriter tsv = Files.newBufferedWriter(
                tsvFile, StandardCharsets.UTF_8)) {
            tsv.write("lang");
            for (int l : lengths) {
                tsv.write("\tf1@" + l);
            }
            tsv.newLine();
            Map<String, Double>[] f1Maps =
                    new HashMap[lengths.length];
            for (int ri = 0; ri < results.size(); ri++) {
                f1Maps[ri] = new HashMap<>();
                for (LangF1 lf : results.get(ri).perLang) {
                    f1Maps[ri].put(lf.lang, lf.f1);
                }
            }
            List<String> allLangs = new ArrayList<>();
            for (LangF1 lf : last.perLang) {
                allLangs.add(lf.lang);
            }
            allLangs.sort(String::compareTo);
            for (String lang : allLangs) {
                tsv.write(lang);
                for (Map<String, Double> m : f1Maps) {
                    tsv.write(String.format(Locale.US,
                            "\t%.4f", m.getOrDefault(lang, 0.0)));
                }
                tsv.newLine();
            }
        }
        System.out.println("TSV written: " + tsvFile);

        // Confusion analysis at @500 ��� high-resource langs + worst 5
        List<LabeledSentence> at500 = truncateTestData(data, 500);
        Set<String> analysisTargets = new LinkedHashSet<>(
                Arrays.asList("eng", "deu", "fra", "spa", "rus",
                        "zho", "ara", "por"));
        for (int i = 0; i < Math.min(5, last.perLang.size()); i++) {
            analysisTargets.add(last.perLang.get(i).lang);
        }
        for (String lang : analysisTargets) {
            printConfusion(model, at500, lang, 15);
        }
    }

    private static void printConfusion(CharSoupModel model,
                                       List<LabeledSentence> data,
                                       String targetLang,
                                       int topN) {
        FeatureExtractor extractor = model.createExtractor();
        Map<String, Integer> labelIndex = new HashMap<>();
        for (int i = 0; i < model.getNumClasses(); i++) {
            labelIndex.put(model.getLabel(i), i);
        }
        Integer trueIdx = labelIndex.get(targetLang);
        if (trueIdx == null) {
            System.out.println("Language not in model: " + targetLang);
            return;
        }

        // predicted ��� count, for sentences whose true label = targetLang
        Map<String, Integer> predictedWhen = new HashMap<>();
        // true ��� count, for sentences whose predicted label = targetLang (false positives)
        Map<String, Integer> trueWhen = new HashMap<>();
        int total = 0;
        int correct = 0;
        for (LabeledSentence s : data) {
            if (targetLang.equals(s.getLanguage())) {
                total++;
                int[] features = extractor.extract(s.getText());
                float[] probs = model.predict(features);
                int predicted = argmax(probs);
                String predLabel = model.getLabel(predicted);
                if (predicted == trueIdx) {
                    correct++;
                }
                predictedWhen.merge(predLabel, 1, Integer::sum);
            } else {
                Integer ti = labelIndex.get(s.getLanguage());
                if (ti == null) {
                    continue;
                }
                int[] features = extractor.extract(s.getText());
                float[] probs = model.predict(features);
                int predicted = argmax(probs);
                if (predicted == trueIdx) {
                    trueWhen.merge(s.getLanguage(), 1, Integer::sum);
                }
            }
        }

        double prec = (correct + trueWhen.values().stream()
                .mapToInt(Integer::intValue).sum()) > 0
                ? (double) correct
                / (correct + trueWhen.values().stream()
                        .mapToInt(Integer::intValue).sum())
                : 0.0;
        double rec = total > 0 ? (double) correct / total : 0.0;
        System.out.printf(Locale.US,
                "%n--- Confusion for '%s' @500 chars"
                        + " (correct=%d/%d, prec=%.3f, rec=%.3f) ---%n",
                targetLang, correct, total, prec, rec);

        // Top misclassifications (predicted X when true=targetLang)
        final int finalTotal = total;
        System.out.printf(Locale.US,
                "  True=%s, predicted as (top %d):%n",
                targetLang, topN);
        predictedWhen.entrySet().stream()
                .filter(e -> !e.getKey().equals(targetLang))
                .sorted((a, b) -> Integer.compare(b.getValue(), a.getValue()))
                .limit(topN)
                .forEach(e -> System.out.printf(Locale.US,
                        "    %-8s  %5d (%.1f%%)%n",
                        e.getKey(), e.getValue(),
                        100.0 * e.getValue() / finalTotal));

        // Top false positives (predicted targetLang when true=X)
        int fpTotal = trueWhen.values().stream()
                .mapToInt(Integer::intValue).sum();
        System.out.printf(Locale.US,
                "  Predicted=%s when actually (top %d, %d total FP):%n",
                targetLang, topN, fpTotal);
        trueWhen.entrySet().stream()
                .sorted((a, b) -> Integer.compare(b.getValue(), a.getValue()))
                .limit(topN)
                .forEach(e -> System.out.printf(Locale.US,
                        "    %-8s  %5d%n", e.getKey(), e.getValue()));
    }

    static class LangF1 {
        final String lang;
        final double f1;

        LangF1(String lang, double f1) {
            this.lang = lang;
            this.f1 = f1;
        }
    }

    static class EvalResult {
        final double macroF1;
        final double medianF1;
        final double accuracy;
        final int numLangs;
        final int numAbove90;
        final int total;
        final List<LangF1> perLang; // sorted worst-first

        EvalResult(double macroF1, double medianF1,
                   double accuracy, int numLangs,
                   int numAbove90, int total,
                   List<LangF1> perLang) {
            this.macroF1 = macroF1;
            this.medianF1 = medianF1;
            this.accuracy = accuracy;
            this.numLangs = numLangs;
            this.numAbove90 = numAbove90;
            this.total = total;
            this.perLang = perLang;
        }
    }

    /**
     * Evaluate quantized model with macro-F1, median per-language F1,
     * micro accuracy, and per-language breakdown.
     * Test data is raw text ��� the extractor runs the full pipeline.
     */
    static EvalResult evaluateQuantized(CharSoupModel model,
                                        List<LabeledSentence> data) {
        FeatureExtractor extractor = model.createExtractor();
        int n = model.getNumClasses();
        Map<String, Integer> labelIndex = new HashMap<>();
        for (int i = 0; i < n; i++) {
            labelIndex.put(model.getLabel(i), i);
        }

        int[] tp = new int[n];
        int[] fp = new int[n];
        int[] fn = new int[n];
        int correct = 0;
        int total = 0;

        for (LabeledSentence s : data) {
            Integer trueIdx = labelIndex.get(s.getLanguage());
            if (trueIdx == null) {
                continue;
            }
            int[] features = extractor.extract(s.getText());
            float[] probs = model.predict(features);
            int predicted = argmax(probs);
            if (predicted == trueIdx) {
                tp[trueIdx]++;
                correct++;
            } else {
                fn[trueIdx]++;
                fp[predicted]++;
            }
            total++;
        }

        List<LangF1> perLang = new ArrayList<>();
        for (int c = 0; c < n; c++) {
            if (tp[c] + fn[c] == 0) {
                continue;
            }
            double prec = tp[c] + fp[c] > 0
                    ? (double) tp[c] / (tp[c] + fp[c]) : 0.0;
            double rec = (double) tp[c] / (tp[c] + fn[c]);
            double f1 = prec + rec > 0
                    ? 2.0 * prec * rec / (prec + rec) : 0.0;
            perLang.add(new LangF1(model.getLabel(c), f1));
        }
        // sort worst-first for easy scanning
        perLang.sort(Comparator.comparingDouble(x -> x.f1));

        int activeLangs = perLang.size();
        double f1Sum = 0;
        int numAbove90 = 0;
        for (LangF1 lf : perLang) {
            f1Sum += lf.f1;
            if (lf.f1 >= 0.90) {
                numAbove90++;
            }
        }
        double macroF1 = activeLangs > 0 ? f1Sum / activeLangs : 0.0;

        double medianF1 = 0.0;
        if (activeLangs > 0) {
            int mid = activeLangs / 2;
            medianF1 = activeLangs % 2 == 1
                    ? perLang.get(mid).f1
                    : (perLang.get(mid - 1).f1
                            + perLang.get(mid).f1) / 2.0;
        }

        double accuracy = total > 0 ? (double) correct / total : 0.0;
        return new EvalResult(macroF1, medianF1, accuracy,
                activeLangs, numAbove90, total, perLang);
    }

    /**
     * Return a copy of {@code data} where each sentence's text is
     * truncated to {@code maxChars} Unicode code units. Only sentences
     * that have at least one character after truncation are included.
     * Sentences already shorter than {@code maxChars} are included as-is.
     */
    static List<LabeledSentence> truncateTestData(
            List<LabeledSentence> data, int maxChars) {
        List<LabeledSentence> result =
                new ArrayList<>(data.size());
        for (LabeledSentence s : data) {
            String text = s.getText();
            if (text.length() > maxChars) {
                text = text.substring(0, maxChars);
            }
            if (!text.isEmpty()) {
                result.add(new LabeledSentence(
                        s.getLanguage(), text));
            }
        }
        return result;
    }

    private static int argmax(float[] arr) {
        int best = 0;
        for (int i = 1; i < arr.length; i++) {
            if (arr[i] > arr[best]) {
                best = i;
            }
        }
        return best;
    }

    // ================================================================
    //  Helpers
    // ================================================================

    /**
     * Collect all language labels from the pool directory
     * (file names = language codes).
     */
    static String[] collectLabels(Path poolDir)
            throws IOException {
        List<String> labels = new ArrayList<>();
        try (DirectoryStream<Path> ds =
                     Files.newDirectoryStream(poolDir)) {
            for (Path p : ds) {
                if (Files.isRegularFile(p)) {
                    labels.add(
                            p.getFileName().toString());
                }
            }
        }
        Collections.sort(labels);
        return labels.toArray(new String[0]);
    }

    static List<LabeledSentence> readPreprocessedFile(Path file)
            throws IOException {
        List<LabeledSentence> sentences = new ArrayList<>();
        try (BufferedReader reader = Files.newBufferedReader(
                file, StandardCharsets.UTF_8)) {
            String line;
            while ((line = reader.readLine()) != null) {
                int tab = line.indexOf('\t');
                if (tab < 0) {
                    continue;
                }
                sentences.add(new LabeledSentence(
                        line.substring(0, tab),
                        line.substring(tab + 1)));
            }
        }
        return sentences;
    }
}