GenerativeLanguageModel.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* Dense INT8 generative character n-gram model for languageness scoring.
*
* <p>Computes an approximate per-n-gram average log P(text | language).
* Higher scores indicate the decoded text is more consistent with the named
* language. The score is used to arbitrate between candidate charsets when
* statistical decoders disagree on script or language.
*
* <h3>Feature types</h3>
* <ul>
* <li><b>CJK languages</b> (Han, Hiragana, Katakana): character unigrams
* and bigrams extracted from CJK/kana codepoints.</li>
* <li><b>Non-CJK languages</b>: character unigrams, bigrams (with
* word-boundary sentinels), and trigrams (with sentinels).</li>
* </ul>
*
* <p>Log-probabilities are quantized to unsigned INT8 over the range
* [{@link #LOGP_MIN}, 0] and stored in dense byte arrays.
*
* <h3>Binary format ({@code GLM1} v2)</h3>
* <pre>
* INT magic = 0x474C4D31
* INT version = 2
* INT numLangs
* INT cjkUnigramBuckets
* INT cjkBigramBuckets
* INT noncjkUnigramBuckets
* INT noncjkBigramBuckets
* INT noncjkTrigramBuckets
* For each language:
* SHORT codeLen
* BYTES langCode (UTF-8)
* BYTE isCjk (0|1)
* FLOAT scoreMean (�� of score distribution on training data)
* FLOAT scoreStdDev (�� of score distribution on training data)
* BYTES unigramTable [cjkUnigramBuckets | noncjkUnigramBuckets]
* BYTES bigramTable [cjkBigramBuckets | noncjkBigramBuckets]
* BYTES trigramTable [noncjkTrigramBuckets] (absent for CJK)
* </pre>
*/
public class GenerativeLanguageModel {
// ---- Bucket counts ----
public static final int CJK_UNIGRAM_BUCKETS = 8_192;
public static final int CJK_BIGRAM_BUCKETS = 32_768;
public static final int NONCJK_UNIGRAM_BUCKETS = 8_192;
public static final int NONCJK_BIGRAM_BUCKETS = 8_192;
public static final int NONCJK_TRIGRAM_BUCKETS = 16_384;
/** Default classpath resource path for the bundled generative model. */
public static final String DEFAULT_MODEL_RESOURCE =
"/org/apache/tika/langdetect/charsoup/langdetect-generative-v1-20260310.bin";
/**
* Quantization floor. Log-probabilities below this value are clamped
* before quantizing; values stored in the table never go lower.
*/
public static final float LOGP_MIN = -18.0f;
private static final int MAGIC = 0x474C4D31; // "GLM1"
private static final int VERSION = 2;
// ---- FNV-1a basis constants ----
/**
* Bigram basis shared with {@link ScriptAwareFeatureExtractor} so that
* identical text produces the same bucket indices for both models.
*/
static final int BIGRAM_BASIS = ScriptAwareFeatureExtractor.BIGRAM_BASIS;
/**
* CJK unigram basis shared with {@link ScriptAwareFeatureExtractor}.
*/
static final int CJK_UNIGRAM_BASIS = ScriptAwareFeatureExtractor.UNIGRAM_BASIS;
/** Distinct salt for non-CJK character unigrams (not in discriminative model). */
static final int NONCJK_UNIGRAM_BASIS = 0x1a3f7c4e;
/** Distinct salt for character trigrams (not in discriminative model). */
static final int TRIGRAM_BASIS = 0x7e3d9b21;
/** Word-boundary sentinel codepoint, matching the discriminative model. */
static final int SENTINEL = '_';
// ---- Model state ----
private final List<String> langIds;
private final Map<String, Integer> langIndex;
private final boolean[] isCjk;
private final byte[][] unigramTables; // [langIdx][bucket]
private final byte[][] bigramTables; // [langIdx][bucket]
private final byte[][] trigramTables; // [langIdx][bucket]; null entry for CJK langs
private final float[] scoreMeans; // �� per language (from training data)
private final float[] scoreStdDevs; // �� per language (from training data)
private GenerativeLanguageModel(
List<String> langIds,
boolean[] isCjk,
byte[][] unigramTables,
byte[][] bigramTables,
byte[][] trigramTables,
float[] scoreMeans,
float[] scoreStdDevs) {
this.langIds = Collections.unmodifiableList(new ArrayList<>(langIds));
this.isCjk = isCjk;
this.unigramTables = unigramTables;
this.bigramTables = bigramTables;
this.trigramTables = trigramTables;
this.scoreMeans = scoreMeans;
this.scoreStdDevs = scoreStdDevs;
Map<String, Integer> idx = new HashMap<>(langIds.size() * 2);
for (int i = 0; i < langIds.size(); i++) {
idx.put(langIds.get(i), i);
}
this.langIndex = Collections.unmodifiableMap(idx);
}
// ---- Public API ----
public List<String> getLanguages() {
return langIds;
}
public boolean isCjk(String language) {
Integer i = langIndex.get(language);
return i != null && isCjk[i];
}
/**
* Per-n-gram average log-probability of {@code text} under {@code language}.
*
* @return a value in [{@link #LOGP_MIN}, 0], or {@link Float#NaN} if the
* language is unknown or the text yields no scorable n-grams.
*/
public float score(String text, String language) {
if (text == null || text.isEmpty()) {
return Float.NaN;
}
Integer li = langIndex.get(language);
if (li == null) {
return Float.NaN;
}
String preprocessed = CharSoupFeatureExtractor.preprocess(text);
if (preprocessed.isEmpty()) {
return Float.NaN;
}
double[] sum = {0.0};
int[] cnt = {0};
if (isCjk[li]) {
byte[] uniT = unigramTables[li];
byte[] biT = bigramTables[li];
extractCjkNgrams(preprocessed,
h -> {
sum[0] += dequantize(uniT[h % CJK_UNIGRAM_BUCKETS]);
cnt[0]++;
},
h -> {
sum[0] += dequantize(biT[h % CJK_BIGRAM_BUCKETS]);
cnt[0]++;
});
} else {
byte[] uniT = unigramTables[li];
byte[] biT = bigramTables[li];
byte[] triT = trigramTables[li];
extractNonCjkNgrams(preprocessed,
h -> {
sum[0] += dequantize(uniT[h % NONCJK_UNIGRAM_BUCKETS]);
cnt[0]++;
},
h -> {
sum[0] += dequantize(biT[h % NONCJK_BIGRAM_BUCKETS]);
cnt[0]++;
},
h -> {
sum[0] += dequantize(triT[h % NONCJK_TRIGRAM_BUCKETS]);
cnt[0]++;
});
}
return cnt[0] == 0 ? Float.NaN : (float) (sum[0] / cnt[0]);
}
/**
* Score {@code text} against all languages and return the best match.
*
* @return an entry {@code (languageCode, score)}, or {@code null} if no
* language yields a finite score.
*/
public Map.Entry<String, Float> bestMatch(String text) {
String best = null;
float bestScore = Float.NEGATIVE_INFINITY;
for (String lang : langIds) {
float s = score(text, lang);
if (!Float.isNaN(s) && s > bestScore) {
bestScore = s;
best = lang;
}
}
return best == null ? null : Map.entry(best, bestScore);
}
/**
* Z-score of {@code text} under {@code language}:
* {@code (score(text, language) - ��) / ��}, where �� and �� were computed
* from the language's training corpus.
*
* <p>Appropriate when the input text is roughly the same length as
* training sentences. For short or variable-length text, prefer
* {@link #zScoreLengthAdjusted}.
*
* @return the z-score, or {@link Float#NaN} if the language is unknown,
* the text yields no scorable n-grams, or �� is zero/uncalibrated.
*/
public float zScore(String text, String language) {
Integer li = langIndex.get(language);
if (li == null || scoreStdDevs[li] <= 0.0f) {
return Float.NaN;
}
float s = score(text, language);
if (Float.isNaN(s)) {
return Float.NaN;
}
return (s - scoreMeans[li]) / scoreStdDevs[li];
}
/**
* Approximate character length of a typical training sentence.
* Used by {@link #zScoreLengthAdjusted} to inflate �� for short text.
* Empirically derived from the calibration data: score �� scales as
* roughly 1/���(charLen) and stabilises around this length.
*/
static final int CALIBRATION_CHAR_LENGTH = 120;
/** Floor on text length to avoid extreme �� inflation. */
static final int MIN_ADJUSTED_CHAR_LENGTH = 10;
/**
* Length-adjusted z-score of {@code text} under {@code language}.
*
* <p>Score variance scales as approximately 1/���(textLength). The
* stored �� was calibrated on full training sentences (typically
* ~{@value #CALIBRATION_CHAR_LENGTH} characters). For shorter text
* this method inflates �� proportionally, preventing spurious low
* z-scores on short snippets. For text at or above the calibration
* length, the result equals {@link #zScore}.
*
* @return the adjusted z-score, or {@link Float#NaN} if the language
* is unknown, the text yields no scorable n-grams, or �� is
* zero/uncalibrated.
*/
public float zScoreLengthAdjusted(String text, String language) {
Integer li = langIndex.get(language);
if (li == null || scoreStdDevs[li] <= 0.0f) {
return Float.NaN;
}
float s = score(text, language);
if (Float.isNaN(s)) {
return Float.NaN;
}
int textLen = text.length();
float adjustment = (float) Math.sqrt(
(double) CALIBRATION_CHAR_LENGTH
/ Math.max(textLen, MIN_ADJUSTED_CHAR_LENGTH));
float adjustedSigma = scoreStdDevs[li] * Math.max(1.0f, adjustment);
return (s - scoreMeans[li]) / adjustedSigma;
}
/**
* Set the calibration statistics for a language. Typically called by
* the training tool after a second pass over the training corpus.
*/
public void setStats(String language, float mean, float stdDev) {
Integer li = langIndex.get(language);
if (li == null) {
throw new IllegalArgumentException("Unknown language: " + language);
}
scoreMeans[li] = mean;
scoreStdDevs[li] = stdDev;
}
// ---- N-gram extraction (shared by scoring and training) ----
/**
* Callback receiving a non-negative raw FNV hash for a single n-gram.
* The caller is responsible for reducing it modulo a table size.
*/
@FunctionalInterface
public interface HashConsumer {
void consume(int hash);
}
/**
* Extract CJK character unigrams and bigrams from preprocessed text,
* delivering raw (positive) hashes to the supplied sinks.
*/
public static void extractCjkNgrams(
String text,
HashConsumer unigramSink,
HashConsumer bigramSink) {
int prevCp = -1;
int i = 0;
int len = text.length();
while (i < len) {
int cp = text.codePointAt(i);
i += Character.charCount(cp);
if (!Character.isLetter(cp)) {
prevCp = -1;
continue;
}
int lower = Character.toLowerCase(cp);
if (!ScriptAwareFeatureExtractor.isCjkOrKana(lower)) {
prevCp = -1;
continue;
}
int script = ScriptCategory.of(lower);
unigramSink.consume(cjkUnigramHash(script, lower));
if (prevCp >= 0) {
bigramSink.consume(bigramHash(script, prevCp, lower));
}
prevCp = lower;
}
}
/**
* Extract non-CJK character unigrams, sentinel-padded bigrams, and
* sentinel-padded trigrams from preprocessed text.
*
* <p>A "word" is a maximal run of non-CJK letter codepoints within the
* same script family. Sentinels ({@link #SENTINEL}) pad each word on
* both sides, so a word of length L yields L+1 bigrams and L+2 trigrams.
*/
public static void extractNonCjkNgrams(
String text,
HashConsumer unigramSink,
HashConsumer bigramSink,
HashConsumer trigramSink) {
int prevPrev = SENTINEL;
int prev = SENTINEL;
int prevScript = -1;
boolean inWord = false;
int i = 0;
int len = text.length();
while (i < len) {
int cp = text.codePointAt(i);
i += Character.charCount(cp);
if (cp >= 0x0300 && CharSoupFeatureExtractor.isTransparent(cp)) {
continue;
}
if (Character.isLetter(cp)) {
int lower = Character.toLowerCase(cp);
if (ScriptAwareFeatureExtractor.isCjkOrKana(lower)) {
if (inWord) {
emitWordEnd(prevScript, prevPrev, prev, bigramSink, trigramSink);
inWord = false;
prevPrev = SENTINEL;
prev = SENTINEL;
prevScript = -1;
}
continue;
}
int script = ScriptCategory.of(lower);
if (inWord && script != prevScript) {
// Script change is a word boundary
emitWordEnd(prevScript, prevPrev, prev, bigramSink, trigramSink);
inWord = false;
prevPrev = SENTINEL;
prev = SENTINEL;
}
unigramSink.consume(noncjkUnigramHash(script, lower));
if (!inWord) {
// Leading sentinels
bigramSink.consume(bigramHash(script, SENTINEL, lower));
trigramSink.consume(trigramHash(script, SENTINEL, SENTINEL, lower));
prevPrev = SENTINEL;
} else {
bigramSink.consume(bigramHash(script, prev, lower));
trigramSink.consume(trigramHash(script, prevPrev, prev, lower));
prevPrev = prev;
}
prev = lower;
prevScript = script;
inWord = true;
} else {
if (inWord) {
emitWordEnd(prevScript, prevPrev, prev, bigramSink, trigramSink);
inWord = false;
prevPrev = SENTINEL;
prev = SENTINEL;
prevScript = -1;
}
}
}
if (inWord) {
emitWordEnd(prevScript, prevPrev, prev, bigramSink, trigramSink);
}
}
private static void emitWordEnd(
int script, int pp, int p,
HashConsumer bigramSink, HashConsumer trigramSink) {
bigramSink.consume(bigramHash(script, p, SENTINEL));
trigramSink.consume(trigramHash(script, pp, p, SENTINEL));
trigramSink.consume(trigramHash(script, p, SENTINEL, SENTINEL));
}
// ---- Hash functions (FNV-1a) ----
static int cjkUnigramHash(int script, int cp) {
int h = CJK_UNIGRAM_BASIS;
h = fnvByte(h, script);
h = fnvInt(h, cp);
return h & 0x7FFFFFFF;
}
static int noncjkUnigramHash(int script, int cp) {
int h = NONCJK_UNIGRAM_BASIS;
h = fnvByte(h, script);
h = fnvInt(h, cp);
return h & 0x7FFFFFFF;
}
static int bigramHash(int script, int cp1, int cp2) {
int h = BIGRAM_BASIS;
h = fnvByte(h, script);
h = fnvInt(h, cp1);
h = fnvInt(h, cp2);
return h & 0x7FFFFFFF;
}
static int trigramHash(int script, int cp1, int cp2, int cp3) {
int h = TRIGRAM_BASIS;
h = fnvByte(h, script);
h = fnvInt(h, cp1);
h = fnvInt(h, cp2);
h = fnvInt(h, cp3);
return h & 0x7FFFFFFF;
}
private static int fnvByte(int h, int b) {
return (h ^ (b & 0xFF)) * 0x01000193;
}
private static int fnvInt(int h, int v) {
h = (h ^ (v & 0xFF)) * 0x01000193;
h = (h ^ ((v >>> 8) & 0xFF)) * 0x01000193;
h = (h ^ ((v >>> 16) & 0xFF)) * 0x01000193;
h = (h ^ ((v >>> 24) & 0xFF)) * 0x01000193;
return h;
}
// ---- Quantization ----
/**
* Quantize a log-probability in [{@link #LOGP_MIN}, 0] to an unsigned byte
* value: 0 maps to {@code LOGP_MIN}, 255 maps to 0.
*/
static byte quantize(float logP) {
float clamped = Math.max(LOGP_MIN, Math.min(0.0f, logP));
return (byte) Math.round((clamped - LOGP_MIN) / (-LOGP_MIN) * 255.0f);
}
/** Inverse of {@link #quantize}. */
static float dequantize(byte b) {
return (b & 0xFF) / 255.0f * (-LOGP_MIN) + LOGP_MIN;
}
// ---- Serialization ----
/**
* Load a model from a classpath resource.
*
* @param resourcePath absolute classpath path, e.g.
* {@code "/org/apache/tika/langdetect/charsoup/langdetect-generative-v1-20260310.bin"}
* @return the loaded model
* @throws IOException if the resource is missing or malformed
*/
public static GenerativeLanguageModel loadFromClasspath(String resourcePath)
throws IOException {
try (InputStream is = GenerativeLanguageModel.class.getResourceAsStream(resourcePath)) {
if (is == null) {
throw new IOException("Classpath resource not found: " + resourcePath);
}
return load(is);
}
}
/**
* Deserialize a model from the GLM1 binary format.
*/
public static GenerativeLanguageModel load(InputStream is) throws IOException {
DataInputStream din = new DataInputStream(new BufferedInputStream(is));
int magic = din.readInt();
if (magic != MAGIC) {
throw new IOException("Not a GLM1 file (bad magic)");
}
int version = din.readInt();
if (version != 1 && version != VERSION) {
throw new IOException("Unsupported GLM version: " + version);
}
boolean hasStats = version >= 2;
int numLangs = din.readInt();
int cjkUni = din.readInt();
int cjkBi = din.readInt();
int noncjkUni = din.readInt();
int noncjkBi = din.readInt();
int noncjkTri = din.readInt();
List<String> langIds = new ArrayList<>(numLangs);
boolean[] isCjk = new boolean[numLangs];
byte[][] unigramTables = new byte[numLangs][];
byte[][] bigramTables = new byte[numLangs][];
byte[][] trigramTables = new byte[numLangs][];
float[] means = new float[numLangs];
float[] stdDevs = new float[numLangs];
for (int i = 0; i < numLangs; i++) {
int codeLen = din.readUnsignedShort();
byte[] codeBytes = new byte[codeLen];
din.readFully(codeBytes);
langIds.add(new String(codeBytes, StandardCharsets.UTF_8));
isCjk[i] = din.readByte() != 0;
if (hasStats) {
means[i] = din.readFloat();
stdDevs[i] = din.readFloat();
}
int uniSize = isCjk[i] ? cjkUni : noncjkUni;
int biSize = isCjk[i] ? cjkBi : noncjkBi;
unigramTables[i] = new byte[uniSize];
din.readFully(unigramTables[i]);
bigramTables[i] = new byte[biSize];
din.readFully(bigramTables[i]);
if (!isCjk[i]) {
trigramTables[i] = new byte[noncjkTri];
din.readFully(trigramTables[i]);
}
}
return new GenerativeLanguageModel(langIds, isCjk,
unigramTables, bigramTables, trigramTables,
means, stdDevs);
}
/**
* Serialize this model to the GLM1 binary format.
*/
public void save(OutputStream os) throws IOException {
DataOutputStream dout = new DataOutputStream(new BufferedOutputStream(os));
dout.writeInt(MAGIC);
dout.writeInt(VERSION);
dout.writeInt(langIds.size());
dout.writeInt(CJK_UNIGRAM_BUCKETS);
dout.writeInt(CJK_BIGRAM_BUCKETS);
dout.writeInt(NONCJK_UNIGRAM_BUCKETS);
dout.writeInt(NONCJK_BIGRAM_BUCKETS);
dout.writeInt(NONCJK_TRIGRAM_BUCKETS);
for (int i = 0; i < langIds.size(); i++) {
byte[] codeBytes = langIds.get(i).getBytes(StandardCharsets.UTF_8);
dout.writeShort(codeBytes.length);
dout.write(codeBytes);
dout.writeByte(isCjk[i] ? 1 : 0);
dout.writeFloat(scoreMeans[i]);
dout.writeFloat(scoreStdDevs[i]);
dout.write(unigramTables[i]);
dout.write(bigramTables[i]);
if (!isCjk[i]) {
dout.write(trigramTables[i]);
}
}
dout.flush();
}
// ---- Builder ----
public static Builder builder() {
return new Builder();
}
/**
* Accumulates training samples per language and produces a
* {@link GenerativeLanguageModel} via add-k smoothing.
*/
public static class Builder {
private final Map<String, Boolean> cjkFlags = new LinkedHashMap<>();
private final Map<String, long[]> unigramCounts = new HashMap<>();
private final Map<String, long[]> bigramCounts = new HashMap<>();
private final Map<String, long[]> trigramCounts = new HashMap<>();
/**
* Register a language before feeding it samples. Must be called
* before {@link #addSample(String, String)}.
*/
public Builder registerLanguage(String langCode, boolean isCjk) {
cjkFlags.put(langCode, isCjk);
unigramCounts.put(langCode,
new long[isCjk ? CJK_UNIGRAM_BUCKETS : NONCJK_UNIGRAM_BUCKETS]);
bigramCounts.put(langCode,
new long[isCjk ? CJK_BIGRAM_BUCKETS : NONCJK_BIGRAM_BUCKETS]);
if (!isCjk) {
trigramCounts.put(langCode, new long[NONCJK_TRIGRAM_BUCKETS]);
}
return this;
}
/**
* Add a text sample for the named language. The language must have
* been registered via {@link #registerLanguage} first.
*/
public Builder addSample(String langCode, String text) {
Boolean cjk = cjkFlags.get(langCode);
if (cjk == null) {
throw new IllegalArgumentException("Unknown language: " + langCode);
}
String pp = CharSoupFeatureExtractor.preprocess(text);
if (pp.isEmpty()) {
return this;
}
long[] ug = unigramCounts.get(langCode);
long[] bg = bigramCounts.get(langCode);
if (cjk) {
extractCjkNgrams(pp,
h -> ug[h % CJK_UNIGRAM_BUCKETS]++,
h -> bg[h % CJK_BIGRAM_BUCKETS]++);
} else {
long[] tg = trigramCounts.get(langCode);
extractNonCjkNgrams(pp,
h -> ug[h % NONCJK_UNIGRAM_BUCKETS]++,
h -> bg[h % NONCJK_BIGRAM_BUCKETS]++,
h -> tg[h % NONCJK_TRIGRAM_BUCKETS]++);
}
return this;
}
/**
* Finalize training with add-{@code k} smoothing and return the model.
*
* @param addK smoothing constant; 0.01 is a reasonable default
*/
public GenerativeLanguageModel build(float addK) {
List<String> ids = new ArrayList<>(cjkFlags.keySet());
int n = ids.size();
boolean[] cjkArr = new boolean[n];
byte[][] uniTables = new byte[n][];
byte[][] biTables = new byte[n][];
byte[][] triTables = new byte[n][];
for (int i = 0; i < n; i++) {
String lang = ids.get(i);
cjkArr[i] = cjkFlags.get(lang);
uniTables[i] = toLogProbTable(unigramCounts.get(lang), addK);
biTables[i] = toLogProbTable(bigramCounts.get(lang), addK);
if (!cjkArr[i]) {
triTables[i] = toLogProbTable(trigramCounts.get(lang), addK);
}
}
return new GenerativeLanguageModel(ids, cjkArr, uniTables, biTables, triTables,
new float[n], new float[n]);
}
private static byte[] toLogProbTable(long[] counts, float addK) {
long total = 0;
for (long c : counts) {
total += c;
}
double denom = total + (double) addK * counts.length;
byte[] table = new byte[counts.length];
for (int i = 0; i < counts.length; i++) {
double p = (counts[i] + addK) / denom;
table[i] = quantize((float) Math.log(p));
}
return table;
}
}
}