Phase2Trainer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.VarHandle;
import java.nio.channels.Channels;
import java.nio.channels.FileChannel;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.DoubleAdder;

import org.apache.tika.langdetect.charsoup.FeatureExtractor;
import org.apache.tika.langdetect.charsoup.ScriptAwareFeatureExtractor;

/**
 * Phase 2 multinomial logistic regression trainer.
 * <p>
 * Streams training data from disk each epoch to avoid loading
 * the full corpus into memory. Data must be pre-shuffled during
 * data preparation.
 * <p>
 * Key features:
 * <ul>
 *   <li><b>Streaming</b>: reads training file from disk each
 *       epoch; flat memory usage regardless of corpus size</li>
 *   <li><b>Bucket-major layout</b>: weights as
 *       {@code [bucket][class]} for cache-friendly sparse
 *       feature access</li>
 *   <li><b>Adam ��� SGD</b>: Adam for fast convergence, then
 *       SGD for fine-tuning</li>
 *   <li><b>Within-epoch early stopping</b>: checkpoints on
 *       dev subsample every N sentences</li>
 *   <li><b>Across-epoch early stopping</b>: full dev eval
 *       after each epoch, patience-based</li>
 *   <li><b>Hogwild!</b>: lock-free parallel weight updates
 *       via batched line dispatch to worker threads</li>
 *   <li><b>Buffer reuse</b>: thread-local feature/logit
 *       buffers to minimize GC pressure</li>
 * </ul>
 */
public class Phase2Trainer {

    /**
     * VarHandle for {@code float[]} element access.
     * Swap between plain get/set and getOpaque/setOpaque
     * to experiment with visibility guarantees vs JIT
     * optimization. Plain access lets the JIT vectorize;
     * opaque prevents register caching across iterations.
     */
    private static final VarHandle FA =
            MethodHandles.arrayElementVarHandle(float[].class);

    private static float getO(float[] a, int i) {
        return (float) FA.get(a, i);  // plain access
    }

    private static void setO(float[] a, int i, float v) {
        FA.set(a, i, v);  // plain access
    }

    private final int numBuckets;

    // --- Optimizer hyperparameters ---
    private float adamLr = 0.001f;
    private float adamBeta1 = 0.9f;
    private float adamBeta2 = 0.999f;
    private float adamEpsilon = 1e-8f;

    private float sgdLrStart = 0.01f;
    private float sgdLrEnd = 0.001f;

    private float l2Lambda = 1e-5f;

    // --- Training schedule ---
    private int adamEpochs = 2;
    private int maxEpochs = 8;

    // --- Early stopping: within-epoch ---
    private int checkpointInterval = 200_000;
    private int rollingWindow = 5;
    private double withinEpochThreshold = 0.002;

    // --- Early stopping: across-epoch ---
    private int patience = 3;
    private double acrossEpochThreshold = 0.001;
    private int devSubsampleSize = 20_000;

    // --- Mini-batch ---
    /**
     * Gradient mini-batch size. Gradients are accumulated
     * over this many samples before a single Adam/SGD
     * update. Standard ML mini-batch ��� NOT just I/O
     * batching.
     */
    private int miniBatchSize = 64;

    // --- I/O batching ---
    /**
     * Lines read from disk before dispatching to threads.
     * Must be large enough that each thread gets a
     * meaningful contiguous slice.
     */
    private int batchSize = 100_000;

    // --- Threading ---
    /**
     * Threads for SGD epochs (Hogwild-safe).
     * Defaults to all available cores.
     */
    private int sgdThreads = Runtime.getRuntime()
            .availableProcessors();
    /**
     * Threads for Adam epochs. Each thread gets its own
     * moment arrays (m, v), so memory scales linearly:
     * {@code adamThreads * 2 * numBuckets * numClasses * 4}
     * bytes. Default 1 (single-threaded, no extra memory).
     */
    private int adamThreads = 1;
    private long seed = 42L;
    private boolean verbose = true;
    private boolean preprocessed = false;

    // --- Model parameters: BUCKET-MAJOR layout ---
    private float[][] weights; // [bucket][class]
    private float[] biases;    // [class]
    private String[] labels;
    private Map<String, Integer> labelIndex;
    private int numClasses;

    // --- Chunk byte offsets for seek-based shuffling ---
    private long[] chunkByteOffsets;

    // --- Adam state (same bucket-major layout) ---
    // Shared state (single-threaded Adam)
    private float[][] mW;      // [bucket][class] first moment
    private float[][] vW;      // [bucket][class] second moment
    private float[] mBias;     // [class]
    private float[] vBias;     // [class]
    private AtomicLong globalStep;

    // Per-thread Adam state (Hogwild Adam)
    private float[][][] perThreadMW;   // [thread][bucket][class]
    private float[][][] perThreadVW;   // [thread][bucket][class]
    private float[][] perThreadMBias;  // [thread][class]
    private float[][] perThreadVBias;  // [thread][class]
    private long[] perThreadStep;      // [thread]

    public Phase2Trainer(int numBuckets) {
        this.numBuckets = numBuckets;
    }

    // --- Builder-style setters ---

    public Phase2Trainer setAdamLr(float lr) {
        this.adamLr = lr;
        return this;
    }

    public Phase2Trainer setSgdLr(float start, float end) {
        this.sgdLrStart = start;
        this.sgdLrEnd = end;
        return this;
    }

    public Phase2Trainer setL2Lambda(float lambda) {
        this.l2Lambda = lambda;
        return this;
    }

    public Phase2Trainer setAdamEpochs(int epochs) {
        this.adamEpochs = epochs;
        return this;
    }

    public Phase2Trainer setMaxEpochs(int epochs) {
        this.maxEpochs = epochs;
        return this;
    }

    public Phase2Trainer setCheckpointInterval(int interval) {
        this.checkpointInterval = interval;
        return this;
    }

    public Phase2Trainer setPatience(int patience) {
        this.patience = patience;
        return this;
    }

    public Phase2Trainer setDevSubsampleSize(int size) {
        this.devSubsampleSize = size;
        return this;
    }

    public Phase2Trainer setMiniBatchSize(int size) {
        this.miniBatchSize = size;
        return this;
    }

    public Phase2Trainer setBatchSize(int size) {
        this.batchSize = size;
        return this;
    }

    public Phase2Trainer setNumThreads(int threads) {
        return setSgdThreads(threads).setAdamThreads(1);
    }

    public Phase2Trainer setSgdThreads(int threads) {
        if (threads < 1) {
            throw new IllegalArgumentException(
                    "sgdThreads must be >= 1");
        }
        this.sgdThreads = threads;
        return this;
    }

    public Phase2Trainer setAdamThreads(int threads) {
        if (threads < 1) {
            throw new IllegalArgumentException(
                    "adamThreads must be >= 1");
        }
        this.adamThreads = threads;
        return this;
    }

    public Phase2Trainer setSeed(long seed) {
        this.seed = seed;
        return this;
    }

    public Phase2Trainer setVerbose(boolean verbose) {
        this.verbose = verbose;
        return this;
    }

    public Phase2Trainer setPreprocessed(boolean preprocessed) {
        this.preprocessed = preprocessed;
        return this;
    }

    // ================================================================
    //  Training entry point
    // ================================================================

    /**
     * Train the model by streaming training data from disk.
     * <p>
     * The training file is read sequentially each epoch.
     * Data must be pre-shuffled during data preparation.
     *
     * @param trainFile path to preprocessed training file
     *                  (tab-separated: {@code lang\ttext})
     * @param devData   dev set (kept in memory for eval)
     * @throws IOException if reading fails
     */
    public void train(Path trainFile,
                      List<LabeledSentence> devData)
            throws IOException {
        long scanStart = System.nanoTime();
        int totalLines = scanLabels(trainFile);
        if (verbose) {
            System.out.printf(Locale.US,
                    "Scanned: %,d lines, %d classes  "
                            + "[%.1f s]%n",
                    totalLines, numClasses,
                    elapsed(scanStart));
        }
        initializeWeights();
        trainEpochs(trainFile, devData);
    }

    /**
     * Initialize the model with known labels, then train
     * with epoch-level resampling. Before each epoch, the
     * caller provides a freshly sampled and shuffled training
     * file via {@code epochFileSupplier}.
     *
     * @param allLabels         all language labels
     * @param epochFileSupplier called before each epoch to
     *                          produce a training file
     * @param devData           fixed dev set for evaluation
     */
    public void trainWithResampling(
            String[] allLabels,
            EpochFileSupplier epochFileSupplier,
            List<LabeledSentence> devData)
            throws IOException {
        initializeLabels(allLabels);
        initializeWeights();
        trainEpochs(null, devData, epochFileSupplier);
    }

    @FunctionalInterface
    public interface EpochFileSupplier {
        Path supply(int epochNum) throws IOException;
    }

    private void initializeLabels(String[] allLabels) {
        this.labels = allLabels.clone();
        Arrays.sort(this.labels);
        this.labelIndex = new HashMap<>();
        for (int i = 0; i < labels.length; i++) {
            this.labelIndex.put(labels[i], i);
        }
        this.numClasses = labels.length;
    }

    private void initializeWeights() {
        weights = new float[numBuckets][numClasses];
        biases = new float[numClasses];
        globalStep = new AtomicLong(0);

        if (adamThreads > 1) {
            perThreadMW =
                    new float[adamThreads][numBuckets][numClasses];
            perThreadVW =
                    new float[adamThreads][numBuckets][numClasses];
            perThreadMBias =
                    new float[adamThreads][numClasses];
            perThreadVBias =
                    new float[adamThreads][numClasses];
            perThreadStep = new long[adamThreads];
            mW = null;
            vW = null;
            mBias = null;
            vBias = null;
        } else {
            mW = new float[numBuckets][numClasses];
            vW = new float[numBuckets][numClasses];
            mBias = new float[numClasses];
            vBias = new float[numClasses];
            perThreadMW = null;
            perThreadVW = null;
            perThreadMBias = null;
            perThreadVBias = null;
            perThreadStep = null;
        }
    }

    /**
     * Run the epoch loop. If {@code staticFile} is non-null,
     * trains from that file every epoch (original behavior).
     * If {@code supplier} is non-null, calls it before each
     * epoch to get a fresh training file.
     */
    private void trainEpochs(Path staticFile,
                              List<LabeledSentence> devData)
            throws IOException {
        trainEpochs(staticFile, devData, null);
    }

    private void trainEpochs(Path staticFile,
                              List<LabeledSentence> devData,
                              EpochFileSupplier supplier)
            throws IOException {

        // For static file, scan chunk offsets once
        if (staticFile != null) {
            chunkByteOffsets = scanChunkOffsets(staticFile);
        }

        List<LabeledSentence> devSubsample =
                sampleDevSubset(devData, devSubsampleSize);

        if (verbose) {
            System.out.printf(Locale.US,
                    "Training: %d classes, %,d buckets, "
                            + "Adam=%d thread(s), "
                            + "SGD=%d thread(s)%n",
                    numClasses, numBuckets,
                    adamThreads, sgdThreads);
            System.out.printf(Locale.US,
                    "Schedule: Adam(lr=%.4f) x%d epochs, "
                            + "SGD(lr=%.4f->%.4f) x%d max%n",
                    adamLr, adamEpochs,
                    sgdLrStart, sgdLrEnd,
                    maxEpochs - adamEpochs);
            System.out.printf(Locale.US,
                    "Early stop: checkpoint every %,d sents "
                            + "(window=%d, thresh=%.4f), "
                            + "patience=%d%n",
                    checkpointInterval, rollingWindow,
                    withinEpochThreshold, patience);
            System.out.printf(Locale.US,
                    "Dev subsample: %,d sents, "
                            + "ioBatch=%,d lines, "
                            + "miniBatch=%d%n",
                    devSubsample.size(), batchSize,
                    miniBatchSize);
        }

        int maxThreads = Math.max(adamThreads, sgdThreads);
        ExecutorService pool = maxThreads > 1
                ? Executors.newFixedThreadPool(maxThreads)
                : null;

        double bestDevF1 = Double.NEGATIVE_INFINITY;
        int epochsWithoutImprovement = 0;

        try {
            for (int epoch = 0; epoch < maxEpochs; epoch++) {
                long epochStart = System.nanoTime();
                boolean useAdam = epoch < adamEpochs;
                float sgdLr = 0f;
                if (!useAdam) {
                    int sgdEpoch = epoch - adamEpochs;
                    int totalSgd = Math.max(1,
                            maxEpochs - adamEpochs);
                    float frac = totalSgd == 1 ? 0f
                            : (float) sgdEpoch
                            / (totalSgd - 1);
                    sgdLr = sgdLrStart
                            + frac * (sgdLrEnd - sgdLrStart);
                }

                // Resolve training file for this epoch
                Path trainFile;
                if (supplier != null) {
                    trainFile = supplier.supply(epoch);
                    chunkByteOffsets =
                            scanChunkOffsets(trainFile);
                } else {
                    trainFile = staticFile;
                }

                int totalLines = 0;
                for (int i = 0;
                     i < chunkByteOffsets.length; i++) {
                    totalLines += chunkSize;
                }

                String optLabel = useAdam
                        ? String.format(Locale.US,
                        "Adam(lr=%.4f)", adamLr)
                        : String.format(Locale.US,
                        "SGD(lr=%.4f)", sgdLr);

                EpochResult result = trainEpochStreaming(
                        pool, trainFile, useAdam, sgdLr,
                        devSubsample, epoch);

                long epochMs = (System.nanoTime()
                        - epochStart) / 1_000_000;

                F1Result devResult = devData != null
                        ? evaluateMacroF1(devData) : null;
                double devF1 = devResult != null
                        ? devResult.f1 : Double.NaN;

                if (verbose) {
                    System.out.printf(Locale.US,
                            "Epoch %d/%d  %s  "
                                    + "avgLoss=%.4f  "
                                    + "devF1=%.4f (%d langs)  "
                                    + "processed=%,d"
                                    + "%s  [%,d ms]%n",
                            epoch + 1, maxEpochs, optLabel,
                            result.avgLoss, devF1,
                            devResult != null
                                    ? devResult.numLangs : 0,
                            result.sentencesProcessed,
                            result.earlyStopped
                                    ? " (early-stopped)"
                                    : "",
                            epochMs);
                }

                // Across-epoch early stopping
                if (!Double.isNaN(devF1)) {
                    if (devF1 > bestDevF1
                            + acrossEpochThreshold) {
                        bestDevF1 = devF1;
                        epochsWithoutImprovement = 0;
                    } else {
                        epochsWithoutImprovement++;
                        if (epochsWithoutImprovement
                                >= patience) {
                            if (verbose) {
                                System.out.printf(Locale.US,
                                        "Stopping: no "
                                                + "improvement "
                                                + "for %d epochs "
                                                + "(best=%.4f)%n",
                                        patience, bestDevF1);
                            }
                            break;
                        }
                    }
                }
            }
        } finally {
            if (pool != null) {
                pool.shutdown();
            }
        }

        // Free Adam state
        mW = null;
        vW = null;
        mBias = null;
        vBias = null;
        perThreadMW = null;
        perThreadVW = null;
        perThreadMBias = null;
        perThreadVBias = null;
        perThreadStep = null;
    }

    // ================================================================
    //  File scanning
    // ================================================================

    /**
     * Scan the training file to discover all language labels
     * and count lines. Builds {@link #labels} and
     * {@link #labelIndex}.
     *
     * @return total number of lines
     */
    private int scanLabels(Path file) throws IOException {
        Map<String, Integer> idx = new HashMap<>();
        List<String> labelList = new ArrayList<>();
        int count = 0;
        try (BufferedReader br = Files.newBufferedReader(
                file, StandardCharsets.UTF_8)) {
            String line;
            while ((line = br.readLine()) != null) {
                int tab = line.indexOf('\t');
                if (tab < 0) {
                    continue;
                }
                String lang = line.substring(0, tab);
                if (allowedLangs != null
                        && !allowedLangs.contains(lang)) {
                    continue;
                }
                if (!idx.containsKey(lang)) {
                    idx.put(lang, labelList.size());
                    labelList.add(lang);
                }
                count++;
            }
        }
        this.labels = labelList.toArray(new String[0]);
        Arrays.sort(this.labels);
        this.labelIndex = new HashMap<>();
        for (int i = 0; i < labels.length; i++) {
            this.labelIndex.put(labels[i], i);
        }
        this.numClasses = labels.length;
        return count;
    }

    /**
     * Scan the file for byte offsets at chunk boundaries.
     * Records the byte position after every {@link #chunkSize}
     * newlines. Fast sequential scan ��� only counts newline
     * bytes, no UTF-8 decoding needed.
     *
     * @return array of byte offsets (first element is 0)
     */
    private long[] scanChunkOffsets(Path file)
            throws IOException {
        List<Long> offsets = new ArrayList<>();
        offsets.add(0L);
        long bytePos = 0;
        int lineCount = 0;
        try (BufferedInputStream bis =
                     new BufferedInputStream(
                             Files.newInputStream(file),
                             1 << 16)) {
            byte[] buf = new byte[8192];
            int n;
            while ((n = bis.read(buf)) != -1) {
                for (int i = 0; i < n; i++) {
                    if (buf[i] == '\n') {
                        lineCount++;
                        if (lineCount % chunkSize == 0) {
                            offsets.add(bytePos + i + 1);
                        }
                    }
                }
                bytePos += n;
            }
        }
        return offsets.stream()
                .mapToLong(Long::longValue).toArray();
    }

    // ================================================================
    //  Streaming epoch
    // ================================================================

    private static class EpochResult {
        double avgLoss;
        int sentencesProcessed;
        boolean earlyStopped;
    }

    /**
     * Size of chunks for chunk-level shuffling. Each epoch
     * the file is read into chunks of this size, then the
     * chunks are shuffled to vary data ordering across
     * epochs. Within each chunk, lines keep their original
     * (pre-shuffled) order.
     */
    private int chunkSize = 100_000;

    public Phase2Trainer setChunkSize(int size) {
        this.chunkSize = size;
        return this;
    }

    /**
     * Train one epoch with chunk-level shuffling.
     * <p>
     * Uses pre-computed byte offsets to seek to chunks in
     * shuffled order. Only one chunk (~{@link #chunkSize}
     * lines) is in memory at a time, so this scales to
     * arbitrarily large files.
     */
    private EpochResult trainEpochStreaming(
            ExecutorService pool, Path trainFile,
            boolean useAdam, float sgdLr,
            List<LabeledSentence> devSubsample,
            int epochNum)
            throws IOException {

        // Shuffle chunk order (different per epoch)
        int numChunks = chunkByteOffsets.length;
        int[] order = new int[numChunks];
        for (int i = 0; i < numChunks; i++) {
            order[i] = i;
        }
        Random rng = new Random(seed + epochNum * 31L);
        for (int i = numChunks - 1; i > 0; i--) {
            int j = rng.nextInt(i + 1);
            int tmp = order[i];
            order[i] = order[j];
            order[j] = tmp;
        }

        DoubleAdder totalLoss = new DoubleAdder();
        int processed = 0;
        double[] recentF1 = new double[rollingWindow];
        int checkCount = 0;
        boolean earlyStopped = false;
        int nextCheckpoint = checkpointInterval;

        // Reusable buffers for one chunk
        String[] cTexts = new String[chunkSize];
        int[] cLabels = new int[chunkSize];

        for (int ci = 0; ci < numChunks && !earlyStopped;
             ci++) {
            int chunkIdx = order[ci];

            // Read one chunk by seeking to its byte offset
            int fill = readChunk(trainFile,
                    chunkByteOffsets[chunkIdx],
                    cTexts, cLabels);

            // Shuffle lines within this chunk so that
            // mini-batches mix languages instead of seeing
            // one language at a time (data is written
            // language-by-language during prep).
            shuffleParallel(cTexts, cLabels, fill,
                    new Random(seed + epochNum * 31L
                            + chunkIdx * 7L));

            // Process chunk in batches
            for (int off = 0; off < fill && !earlyStopped;
                 off += batchSize) {
                int end = Math.min(off + batchSize, fill);
                int bLen = end - off;

                if (off == 0 && bLen == fill) {
                    processBatch(pool, cTexts, cLabels,
                            fill, useAdam, sgdLr,
                            totalLoss);
                } else {
                    String[] bTexts =
                            Arrays.copyOfRange(
                                    cTexts, off, end);
                    int[] bLabs =
                            Arrays.copyOfRange(
                                    cLabels, off, end);
                    processBatch(pool, bTexts, bLabs,
                            bLen, useAdam, sgdLr,
                            totalLoss);
                }
                processed += bLen;

                if (processed >= nextCheckpoint
                        && devSubsample != null
                        && !devSubsample.isEmpty()) {
                    earlyStopped = checkpoint(
                            processed, devSubsample,
                            recentF1, checkCount);
                    checkCount++;
                    nextCheckpoint = processed
                            + checkpointInterval;
                }
            }
        }

        EpochResult r = new EpochResult();
        r.avgLoss = processed > 0
                ? totalLoss.sum() / processed : 0;
        r.sentencesProcessed = processed;
        r.earlyStopped = earlyStopped;
        return r;
    }

    /**
     * Read one chunk starting at the given byte offset.
     * Reads up to {@link #chunkSize} valid lines into
     * the provided buffers.
     *
     * @return number of lines read
     */
    private int readChunk(Path file, long byteOffset,
                          String[] texts, int[] labels)
            throws IOException {
        int fill = 0;
        try (FileChannel fc = FileChannel.open(file,
                StandardOpenOption.READ)) {
            fc.position(byteOffset);
            BufferedReader br = new BufferedReader(
                    new InputStreamReader(
                            Channels.newInputStream(fc),
                            StandardCharsets.UTF_8));
            String line;
            while (fill < chunkSize
                    && (line = br.readLine()) != null) {
                int tab = line.indexOf('\t');
                if (tab < 0) {
                    continue;
                }
                String lang = line.substring(0, tab);
                Integer idx = labelIndex.get(lang);
                if (idx == null) {
                    continue;
                }
                texts[fill] = line.substring(tab + 1);
                labels[fill] = idx;
                fill++;
            }
        }
        return fill;
    }

    /**
     * Check dev subsample F1 and return true if epoch should
     * stop early.
     */
    /** Compute L2 norm of all weight values. */
    private double weightNorm() {
        double sum = 0;
        for (int b = 0; b < numBuckets; b++) {
            float[] wb = weights[b];
            for (int c = 0; c < numClasses; c++) {
                float w = getO(wb, c);
                sum += (double) w * w;
            }
        }
        return Math.sqrt(sum);
    }

    /** Max absolute weight value (detects blowup). */
    private float maxAbsWeight() {
        float max = 0;
        for (int b = 0; b < numBuckets; b++) {
            float[] wb = weights[b];
            for (int c = 0; c < numClasses; c++) {
                float a = Math.abs(getO(wb, c));
                if (a > max) {
                    max = a;
                }
            }
        }
        return max;
    }

    private boolean checkpoint(int processed,
                               List<LabeledSentence> devSub,
                               double[] recentF1,
                               int checkCount) {
        F1Result r = evaluateMacroF1(devSub);
        int slot = checkCount % rollingWindow;
        recentF1[slot] = r.f1;

        if (verbose) {
            System.out.printf(Locale.US,
                    "    checkpoint %,d sents: "
                            + "devSubF1=%.4f (%d langs) "
                            + "wNorm=%.1f maxW=%.3f%n",
                    processed, r.f1, r.numLangs,
                    weightNorm(), maxAbsWeight());
        }

        if (checkCount + 1 >= rollingWindow) {
            double minF1 = Double.MAX_VALUE;
            double maxF1 = Double.MIN_VALUE;
            for (double d : recentF1) {
                if (d < minF1) {
                    minF1 = d;
                }
                if (d > maxF1) {
                    maxF1 = d;
                }
            }
            if (maxF1 - minF1 < withinEpochThreshold) {
                if (verbose) {
                    System.out.printf(Locale.US,
                            "    early-stop at %,d sents "
                                    + "(F1 range=%.5f "
                                    + "< %.4f)%n",
                            processed, maxF1 - minF1,
                            withinEpochThreshold);
                }
                return true;
            }
        }
        return false;
    }

    // ================================================================
    //  Batch dispatch
    // ================================================================

    /**
     * Process a batch of lines. Uses the appropriate thread
     * count based on optimizer: adamThreads for Adam epochs,
     * sgdThreads for SGD epochs.
     */
    private void processBatch(ExecutorService pool,
                              String[] texts, int[] labels,
                              int count,
                              boolean useAdam, float sgdLr,
                              DoubleAdder totalLoss) {
        int threads = useAdam ? adamThreads : sgdThreads;
        if (threads > 1 && pool != null) {
            processBatchHogwild(pool, texts, labels, count,
                    useAdam, sgdLr, totalLoss, threads);
        } else {
            processBatchSingle(texts, labels, count,
                    useAdam, sgdLr, totalLoss);
        }
    }

    private void processBatchSingle(
            String[] texts, int[] batchLabels, int count,
            boolean useAdam, float sgdLr,
            DoubleAdder totalLoss) {
        FeatureExtractor ext = createExtractor();
        int[] featureBuf = new int[numBuckets];
        float[] logitBuf = new float[numClasses];
        int[] nzBuf = new int[numBuckets];

        if (useAdam) {
            // Mini-batch Adam: accumulate gradients, then
            // apply one update per mini-batch.
            float[][] gradAccumW =
                    new float[numBuckets][numClasses];
            float[] gradAccumB = new float[numClasses];
            int mbCount = 0;

            for (int i = 0; i < count; i++) {
                extractInto(ext, texts[i], featureBuf);
                double loss = forwardGrad(featureBuf,
                        batchLabels[i], logitBuf, nzBuf);
                totalLoss.add(loss);

                // Accumulate: gradW[b][c] += grad[c]*feat[b]
                int nnz = sparseIndex(featureBuf, nzBuf);
                for (int j = 0; j < nnz; j++) {
                    int b = nzBuf[j];
                    float fv = featureBuf[b];
                    float[] ab = gradAccumW[b];
                    for (int c = 0; c < numClasses; c++) {
                        ab[c] += logitBuf[c] * fv;
                    }
                }
                for (int c = 0; c < numClasses; c++) {
                    gradAccumB[c] += logitBuf[c];
                }
                mbCount++;

                if (mbCount == miniBatchSize) {
                    applyAdamMiniBatch(gradAccumW,
                            gradAccumB, mbCount, -1);
                    mbCount = 0;
                }
            }
            if (mbCount > 0) {
                applyAdamMiniBatch(gradAccumW,
                        gradAccumB, mbCount, -1);
            }
        } else {
            // Online SGD (Hogwild-safe with multi-thread)
            for (int i = 0; i < count; i++) {
                extractInto(ext, texts[i], featureBuf);
                double loss = forwardGrad(featureBuf,
                        batchLabels[i], logitBuf, nzBuf);
                totalLoss.add(loss);
                int nnz = sparseIndex(featureBuf, nzBuf);
                sgdUpdate(featureBuf, logitBuf, nnz,
                        nzBuf, sgdLr);
            }
        }
    }

    private void processBatchHogwild(
            ExecutorService pool,
            String[] texts, int[] batchLabels, int count,
            boolean useAdam, float sgdLr,
            DoubleAdder totalLoss, int threads) {

        List<Future<?>> futures =
                new ArrayList<>(threads);
        for (int t = 0; t < threads; t++) {
            int from = (int) ((long) count * t
                    / threads);
            int to = (int) ((long) count * (t + 1)
                    / threads);
            int tid = t;
            futures.add(pool.submit(() -> {
                FeatureExtractor ext = createExtractor();
                int[] featureBuf = new int[numBuckets];
                float[] logitBuf = new float[numClasses];
                int[] nzBuf = new int[numBuckets];
                double threadLoss = 0;

                if (useAdam) {
                    // Per-thread mini-batch Adam
                    float[][] gradAccumW =
                            new float[numBuckets]
                                    [numClasses];
                    float[] gradAccumB =
                            new float[numClasses];
                    int mbCount = 0;

                    for (int i = from; i < to; i++) {
                        extractInto(ext, texts[i],
                                featureBuf);
                        threadLoss += forwardGrad(
                                featureBuf,
                                batchLabels[i],
                                logitBuf, nzBuf);

                        int nnz = sparseIndex(
                                featureBuf, nzBuf);
                        for (int j = 0; j < nnz; j++) {
                            int b = nzBuf[j];
                            float fv = featureBuf[b];
                            float[] ab = gradAccumW[b];
                            for (int c = 0;
                                 c < numClasses; c++) {
                                ab[c] +=
                                        logitBuf[c] * fv;
                            }
                        }
                        for (int c = 0;
                             c < numClasses; c++) {
                            gradAccumB[c] +=
                                    logitBuf[c];
                        }
                        mbCount++;

                        if (mbCount == miniBatchSize) {
                            applyAdamMiniBatch(
                                    gradAccumW,
                                    gradAccumB,
                                    mbCount, tid);
                            mbCount = 0;
                        }
                    }
                    if (mbCount > 0) {
                        applyAdamMiniBatch(gradAccumW,
                                gradAccumB, mbCount, tid);
                    }
                } else {
                    // Online SGD (Hogwild)
                    for (int i = from; i < to; i++) {
                        extractInto(ext, texts[i],
                                featureBuf);
                        threadLoss += forwardGrad(
                                featureBuf,
                                batchLabels[i],
                                logitBuf, nzBuf);
                        int nnz = sparseIndex(
                                featureBuf, nzBuf);
                        sgdUpdate(featureBuf, logitBuf,
                                nnz, nzBuf, sgdLr);
                    }
                }
                totalLoss.add(threadLoss);
            }));
        }
        for (Future<?> f : futures) {
            try {
                f.get();
            } catch (Exception e) {
                throw new RuntimeException(
                        "Hogwild thread failed", e);
            }
        }
    }

    // ================================================================
    //  Forward pass + gradient (no weight update)
    // ================================================================

    /**
     * Forward pass and gradient computation for one sample.
     * After return, {@code logitBuf} contains the gradient:
     * {@code grad[c] = prob[c] - 1{c == trueClass}}.
     *
     * @return cross-entropy loss for this sample
     */
    private double forwardGrad(int[] features,
                               int trueClass,
                               float[] logitBuf,
                               int[] nzBuf) {
        int nnz = sparseIndex(features, nzBuf);

        for (int c = 0; c < numClasses; c++) {
            logitBuf[c] = getO(biases, c);
        }
        for (int i = 0; i < nnz; i++) {
            int b = nzBuf[i];
            float fv = features[b];
            float[] wb = weights[b];
            for (int c = 0; c < numClasses; c++) {
                logitBuf[c] += getO(wb, c) * fv;
            }
        }

        softmaxInPlace(logitBuf);

        double loss = -Math.log(
                Math.max(logitBuf[trueClass], 1e-10));

        logitBuf[trueClass] -= 1.0f;
        return loss;
    }

    /**
     * Build sparse index of non-zero features.
     *
     * @return count of non-zero entries
     */
    private int sparseIndex(int[] features, int[] nzBuf) {
        int nnz = 0;
        for (int b = 0; b < numBuckets; b++) {
            if (features[b] != 0) {
                nzBuf[nnz++] = b;
            }
        }
        return nnz;
    }

    // ================================================================
    //  Mini-batch Adam update
    // ================================================================

    /**
     * Apply one AdamW update from accumulated mini-batch
     * gradients. Averages the accumulated gradient by
     * {@code mbCount}, then runs a single Adam step.
     * Clears the accumulator buffers after the update.
     *
     * @param gradAccumW accumulated weight gradients
     *                   [bucket][class]
     * @param gradAccumB accumulated bias gradients [class]
     * @param mbCount    number of samples in this
     *                   mini-batch
     * @param threadId   thread ID for per-thread moments,
     *                   or -1 for shared moments
     */
    private void applyAdamMiniBatch(
            float[][] gradAccumW, float[] gradAccumB,
            int mbCount, int threadId) {
        float scale = 1.0f / mbCount;

        long t;
        float bc1, bc2;
        float[][] lMW, lVW;
        float[] lMB, lVB;

        if (threadId >= 0 && perThreadMW != null) {
            t = ++perThreadStep[threadId];
            lMW = perThreadMW[threadId];
            lVW = perThreadVW[threadId];
            lMB = perThreadMBias[threadId];
            lVB = perThreadVBias[threadId];
        } else {
            t = globalStep.incrementAndGet();
            lMW = mW;
            lVW = vW;
            lMB = mBias;
            lVB = vBias;
        }
        bc1 = 1f - (float) Math.pow(adamBeta1, t);
        bc2 = 1f - (float) Math.pow(adamBeta2, t);

        // Weight update ��� only touch buckets that have
        // accumulated gradient (sparse)
        for (int b = 0; b < numBuckets; b++) {
            float[] ab = gradAccumW[b];
            boolean touched = false;
            for (int c = 0; c < numClasses; c++) {
                if (ab[c] != 0f) {
                    touched = true;
                    break;
                }
            }
            if (!touched) {
                continue;
            }

            float[] wb = weights[b];
            float[] mb = lMW[b];
            float[] vb = lVW[b];
            for (int c = 0; c < numClasses; c++) {
                float g = ab[c] * scale;
                float m = adamBeta1 * mb[c]
                        + (1 - adamBeta1) * g;
                float v = adamBeta2 * vb[c]
                        + (1 - adamBeta2) * g * g;
                mb[c] = m;
                vb[c] = v;
                float mHat = m / bc1;
                float vHat = v / bc2;
                float w = getO(wb, c);
                w -= adamLr * mHat
                        / ((float) Math.sqrt(vHat)
                        + adamEpsilon);
                w -= adamLr * l2Lambda * w;
                setO(wb, c, w);
                ab[c] = 0f; // clear as we go
            }
        }

        // Bias update
        for (int c = 0; c < numClasses; c++) {
            float g = gradAccumB[c] * scale;
            float m = adamBeta1 * lMB[c]
                    + (1 - adamBeta1) * g;
            float v = adamBeta2 * lVB[c]
                    + (1 - adamBeta2) * g * g;
            lMB[c] = m;
            lVB[c] = v;
            float mHat = m / bc1;
            float vHat = v / bc2;
            float bi = getO(biases, c);
            bi -= adamLr * mHat
                    / ((float) Math.sqrt(vHat)
                    + adamEpsilon);
            setO(biases, c, bi);
            gradAccumB[c] = 0f; // clear
        }
    }

    /**
     * Online SGD update for one sample. Gradient is
     * {@code grad[c] = prob[c] - 1{c == trueClass}}.
     * Weight decay is coupled (standard L2).
     */
    private void sgdUpdate(int[] features, float[] grad,
                           int nnz, int[] nzIdx,
                           float lr) {
        for (int i = 0; i < nnz; i++) {
            int b = nzIdx[i];
            float fv = features[b];
            float[] wb = weights[b];
            for (int c = 0; c < numClasses; c++) {
                float w = getO(wb, c);
                setO(wb, c, w - lr * (grad[c] * fv
                        + l2Lambda * w));
            }
        }
        for (int c = 0; c < numClasses; c++) {
            float bi = getO(biases, c);
            setO(biases, c, bi - lr * grad[c]);
        }
    }

    // ================================================================
    //  Evaluation
    // ================================================================

    /**
     * Compute macro-averaged F1 on a dataset using current
     * float32 weights.
     */
    /** Result of a macro F1 evaluation. */
    public static class F1Result {
        public final double f1;
        public final int numLangs;

        F1Result(double f1, int numLangs) {
            this.f1 = f1;
            this.numLangs = numLangs;
        }
    }

    public F1Result evaluateMacroF1(
            List<LabeledSentence> data) {
        return evaluateMacroF1(data, null);
    }

    /**
     * Macro F1 evaluation. If {@code perLangOut} is non-null, it is
     * populated with per-language F1 scores keyed by language label.
     */
    public F1Result evaluateMacroF1(
            List<LabeledSentence> data,
            java.util.Map<String, Double> perLangOut) {
        int[][] counts = new int[numClasses][3];
        FeatureExtractor ext = createExtractor();
        int[] featureBuf = new int[numBuckets];
        float[] logitBuf = new float[numClasses];

        for (LabeledSentence s : data) {
            Integer trueIdx =
                    labelIndex.get(s.getLanguage());
            if (trueIdx == null) {
                continue;
            }
            extractInto(ext, s.getText(), featureBuf);
            int predicted =
                    predictFromBuf(featureBuf, logitBuf);
            if (predicted == trueIdx) {
                counts[trueIdx][0]++;
            } else {
                counts[trueIdx][2]++;
                counts[predicted][1]++;
            }
        }

        double f1Sum = 0;
        int n = 0;
        for (int c = 0; c < numClasses; c++) {
            int tp = counts[c][0];
            int fp = counts[c][1];
            int fn = counts[c][2];
            if (tp + fn == 0) {
                continue;
            }
            double p = tp + fp > 0
                    ? (double) tp / (tp + fp) : 0;
            double r = (double) tp / (tp + fn);
            double f1 = p + r > 0
                    ? 2 * p * r / (p + r) : 0;
            f1Sum += f1;
            n++;
            if (perLangOut != null) {
                perLangOut.put(labels[c], f1);
            }
        }
        return new F1Result(
                n > 0 ? f1Sum / n : 0, n);
    }

    private int predictFromBuf(int[] features,
                               float[] logitBuf) {
        for (int c = 0; c < numClasses; c++) {
            logitBuf[c] = getO(biases, c);
        }
        for (int b = 0; b < numBuckets; b++) {
            if (features[b] != 0) {
                float fv = features[b];
                float[] wb = weights[b];
                for (int c = 0; c < numClasses; c++) {
                    logitBuf[c] += getO(wb, c) * fv;
                }
            }
        }
        int best = 0;
        for (int c = 1; c < numClasses; c++) {
            if (logitBuf[c] > logitBuf[best]) {
                best = c;
            }
        }
        return best;
    }

    /**
     * Predict the language label for a text string.
     */
    public String predict(String text) {
        FeatureExtractor ext = createExtractor();
        int[] features = new int[numBuckets];
        extractInto(ext, text, features);
        float[] logits = new float[numClasses];
        return labels[predictFromBuf(features, logits)];
    }

    // ================================================================
    //  Accessors
    // ================================================================

    /**
     * Return weights transposed to class-major
     * {@code [class][bucket]} for {@link ModelQuantizer}.
     */
    public float[][] getWeightsClassMajor() {
        float[][] cm = new float[numClasses][numBuckets];
        for (int b = 0; b < numBuckets; b++) {
            float[] wb = weights[b];
            for (int c = 0; c < numClasses; c++) {
                cm[c][b] = getO(wb, c);
            }
        }
        return cm;
    }

    public float[] getBiases() {
        return biases;
    }

    public String[] getLabels() {
        return labels;
    }

    public int getNumBuckets() {
        return numBuckets;
    }

    public int getNumClasses() {
        return numClasses;
    }

    public FeatureExtractor getExtractor() {
        return createExtractor();
    }

    public Map<String, Integer> getLabelIndex() {
        return labelIndex;
    }

    /**
     * Predict using caller-supplied buffers to avoid per-call allocation.
     * Safe to call from multiple threads provided each thread has its own
     * {@code ext}, {@code featureBuf}, and {@code logitBuf}.
     */
    public String predictBuffered(String text, FeatureExtractor ext,
                                  int[] featureBuf, float[] logitBuf) {
        extractInto(ext, text, featureBuf);
        return labels[predictFromBuf(featureBuf, logitBuf)];
    }

    // ================================================================
    //  Helpers
    // ================================================================

    private Set<String> allowedLangs = null; // null = all langs

    public Phase2Trainer setAllowedLangs(Set<String> langs) {
        this.allowedLangs = langs == null || langs.isEmpty()
                ? null : langs;
        return this;
    }

    private boolean useTrigrams    = false;
    private boolean useSkipBigrams = false;
    private boolean useWordSuffixes = false;
    private boolean useWordSuffix4  = false;
    private boolean useWordPrefix   = false;
    private boolean useWordUnigrams = true;
    private boolean useCharUnigrams = false;
    private boolean use4grams       = false;
    private boolean use5grams       = false;

    public Phase2Trainer setUseTrigrams(boolean v) {
        this.useTrigrams = v;
        return this;
    }

    public Phase2Trainer setUseSkipBigrams(boolean v) {
        this.useSkipBigrams = v;
        return this;
    }

    public Phase2Trainer setUseWordSuffixes(boolean v) {
        this.useWordSuffixes = v;
        return this;
    }

    public Phase2Trainer setUseWordSuffix4(boolean v) {
        this.useWordSuffix4 = v;
        return this;
    }

    public Phase2Trainer setUseWordPrefix(boolean v) {
        this.useWordPrefix = v;
        return this;
    }

    public Phase2Trainer setUseWordUnigrams(boolean v) {
        this.useWordUnigrams = v;
        return this;
    }

    public Phase2Trainer setUseCharUnigrams(boolean v) {
        this.useCharUnigrams = v;
        return this;
    }

    public Phase2Trainer setUse4grams(boolean v) {
        this.use4grams = v;
        return this;
    }

    public Phase2Trainer setUse5grams(boolean v) {
        this.use5grams = v;
        return this;
    }

    public boolean isUseTrigrams()     { return useTrigrams; }
    public boolean isUseSkipBigrams()  { return useSkipBigrams; }
    public boolean isUseWordSuffixes() { return useWordSuffixes; }
    public boolean isUseWordSuffix4()  { return useWordSuffix4; }
    public boolean isUseWordPrefix()   { return useWordPrefix; }
    public boolean isUseWordUnigrams() { return useWordUnigrams; }
    public boolean isUseCharUnigrams() { return useCharUnigrams; }
    public boolean isUse4grams()       { return use4grams; }
    public boolean isUse5grams()       { return use5grams; }

    private FeatureExtractor createExtractor() {
        boolean research = useTrigrams || useSkipBigrams
                || useWordSuffixes || useWordSuffix4
                || useWordPrefix || useCharUnigrams
                || use4grams || use5grams;
        if (research) {
            return new ResearchFeatureExtractor(numBuckets,
                    useTrigrams, useSkipBigrams,
                    useWordSuffixes, useWordSuffix4,
                    useWordPrefix, useWordUnigrams,
                    useCharUnigrams, use4grams, use5grams);
        }
        return new ScriptAwareFeatureExtractor(numBuckets);
    }

    private void extractInto(FeatureExtractor ext,
                             String text, int[] buf) {
        if (preprocessed) {
            ext.extractFromPreprocessed(text, buf, true);
        } else {
            ext.extract(text, buf);
        }
    }

    private List<LabeledSentence> sampleDevSubset(
            List<LabeledSentence> devData, int maxSize) {
        if (devData == null || devData.size() <= maxSize) {
            return devData != null
                    ? devData : Collections.emptyList();
        }
        Map<String, List<LabeledSentence>> byLang =
                new HashMap<>();
        for (LabeledSentence s : devData) {
            byLang.computeIfAbsent(s.getLanguage(),
                    k -> new ArrayList<>()).add(s);
        }
        Random rng = new Random(seed + 7);
        List<LabeledSentence> sample = new ArrayList<>();
        double ratio = (double) maxSize / devData.size();
        for (List<LabeledSentence> langSents :
                byLang.values()) {
            int take = Math.max(1,
                    (int) (langSents.size() * ratio));
            Collections.shuffle(langSents, rng);
            sample.addAll(langSents.subList(
                    0, Math.min(take, langSents.size())));
        }
        return sample;
    }

    private static void softmaxInPlace(float[] logits) {
        float max = Float.NEGATIVE_INFINITY;
        for (float v : logits) {
            if (v > max) {
                max = v;
            }
        }
        float sum = 0f;
        for (int i = 0; i < logits.length; i++) {
            logits[i] = (float) Math.exp(logits[i] - max);
            sum += logits[i];
        }
        if (sum > 0f) {
            for (int i = 0; i < logits.length; i++) {
                logits[i] /= sum;
            }
        }
    }

    /**
     * Fisher���Yates shuffle of parallel arrays (texts and
     * labels) in-place, up to {@code len} elements.
     */
    private static void shuffleParallel(
            String[] texts, int[] labels, int len,
            Random rng) {
        for (int i = len - 1; i > 0; i--) {
            int j = rng.nextInt(i + 1);
            String tmpT = texts[i];
            texts[i] = texts[j];
            texts[j] = tmpT;
            int tmpL = labels[i];
            labels[i] = labels[j];
            labels[j] = tmpL;
        }
    }

    private static double elapsed(long startNanos) {
        return (System.nanoTime() - startNanos)
                / 1_000_000_000.0;
    }
}