GenerativeLanguageModel.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/**
 * Dense INT8 generative character n-gram model for languageness scoring.
 *
 * <p>Computes an approximate per-n-gram average log P(text | language).
 * Higher scores indicate the decoded text is more consistent with the named
 * language.  The score is used to detect charset errors (mojibake), garbled
 * text, and other corpus-quality issues.
 *
 * <h3>Feature types (v4 model)</h3>
 * <ul>
 *   <li><b>CJK languages</b>: character unigrams and bigrams from CJK/kana
 *       codepoints.</li>
 *   <li><b>Non-CJK languages</b>: character unigrams (per-letter streaming);
 *       character bigrams and trigrams with positional salts (BOW/MID/EOW/
 *       FULL_WORD) applied at the word level; bidirectional word bigrams
 *       (short-anchor forward and backward).</li>
 *   <li><b>Script distribution</b>: normalized per-script letter proportions
 *       using {@link GlmScriptCategory} (34 fine-grained categories, no
 *       OTHER catch-all).</li>
 * </ul>
 *
 * <h3>Positional salting</h3>
 * N-grams use a single salt byte (BOW/MID/EOW/FULL_WORD) as the first byte
 * of the FNV hash rather than sentinel characters.  This means:
 * <ul>
 *   <li>N-grams always contain N real characters ��� no {@code _} padding.</li>
 *   <li>The same bigram at the start vs. middle of a word maps to a different
 *       bucket, encoding positional information without polluting codepoint
 *       space.</li>
 * </ul>
 *
 * <h3>Bidirectional word bigrams</h3>
 * <ul>
 *   <li><b>Forward</b>: fired when the previous word is short (��� {@value
 *       #MAX_SHORT_WORD} chars) ��� captures function-word-in-context like
 *       "the X", "de X", "�� X".</li>
 *   <li><b>Backward</b>: fired when the current word is short ��� captures
 *       "X the", "X de", "X ��" (what precedes a function word).</li>
 * </ul>
 *
 * <h3>Binary format ({@code GLM1})</h3>
 * <pre>
 *   INT  magic    = 0x474C4D31  ("GLM1")
 *   INT  version                (3 = legacy, 4 = current)
 *   INT  numLangs
 *   INT  cjkUnigramBuckets
 *   INT  cjkBigramBuckets
 *   INT  noncjkUnigramBuckets
 *   INT  noncjkBigramBuckets
 *   INT  noncjkTrigramBuckets
 *   INT  scriptCategories
 *   INT  wordBigramBuckets      (v4+ only; 0 absent in v3)
 *   For each language:
 *     SHORT  codeLen
 *     BYTES  langCode (UTF-8)
 *     BYTE   isCjk (0|1)
 *     FLOAT  scoreMean
 *     FLOAT  scoreStdDev
 *     BYTES  unigramTable  [cjkUnigramBuckets | noncjkUnigramBuckets]
 *     BYTES  bigramTable   [cjkBigramBuckets  | noncjkBigramBuckets]
 *     BYTES  trigramTable  [noncjkTrigramBuckets] (absent for CJK)
 *     BYTES  wordBigramTable [wordBigramBuckets] (v4+, absent for CJK)
 *     BYTES  scriptTable   [scriptCategories]
 * </pre>
 */
public class GenerativeLanguageModel {

    // ---- Bucket counts (v4/v5) ----

    public static final int CJK_UNIGRAM_BUCKETS     =  8_192;
    public static final int CJK_BIGRAM_BUCKETS      = 16_384;
    public static final int NONCJK_UNIGRAM_BUCKETS  =  4_096;
    public static final int NONCJK_BIGRAM_BUCKETS   =  8_192;
    public static final int NONCJK_TRIGRAM_BUCKETS  = 16_384;
    public static final int WORD_BIGRAM_BUCKETS      =  8_192;  // v4: new

    /**
     * Script categories used for the script distribution feature.
     * Matches {@link GlmScriptCategory#COUNT} at model-build time; the actual
     * count is stored in the binary so older v3 readers still work.
     */
    public static final int SCRIPT_CATEGORIES = GlmScriptCategory.COUNT;

    /** Default classpath resource for the bundled generative model. */
    public static final String DEFAULT_MODEL_RESOURCE =
            "/org/apache/tika/langdetect/charsoup/langdetect-generative-v4-20260320.bin";

    /**
     * Quantization floor.  Log-probabilities below this are clamped before
     * quantizing; stored values never go lower.
     */
    public static final float LOGP_MIN = -18.0f;

    private static final int MAGIC   = 0x474C4D31;  // "GLM1"
    private static final int VERSION = 4;

    // ---- FNV constants ----

    static final int FNV_BASIS = 0x811c9dc5;

    /** Positional salt bytes ��� same scheme as {@link SaltedNgramFeatureExtractor}. */
    static final int SALT_MID            = 0x00;
    static final int SALT_BOW            = 0x01;
    static final int SALT_EOW            = 0x02;
    static final int SALT_FULL_WORD      = 0x03;
    static final int SALT_CJK_UNIGRAM    = 0x04;
    static final int SALT_NONCJK_UNIGRAM = 0x05;
    static final int SALT_WORD_FWD       = 0x06;  // short-prev ��� any-next
    static final int SALT_WORD_BWD       = 0x07;  // any-prev ��� short-next

    /**
     * Maximum anchor-word length for word bigrams.
     * Words of 1���{@value} characters are treated as anchors.
     */
    static final int MAX_SHORT_WORD = 3;

    // ---- Legacy v3 constants (kept for backward-compatible loading only) ----

    /** @deprecated Used only when loading v3 models. */
    @Deprecated
    static final int BIGRAM_BASIS          = ScriptAwareFeatureExtractor.BIGRAM_BASIS;
    /** @deprecated Used only when loading v3 models. */
    @Deprecated
    static final int CJK_UNIGRAM_BASIS     = ScriptAwareFeatureExtractor.UNIGRAM_BASIS;
    /** @deprecated Used only when loading v3 models. */
    @Deprecated
    static final int NONCJK_UNIGRAM_BASIS  = 0x1a3f7c4e;
    /** @deprecated Used only when loading v3 models. */
    @Deprecated
    static final int TRIGRAM_BASIS         = 0x7e3d9b21;
    /** @deprecated Sentinel used in v3 n-gram extraction. */
    @Deprecated
    static final int SENTINEL              = '_';

    // ---- Model state ----

    private final int             modelVersion;
    private final List<String>    langIds;
    private final Map<String, Integer> langIndex;
    private final boolean[]       isCjk;
    private final byte[][]        unigramTables;
    private final byte[][]        bigramTables;
    private final byte[][]        trigramTables;
    private final byte[][]        wordBigramTables;   // null for v3 models
    private final byte[][]        scriptTables;
    private final int             loadedScriptCats;   // actual count from binary
    private final float[]         scoreMeans;
    private final float[]         scoreStdDevs;

    private GenerativeLanguageModel(
            int          modelVersion,
            List<String> langIds,
            boolean[]    isCjk,
            byte[][]     unigramTables,
            byte[][]     bigramTables,
            byte[][]     trigramTables,
            byte[][]     wordBigramTables,
            byte[][]     scriptTables,
            int          loadedScriptCats,
            float[]      scoreMeans,
            float[]      scoreStdDevs) {
        this.modelVersion      = modelVersion;
        this.langIds           = Collections.unmodifiableList(new ArrayList<>(langIds));
        this.isCjk             = isCjk;
        this.unigramTables     = unigramTables;
        this.bigramTables      = bigramTables;
        this.trigramTables     = trigramTables;
        this.wordBigramTables  = wordBigramTables;
        this.scriptTables      = scriptTables;
        this.loadedScriptCats  = loadedScriptCats;
        this.scoreMeans        = scoreMeans;
        this.scoreStdDevs      = scoreStdDevs;
        Map<String, Integer> idx = new HashMap<>(langIds.size() * 2);
        for (int i = 0; i < langIds.size(); i++) {
            idx.put(langIds.get(i), i);
        }
        this.langIndex = Collections.unmodifiableMap(idx);
    }

    // ---- Public API ----

    public List<String> getLanguages() {
        return langIds;
    }

    public boolean isCjk(String language) {
        Integer i = langIndex.get(language);
        return i != null && isCjk[i];
    }

    /**
     * Per-n-gram average log-probability of {@code text} under {@code language}.
     *
     * @return a value in [{@link #LOGP_MIN}, 0], or {@link Float#NaN} if the
     *         language is unknown or the text yields no scorable n-grams.
     */
    public float score(String text, String language) {
        if (text == null || text.isEmpty()) return Float.NaN;
        Integer li = langIndex.get(language);
        if (li == null) return Float.NaN;
        String pp = CharSoupFeatureExtractor.preprocess(text);
        if (pp.isEmpty()) return Float.NaN;

        double[] sum = {0.0};
        int[]    cnt = {0};

        if (modelVersion >= 4) {
            scoreV4(pp, li, sum, cnt);
        } else {
            scoreV3(pp, li, sum, cnt);
        }

        return cnt[0] == 0 ? Float.NaN : (float) (sum[0] / cnt[0]);
    }

    // ---- Scoring ��� v4 (salted n-grams + word bigrams + fine-grained script) ----

    private void scoreV4(String pp, int li, double[] sum, int[] cnt) {
        if (isCjk[li]) {
            byte[] uniT = unigramTables[li];
            byte[] biT  = bigramTables[li];
            extractCjkFeaturesV4(pp,
                    h -> { sum[0] += dequantize(uniT[h % uniT.length]);
                           cnt[0]++; },
                    h -> { sum[0] += dequantize(biT[h % biT.length]);
                           cnt[0]++; });
        } else {
            byte[] uniT  = unigramTables[li];
            byte[] biT   = bigramTables[li];
            byte[] triT  = trigramTables[li];
            byte[] wbiT  = wordBigramTables != null ? wordBigramTables[li] : null;
            HashConsumer wbiSink = wbiT != null
                    ? h -> { sum[0] += dequantize(wbiT[h % wbiT.length]);
                             cnt[0]++; }
                    : null;
            extractNonCjkFeaturesV4(pp,
                    h -> { sum[0] += dequantize(uniT[h % uniT.length]);
                           cnt[0]++; },
                    h -> { sum[0] += dequantize(biT[h % biT.length]);
                           cnt[0]++; },
                    h -> { sum[0] += dequantize(triT[h % triT.length]);
                           cnt[0]++; },
                    wbiSink);
        }

        if (scriptTables != null && scriptTables[li] != null) {
            addScriptContributionsV4(pp, scriptTables[li], sum, cnt);
        }
    }

    // ---- Scoring ��� v3 (legacy sentinel n-grams) ----

    private void scoreV3(String pp, int li, double[] sum, int[] cnt) {
        if (isCjk[li]) {
            byte[] uniT = unigramTables[li];
            byte[] biT  = bigramTables[li];
            extractCjkNgrams(pp,
                    h -> { sum[0] += dequantize(uniT[h % CJK_UNIGRAM_BUCKETS]);
                           cnt[0]++; },
                    h -> { sum[0] += dequantize(biT[h % 32_768]);
                           cnt[0]++; });
        } else {
            byte[] uniT = unigramTables[li];
            byte[] biT  = bigramTables[li];
            byte[] triT = trigramTables[li];
            extractNonCjkNgrams(pp,
                    h -> { sum[0] += dequantize(uniT[h % uniT.length]);
                           cnt[0]++; },
                    h -> { sum[0] += dequantize(biT[h % biT.length]);
                           cnt[0]++; },
                    h -> { sum[0] += dequantize(triT[h % triT.length]);
                           cnt[0]++; });
        }

        if (scriptTables != null && scriptTables[li] != null) {
            addScriptContributionsV3(pp, scriptTables[li], sum, cnt);
        }
    }

    /**
     * Score {@code text} against all languages and return the best match.
     */
    public Map.Entry<String, Float> bestMatch(String text) {
        String best = null;
        float  bestScore = Float.NEGATIVE_INFINITY;
        for (String lang : langIds) {
            float s = score(text, lang);
            if (!Float.isNaN(s) && s > bestScore) {
                bestScore = s;
                best = lang;
            }
        }
        return best == null ? null : Map.entry(best, bestScore);
    }

    /**
     * Average raw score of {@code text} across all CJK languages in the model.
     */
    public float avgCjkScore(String text) {
        double sum = 0;
        int count = 0;
        for (int i = 0; i < langIds.size(); i++) {
            if (!isCjk[i]) continue;
            float s = score(text, langIds.get(i));
            if (!Float.isNaN(s)) {
                sum += s;
                count++;
            }
        }
        return count == 0 ? Float.NaN : (float) (sum / count);
    }

    // ---- Z-score API ----

    static final int CALIBRATION_CHAR_LENGTH  = 120;
    static final int MIN_ADJUSTED_CHAR_LENGTH = 10;

    public float zScore(String text, String language) {
        Integer li = langIndex.get(language);
        if (li == null || scoreStdDevs[li] <= 0.0f) return Float.NaN;
        float s = score(text, language);
        if (Float.isNaN(s)) return Float.NaN;
        return (s - scoreMeans[li]) / scoreStdDevs[li];
    }

    public float zScoreLengthAdjusted(String text, String language) {
        Integer li = langIndex.get(language);
        if (li == null || scoreStdDevs[li] <= 0.0f) return Float.NaN;
        float s = score(text, language);
        if (Float.isNaN(s)) return Float.NaN;
        int textLen = text.length();
        float adjustment = (float) Math.sqrt(
                (double) CALIBRATION_CHAR_LENGTH
                / Math.max(textLen, MIN_ADJUSTED_CHAR_LENGTH));
        float adjustedSigma = scoreStdDevs[li] * Math.max(1.0f, adjustment);
        return (s - scoreMeans[li]) / adjustedSigma;
    }

    public void setStats(String language, float mean, float stdDev) {
        Integer li = langIndex.get(language);
        if (li == null) throw new IllegalArgumentException("Unknown language: " + language);
        scoreMeans[li]   = mean;
        scoreStdDevs[li] = stdDev;
    }

    // ---- N-gram extraction: v4 (salted, word-buffer) ----

    /**
     * Callback receiving a non-negative FNV hash for a single feature.
     */
    @FunctionalInterface
    public interface HashConsumer {
        void consume(int hash);
    }

    /**
     * Extract CJK character unigrams and bigrams (v4: no script salt).
     */
    public static void extractCjkFeaturesV4(
            String text,
            HashConsumer unigramSink,
            HashConsumer bigramSink) {
        int prevCp = -1;
        int i = 0, len = text.length();
        while (i < len) {
            int cp = text.codePointAt(i);
            i += Character.charCount(cp);
            if (!Character.isLetter(cp)) {
                prevCp = -1;
                continue;
            }
            int lower = Character.toLowerCase(cp);
            if (!ScriptAwareFeatureExtractor.isCjkOrKana(lower)) {
                prevCp = -1;
                continue;
            }
            unigramSink.consume(hashV4(SALT_CJK_UNIGRAM, lower));
            if (prevCp >= 0) {
                bigramSink.consume(hashV4(SALT_MID, prevCp, lower));
            }
            prevCp = lower;
        }
    }

    /**
     * Extract non-CJK features (v4): per-letter unigrams (streaming) plus
     * word-buffer bigrams/trigrams with positional salt, and bidirectional
     * word bigrams.
     *
     * <p>A word is a maximal run of same-script non-CJK letter codepoints.
     * Script is determined by {@link GlmScriptCategory#of(int)}; codepoints
     * returning {@code -1} (unrecognized script) are treated as their own
     * single-character word so they don't pollute adjacent-word bigrams.
     *
     * @param wordBigramSink may be {@code null} to skip word-bigram features
     */
    public static void extractNonCjkFeaturesV4(
            String       text,
            HashConsumer unigramSink,
            HashConsumer bigramSink,
            HashConsumer trigramSink,
            HashConsumer wordBigramSink) {

        int[] word     = new int[256];
        int   wordLen  = 0;
        int   wordScript = -2;  // -2 = no word in progress; -1 = unrecognized script

        int[] prevWord    = new int[256];
        int   prevWordLen = 0;

        int i = 0, len = text.length();
        while (i < len) {
            int cp = text.codePointAt(i);
            i += Character.charCount(cp);

            if (cp >= 0x0300 && CharSoupFeatureExtractor.isTransparent(cp)) continue;

            if (Character.isLetter(cp)) {
                int lower  = Character.toLowerCase(cp);

                // CJK breaks non-CJK word stream
                if (ScriptAwareFeatureExtractor.isCjkOrKana(lower)) {
                    if (wordLen > 0) {
                        prevWordLen = flushWordV4(word, wordLen, bigramSink, trigramSink,
                                wordBigramSink, prevWord, prevWordLen);
                        wordLen = 0;
                        wordScript = -2;
                    }
                    prevWordLen = 0;  // CJK breaks word-bigram chain
                    continue;
                }

                // Unigram: every non-CJK letter
                unigramSink.consume(hashV4(SALT_NONCJK_UNIGRAM, lower));

                int script = GlmScriptCategory.of(lower);

                // Script change (or unrecognized) = word boundary
                boolean sameScript = (wordScript != -2)
                        && (script == wordScript)
                        && (script != -1);  // unrecognized is always its own word
                if (wordLen > 0 && !sameScript) {
                    prevWordLen = flushWordV4(word, wordLen, bigramSink, trigramSink,
                            wordBigramSink, prevWord, prevWordLen);
                    wordLen = 0;
                }

                if (wordLen < word.length) {
                    word[wordLen++] = lower;
                    wordScript = script;
                }
            } else {
                // Non-letter: flush word
                if (wordLen > 0) {
                    prevWordLen = flushWordV4(word, wordLen, bigramSink, trigramSink,
                            wordBigramSink, prevWord, prevWordLen);
                    wordLen = 0;
                    wordScript = -2;
                }
            }
        }

        if (wordLen > 0) {
            flushWordV4(word, wordLen, bigramSink, trigramSink,
                    wordBigramSink, prevWord, prevWordLen);
        }
    }

    /**
     * Emit bigrams/trigrams for a completed word and handle word bigrams.
     *
     * @return the new prevWordLen (= wordLen, since this word becomes the new prev)
     */
    private static int flushWordV4(
            int[] word, int wordLen,
            HashConsumer bigramSink,
            HashConsumer trigramSink,
            HashConsumer wordBigramSink,
            int[] prevWord, int prevWordLen) {

        emitWordNgramsV4(word, wordLen, bigramSink, trigramSink);

        if (wordBigramSink != null) {
            // Forward: short prev ��� any current
            if (prevWordLen >= 1 && prevWordLen <= MAX_SHORT_WORD) {
                emitWordBigram(wordBigramSink, SALT_WORD_FWD,
                        prevWord, prevWordLen, word, wordLen);
            }
            // Backward: any prev ��� short current
            if (wordLen >= 1 && wordLen <= MAX_SHORT_WORD && prevWordLen > 0) {
                emitWordBigram(wordBigramSink, SALT_WORD_BWD,
                        prevWord, prevWordLen, word, wordLen);
            }
        }

        System.arraycopy(word, 0, prevWord, 0, wordLen);
        return wordLen;
    }

    /**
     * Emit positionally-salted bigrams and trigrams for a completed word.
     *
     * <p>For each n-gram order k ��� {2, 3}:
     * <ul>
     *   <li>If wordLen == k: emit once with FULL_WORD salt.</li>
     *   <li>If wordLen > k: first n-gram gets BOW, last gets EOW, rest get MID.</li>
     * </ul>
     */
    static void emitWordNgramsV4(int[] word, int wordLen,
                                  HashConsumer bigramSink,
                                  HashConsumer trigramSink) {
        // Bigrams
        if (wordLen == 2) {
            bigramSink.consume(hashV4(SALT_FULL_WORD, word[0], word[1]));
        } else if (wordLen > 2) {
            bigramSink.consume(hashV4(SALT_BOW, word[0], word[1]));
            for (int j = 1; j < wordLen - 2; j++) {
                bigramSink.consume(hashV4(SALT_MID, word[j], word[j + 1]));
            }
            bigramSink.consume(hashV4(SALT_EOW, word[wordLen - 2], word[wordLen - 1]));
        }

        // Trigrams
        if (wordLen == 3) {
            trigramSink.consume(hashV4(SALT_FULL_WORD, word[0], word[1], word[2]));
        } else if (wordLen > 3) {
            trigramSink.consume(hashV4(SALT_BOW, word[0], word[1], word[2]));
            for (int j = 1; j < wordLen - 3; j++) {
                trigramSink.consume(hashV4(SALT_MID, word[j], word[j + 1], word[j + 2]));
            }
            trigramSink.consume(
                    hashV4(SALT_EOW, word[wordLen - 3], word[wordLen - 2], word[wordLen - 1]));
        }
    }

    /**
     * Emit a word bigram feature for (w1, w2) with the given directional salt.
     * A separator byte (0xFF) prevents collisions between, e.g., "ab"+"cd"
     * and "abc"+"d".
     */
    static void emitWordBigram(HashConsumer sink, int salt,
                                int[] w1, int w1Len,
                                int[] w2, int w2Len) {
        int h = fnvByte(FNV_BASIS, salt);
        for (int j = 0; j < w1Len; j++) h = fnvInt(h, w1[j]);
        h = fnvByte(h, 0xFF);
        for (int j = 0; j < w2Len; j++) h = fnvInt(h, w2[j]);
        sink.consume(h & 0x7FFFFFFF);
    }

    /**
     * Add per-letter script log-probability contributions (v4).
     * Uses {@link GlmScriptCategory}: unrecognized scripts (return value -1)
     * are silently skipped rather than falling into an OTHER bucket.
     */
    static void addScriptContributionsV4(String pp, byte[] scriptTable,
                                          double[] sum, int[] cnt) {
        // Count letters per script category
        int[] scriptCounts = new int[GlmScriptCategory.COUNT];
        int totalLetters = 0;
        int i = 0, len = pp.length();
        while (i < len) {
            int cp = pp.codePointAt(i);
            i += Character.charCount(cp);
            if (!Character.isLetter(cp)) continue;
            int script = GlmScriptCategory.of(Character.toLowerCase(cp));
            if (script >= 0 && script < scriptTable.length) {
                scriptCounts[script]++;
                totalLetters++;
            }
        }
        if (totalLetters == 0) return;
        // L1-normalize: one weighted contribution regardless of text length,
        // so script signal doesn't swamp n-gram signal on long text.
        double scriptScore = 0.0;
        for (int s = 0; s < scriptTable.length; s++) {
            if (scriptCounts[s] > 0) {
                scriptScore += (double) scriptCounts[s] / totalLetters
                        * dequantize(scriptTable[s]);
            }
        }
        sum[0] += scriptScore;
        cnt[0]++;
    }

    // ---- N-gram extraction: v3 (legacy sentinel-based, for old model loading) ----

    /** @deprecated Use v4 extraction for new models. */
    @Deprecated
    public static void extractCjkNgrams(
            String text, HashConsumer unigramSink, HashConsumer bigramSink) {
        int prevCp = -1;
        int i = 0, len = text.length();
        while (i < len) {
            int cp = text.codePointAt(i);
            i += Character.charCount(cp);
            if (!Character.isLetter(cp)) {
                prevCp = -1;
                continue;
            }
            int lower = Character.toLowerCase(cp);
            if (!ScriptAwareFeatureExtractor.isCjkOrKana(lower)) {
                prevCp = -1;
                continue;
            }
            int script = ScriptCategory.of(lower);
            unigramSink.consume(cjkUnigramHashV3(script, lower));
            if (prevCp >= 0) bigramSink.consume(bigramHashV3(script, prevCp, lower));
            prevCp = lower;
        }
    }

    /** @deprecated Use v4 extraction for new models. */
    @Deprecated
    public static void extractNonCjkNgrams(
            String text,
            HashConsumer unigramSink,
            HashConsumer bigramSink,
            HashConsumer trigramSink) {
        int prevPrev = SENTINEL, prev = SENTINEL, prevScript = -1;
        boolean inWord = false;

        int i = 0, len = text.length();
        while (i < len) {
            int cp = text.codePointAt(i);
            i += Character.charCount(cp);
            if (cp >= 0x0300 && CharSoupFeatureExtractor.isTransparent(cp)) continue;

            if (Character.isLetter(cp)) {
                int lower = Character.toLowerCase(cp);
                if (ScriptAwareFeatureExtractor.isCjkOrKana(lower)) {
                    if (inWord) {
                        emitWordEndV3(prevScript, prevPrev, prev, bigramSink, trigramSink);
                        inWord = false;
                        prevPrev = SENTINEL;
                        prev = SENTINEL;
                        prevScript = -1;
                    }
                    continue;
                }
                int script = ScriptCategory.of(lower);
                if (inWord && script != prevScript) {
                    emitWordEndV3(prevScript, prevPrev, prev, bigramSink, trigramSink);
                    inWord = false;
                    prevPrev = SENTINEL;
                    prev = SENTINEL;
                }
                unigramSink.consume(noncjkUnigramHashV3(script, lower));
                if (!inWord) {
                    bigramSink.consume(bigramHashV3(script, SENTINEL, lower));
                    trigramSink.consume(trigramHashV3(script, SENTINEL, SENTINEL, lower));
                    prevPrev = SENTINEL;
                } else {
                    bigramSink.consume(bigramHashV3(script, prev, lower));
                    trigramSink.consume(trigramHashV3(script, prevPrev, prev, lower));
                    prevPrev = prev;
                }
                prev = lower;
                prevScript = script;
                inWord = true;
            } else {
                if (inWord) {
                    emitWordEndV3(prevScript, prevPrev, prev, bigramSink, trigramSink);
                    inWord = false;
                    prevPrev = SENTINEL;
                    prev = SENTINEL;
                    prevScript = -1;
                }
            }
        }
        if (inWord) emitWordEndV3(prevScript, prevPrev, prev, bigramSink, trigramSink);
    }

    private static void emitWordEndV3(int script, int pp, int p,
                                       HashConsumer biSink, HashConsumer triSink) {
        biSink.consume(bigramHashV3(script, p, SENTINEL));
        triSink.consume(trigramHashV3(script, pp, p, SENTINEL));
        triSink.consume(trigramHashV3(script, p, SENTINEL, SENTINEL));
    }

    /** v3 script contributions using {@link ScriptCategory} (includes OTHER). */
    static void addScriptContributionsV3(String pp, byte[] scriptTable,
                                          double[] sum, int[] cnt) {
        int i = 0, len = pp.length();
        while (i < len) {
            int cp = pp.codePointAt(i);
            i += Character.charCount(cp);
            if (!Character.isLetter(cp)) continue;
            int script = ScriptCategory.of(Character.toLowerCase(cp));
            if (script < scriptTable.length) {
                sum[0] += dequantize(scriptTable[script]);
                cnt[0]++;
            }
        }
    }

    // ---- Hash functions ----

    /** FNV-1a hash: salt byte then one codepoint. */
    static int hashV4(int salt, int cp1) {
        int h = fnvByte(FNV_BASIS, salt);
        h = fnvInt(h, cp1);
        return h & 0x7FFFFFFF;
    }

    /** FNV-1a hash: salt byte then two codepoints. */
    static int hashV4(int salt, int cp1, int cp2) {
        int h = fnvByte(FNV_BASIS, salt);
        h = fnvInt(h, cp1);
        h = fnvInt(h, cp2);
        return h & 0x7FFFFFFF;
    }

    /** FNV-1a hash: salt byte then three codepoints. */
    static int hashV4(int salt, int cp1, int cp2, int cp3) {
        int h = fnvByte(FNV_BASIS, salt);
        h = fnvInt(h, cp1);
        h = fnvInt(h, cp2);
        h = fnvInt(h, cp3);
        return h & 0x7FFFFFFF;
    }

    // ---- v3 hash functions (kept for backward-compatible scoring) ----

    @Deprecated static int cjkUnigramHashV3(int script, int cp) {
        int h = CJK_UNIGRAM_BASIS;
        h = fnvByte(h, script);
        h = fnvInt(h, cp);
        return h & 0x7FFFFFFF;
    }
    @Deprecated static int noncjkUnigramHashV3(int script, int cp) {
        int h = NONCJK_UNIGRAM_BASIS;
        h = fnvByte(h, script);
        h = fnvInt(h, cp);
        return h & 0x7FFFFFFF;
    }
    @Deprecated static int bigramHashV3(int script, int cp1, int cp2) {
        int h = BIGRAM_BASIS;
        h = fnvByte(h, script);
        h = fnvInt(h, cp1);
        h = fnvInt(h, cp2);
        return h & 0x7FFFFFFF;
    }
    @Deprecated static int trigramHashV3(int script, int cp1, int cp2, int cp3) {
        int h = TRIGRAM_BASIS;
        h = fnvByte(h, script);
        h = fnvInt(h, cp1);
        h = fnvInt(h, cp2);
        h = fnvInt(h, cp3);
        return h & 0x7FFFFFFF;
    }

    // ---- FNV-1a primitives ----

    static int fnvByte(int h, int b) {
        return (h ^ (b & 0xFF)) * 0x01000193;
    }

    static int fnvInt(int h, int v) {
        h = (h ^ (v         & 0xFF)) * 0x01000193;
        h = (h ^ ((v >>>  8) & 0xFF)) * 0x01000193;
        h = (h ^ ((v >>> 16) & 0xFF)) * 0x01000193;
        h = (h ^ ((v >>> 24) & 0xFF)) * 0x01000193;
        return h;
    }

    // ---- Quantization ----

    static byte quantize(float logP) {
        float clamped = Math.max(LOGP_MIN, Math.min(0.0f, logP));
        return (byte) Math.round((clamped - LOGP_MIN) / (-LOGP_MIN) * 255.0f);
    }

    static float dequantize(byte b) {
        return (b & 0xFF) / 255.0f * (-LOGP_MIN) + LOGP_MIN;
    }

    // ---- Serialization ----

    public static GenerativeLanguageModel loadFromClasspath(String resourcePath)
            throws IOException {
        try (InputStream is = GenerativeLanguageModel.class.getResourceAsStream(resourcePath)) {
            if (is == null) throw new IOException("Classpath resource not found: " + resourcePath);
            return load(is);
        }
    }

    public static GenerativeLanguageModel load(InputStream is) throws IOException {
        DataInputStream din = new DataInputStream(new BufferedInputStream(is));

        int magic = din.readInt();
        if (magic != MAGIC) throw new IOException("Not a GLM1 file (bad magic)");
        int version = din.readInt();
        if (version < 1 || version > VERSION) {
            throw new IOException("Unsupported GLM version: " + version);
        }

        boolean hasStats      = version >= 2;
        boolean hasScript     = version >= 3;
        boolean hasWordBigram = version >= 4;

        int numLangs    = din.readInt();
        int cjkUni      = din.readInt();
        int cjkBi       = din.readInt();
        int noncjkUni   = din.readInt();
        int noncjkBi    = din.readInt();
        int noncjkTri   = din.readInt();
        int scriptCats  = hasScript     ? din.readInt() : 0;
        int wordBiBkts  = hasWordBigram ? din.readInt() : 0;

        List<String> langIds      = new ArrayList<>(numLangs);
        boolean[]    isCjkArr     = new boolean[numLangs];
        byte[][]     unigramTbls  = new byte[numLangs][];
        byte[][]     bigramTbls   = new byte[numLangs][];
        byte[][]     trigramTbls  = new byte[numLangs][];
        byte[][]     wordBiTbls   = hasWordBigram ? new byte[numLangs][] : null;
        byte[][]     scriptTbls   = hasScript     ? new byte[numLangs][] : null;
        float[]      means        = new float[numLangs];
        float[]      stdDevs      = new float[numLangs];

        for (int i = 0; i < numLangs; i++) {
            int    codeLen   = din.readUnsignedShort();
            byte[] codeBytes = new byte[codeLen];
            din.readFully(codeBytes);
            langIds.add(new String(codeBytes, StandardCharsets.UTF_8));

            isCjkArr[i] = din.readByte() != 0;

            if (hasStats) {
                means[i]   = din.readFloat();
                stdDevs[i] = din.readFloat();
            }

            int uniSize = isCjkArr[i] ? cjkUni : noncjkUni;
            int biSize  = isCjkArr[i] ? cjkBi  : noncjkBi;

            unigramTbls[i] = new byte[uniSize];
            din.readFully(unigramTbls[i]);
            bigramTbls[i] = new byte[biSize];
            din.readFully(bigramTbls[i]);

            if (!isCjkArr[i]) {
                trigramTbls[i] = new byte[noncjkTri];
                din.readFully(trigramTbls[i]);
            }

            if (hasWordBigram && !isCjkArr[i] && wordBiBkts > 0) {
                wordBiTbls[i] = new byte[wordBiBkts];
                din.readFully(wordBiTbls[i]);
            }

            if (hasScript) {
                scriptTbls[i] = new byte[scriptCats];
                din.readFully(scriptTbls[i]);
            }
        }

        return new GenerativeLanguageModel(version, langIds, isCjkArr,
                unigramTbls, bigramTbls, trigramTbls, wordBiTbls,
                scriptTbls, scriptCats, means, stdDevs);
    }

    public void save(OutputStream os) throws IOException {
        DataOutputStream dout = new DataOutputStream(new BufferedOutputStream(os));

        // Compute actual bucket sizes from the stored tables so that round-tripping
        // a model loaded from an older binary preserves the original table dimensions.
        int saveCjkUni = CJK_UNIGRAM_BUCKETS, saveCjkBi = CJK_BIGRAM_BUCKETS;
        int saveNoncjkUni = NONCJK_UNIGRAM_BUCKETS, saveNoncjkBi = NONCJK_BIGRAM_BUCKETS;
        int saveNoncjkTri = NONCJK_TRIGRAM_BUCKETS;
        int saveWordBi  = WORD_BIGRAM_BUCKETS;
        int saveScript  = SCRIPT_CATEGORIES;
        boolean foundCjk = false, foundNoncjk = false;
        for (int i = 0; i < langIds.size(); i++) {
            if (!foundCjk && isCjk[i] && unigramTables[i] != null) {
                saveCjkUni = unigramTables[i].length;
                saveCjkBi  = bigramTables[i].length;
                foundCjk = true;
            }
            if (!foundNoncjk && !isCjk[i] && unigramTables[i] != null) {
                saveNoncjkUni = unigramTables[i].length;
                saveNoncjkBi  = bigramTables[i].length;
                if (trigramTables[i] != null) saveNoncjkTri = trigramTables[i].length;
                if (wordBigramTables != null && wordBigramTables[i] != null)
                    saveWordBi = wordBigramTables[i].length;
                foundNoncjk = true;
            }
            if (scriptTables != null && scriptTables[i] != null && (i == 0 || saveScript == SCRIPT_CATEGORIES))
                saveScript = scriptTables[i].length;
            if (foundCjk && foundNoncjk) break;
        }

        dout.writeInt(MAGIC);
        dout.writeInt(VERSION);
        dout.writeInt(langIds.size());
        dout.writeInt(saveCjkUni);
        dout.writeInt(saveCjkBi);
        dout.writeInt(saveNoncjkUni);
        dout.writeInt(saveNoncjkBi);
        dout.writeInt(saveNoncjkTri);
        dout.writeInt(saveScript);
        dout.writeInt(saveWordBi);

        for (int i = 0; i < langIds.size(); i++) {
            byte[] codeBytes = langIds.get(i).getBytes(StandardCharsets.UTF_8);
            dout.writeShort(codeBytes.length);
            dout.write(codeBytes);
            dout.writeByte(isCjk[i] ? 1 : 0);
            dout.writeFloat(scoreMeans[i]);
            dout.writeFloat(scoreStdDevs[i]);
            dout.write(unigramTables[i]);
            dout.write(bigramTables[i]);
            if (!isCjk[i]) {
                dout.write(trigramTables[i]);
                // Word bigrams: write table or zeros if absent
                if (wordBigramTables != null && wordBigramTables[i] != null) {
                    dout.write(wordBigramTables[i]);
                } else {
                    dout.write(new byte[saveWordBi]);
                }
            }
            if (scriptTables != null && scriptTables[i] != null) {
                dout.write(scriptTables[i]);
            } else {
                dout.write(new byte[saveScript]);
            }
        }
        dout.flush();
    }

    // ---- Builder ----

    public static Builder builder() {
        return new Builder();
    }

    /**
     * Accumulates training samples per language and produces a
     * {@link GenerativeLanguageModel} via add-k smoothing.
     */
    public static class Builder {

        private final Map<String, Boolean> cjkFlags         = new LinkedHashMap<>();
        private final Map<String, long[]>  unigramCounts    = new HashMap<>();
        private final Map<String, long[]>  bigramCounts     = new HashMap<>();
        private final Map<String, long[]>  trigramCounts    = new HashMap<>();
        private final Map<String, long[]>  wordBigramCounts = new HashMap<>();
        private final Map<String, long[]>  scriptCounts     = new HashMap<>();

        public Builder registerLanguage(String langCode, boolean isCjk) {
            cjkFlags.put(langCode, isCjk);
            unigramCounts.put(langCode,
                    new long[isCjk ? CJK_UNIGRAM_BUCKETS : NONCJK_UNIGRAM_BUCKETS]);
            bigramCounts.put(langCode,
                    new long[isCjk ? CJK_BIGRAM_BUCKETS  : NONCJK_BIGRAM_BUCKETS]);
            if (!isCjk) {
                trigramCounts.put(langCode, new long[NONCJK_TRIGRAM_BUCKETS]);
                wordBigramCounts.put(langCode, new long[WORD_BIGRAM_BUCKETS]);
            }
            scriptCounts.put(langCode, new long[SCRIPT_CATEGORIES]);
            return this;
        }

        public Builder addSample(String langCode, String text) {
            Boolean cjk = cjkFlags.get(langCode);
            if (cjk == null) throw new IllegalArgumentException("Unknown language: " + langCode);
            String pp = CharSoupFeatureExtractor.preprocess(text);
            if (pp.isEmpty()) return this;

            long[] ug = unigramCounts.get(langCode);
            long[] bg = bigramCounts.get(langCode);

            if (cjk) {
                extractCjkFeaturesV4(pp,
                        h -> ug[h % CJK_UNIGRAM_BUCKETS]++,
                        h -> bg[h % CJK_BIGRAM_BUCKETS]++);
            } else {
                long[] tg  = trigramCounts.get(langCode);
                long[] wbg = wordBigramCounts.get(langCode);
                extractNonCjkFeaturesV4(pp,
                        h -> ug[h  % NONCJK_UNIGRAM_BUCKETS]++,
                        h -> bg[h  % NONCJK_BIGRAM_BUCKETS]++,
                        h -> tg[h  % NONCJK_TRIGRAM_BUCKETS]++,
                        h -> wbg[h % WORD_BIGRAM_BUCKETS]++);
            }

            accumulateScriptCounts(pp, scriptCounts.get(langCode));
            return this;
        }

        public GenerativeLanguageModel build(float addK) {
            List<String> ids = new ArrayList<>(cjkFlags.keySet());
            int n = ids.size();

            boolean[] cjkArr     = new boolean[n];
            byte[][]  uniTables  = new byte[n][];
            byte[][]  biTables   = new byte[n][];
            byte[][]  triTables  = new byte[n][];
            byte[][]  wbiTables  = new byte[n][];
            byte[][]  scriptTbls = new byte[n][];

            for (int i = 0; i < n; i++) {
                String lang = ids.get(i);
                cjkArr[i]     = cjkFlags.get(lang);
                uniTables[i]  = toLogProbTable(unigramCounts.get(lang),   addK);
                biTables[i]   = toLogProbTable(bigramCounts.get(lang),    addK);
                if (!cjkArr[i]) {
                    triTables[i] = toLogProbTable(trigramCounts.get(lang),    addK);
                    wbiTables[i] = toLogProbTable(wordBigramCounts.get(lang), addK);
                }
                scriptTbls[i] = toLogProbTable(scriptCounts.get(lang), addK);
            }
            return new GenerativeLanguageModel(VERSION, ids, cjkArr,
                    uniTables, biTables, triTables, wbiTables,
                    scriptTbls, SCRIPT_CATEGORIES,
                    new float[n], new float[n]);
        }

        private static byte[] toLogProbTable(long[] counts, float addK) {
            long total = 0;
            for (long c : counts) total += c;
            double denom = total + (double) addK * counts.length;
            byte[] table = new byte[counts.length];
            for (int i = 0; i < counts.length; i++) {
                double p = (counts[i] + addK) / denom;
                table[i] = quantize((float) Math.log(p));
            }
            return table;
        }

        private static void accumulateScriptCounts(String pp, long[] dest) {
            int i = 0, len = pp.length();
            while (i < len) {
                int cp = pp.codePointAt(i);
                i += Character.charCount(cp);
                if (!Character.isLetter(cp)) continue;
                int script = GlmScriptCategory.of(Character.toLowerCase(cp));
                if (script >= 0 && script < dest.length) {
                    dest[script]++;
                }
                // script == -1 (unrecognized): silently skipped ��� no OTHER bucket
            }
        }

    }
}