TrainGenerativeLanguageModel.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.DirectoryStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Locale;
import org.apache.tika.langdetect.charsoup.GenerativeLanguageModel;
import org.apache.tika.langdetect.charsoup.ScriptAwareFeatureExtractor;
/**
* Trains a {@link GenerativeLanguageModel} from a Leipzig-format corpus.
*
* <h3>Corpus format</h3>
* <pre>
* corpusDir/
* eng/
* sentences.txt (lineNum TAB sentence)
* zho/
* sentences.txt
* jpn/
* sentences.txt
* ...
* </pre>
* Each directory name is used as the language code. Any {@code .txt} file
* directly under a language directory is read; each line must contain at
* least one tab, and the text after the first tab is the sentence.
*
* <h3>CJK detection</h3>
* A language is treated as CJK if at least 60% of the letter codepoints
* in a random sample of sentences are CJK/kana characters. You can
* override this with an explicit {@code --cjk} list on the command line.
*
* <h3>Usage</h3>
* <pre>
* java TrainGenerativeLanguageModel \
* --corpus /path/to/Leipzig-corpus \
* --output generative.bin \
* [--max-per-lang 500000] \
* [--add-k 0.01] \
* [--cjk zho,jpn,cmn]
* </pre>
*/
public class TrainGenerativeLanguageModel {
private static final int DEFAULT_MAX_PER_LANG = 500_000;
private static final float DEFAULT_ADD_K = 0.01f;
/** Fraction of letter codepoints that must be CJK to classify a language as CJK. */
private static final float CJK_LETTER_THRESHOLD = 0.60f;
/** Number of sentences used to probe the script of an unknown language. */
private static final int CJK_PROBE_SENTENCES = 500;
public static void main(String[] args) throws Exception {
Path corpus = null;
Path output = null;
int maxPerLang = DEFAULT_MAX_PER_LANG;
float addK = DEFAULT_ADD_K;
List<String> forceCjk = new ArrayList<>();
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
case "--corpus":
corpus = Paths.get(args[++i]);
break;
case "--output":
output = Paths.get(args[++i]);
break;
case "--max-per-lang":
maxPerLang = Integer.parseInt(args[++i]);
break;
case "--add-k":
addK = Float.parseFloat(args[++i]);
break;
case "--cjk": {
for (String code : args[++i].split(",")) {
forceCjk.add(code.trim());
}
break;
}
default:
System.err.println("Unknown option: " + args[i]);
printUsage();
System.exit(1);
}
}
if (corpus == null || output == null) {
printUsage();
System.exit(1);
}
new TrainGenerativeLanguageModel().run(corpus, output, maxPerLang, addK, forceCjk);
}
private void run(Path corpusDir, Path outputPath,
int maxPerLang, float addK,
List<String> forceCjkList) throws IOException {
// Support two corpus layouts:
// flat: corpusDir/{langCode} (one sentence per line, no tab prefix)
// Leipzig: corpusDir/{langCode}/*.txt (lineNum TAB sentence)
boolean flatLayout = isFlatLayout(corpusDir);
System.out.printf(Locale.US, "Corpus layout: %s%n", flatLayout ? "flat" : "Leipzig");
List<Path> langPaths = listLangPaths(corpusDir, flatLayout);
System.out.printf(Locale.US, "Found %d languages in %s%n", langPaths.size(), corpusDir);
GenerativeLanguageModel.Builder builder = GenerativeLanguageModel.builder();
for (Path langPath : langPaths) {
String lang = langPath.getFileName().toString();
boolean cjk = forceCjkList.contains(lang)
|| probeCjk(langPath, flatLayout, CJK_PROBE_SENTENCES);
System.out.printf(Locale.US, " %-12s %s%n", lang, cjk ? "CJK" : "non-CJK");
builder.registerLanguage(lang, cjk);
}
System.out.println("Accumulating n-gram counts ���");
long totalSentences = 0;
for (Path langPath : langPaths) {
String lang = langPath.getFileName().toString();
long counted = feedLanguage(builder, lang, langPath, flatLayout, maxPerLang);
totalSentences += counted;
System.out.printf(Locale.US, " %-12s %,d sentences%n", lang, counted);
}
System.out.printf(Locale.US, "Total sentences: %,d%n", totalSentences);
System.out.printf(Locale.US, "Building model (add-k=%.4f) ���%n", addK);
GenerativeLanguageModel model = builder.build(addK);
// Second pass: score training data to compute per-language �� and ��
System.out.println("Calibrating z-scores (second pass) ���");
for (Path langPath : langPaths) {
String lang = langPath.getFileName().toString();
double[] stats = calibrateLanguage(model, lang, langPath, flatLayout, maxPerLang);
model.setStats(lang, (float) stats[0], (float) stats[1]);
System.out.printf(Locale.US,
" %-12s ��=%8.4f ��=%6.4f (n=%d)%n",
lang, stats[0], stats[1], (long) stats[2]);
}
System.out.printf(Locale.US, "Writing model to %s ���%n", outputPath);
try (OutputStream os = new FileOutputStream(outputPath.toFile())) {
model.save(os);
}
long bytes = Files.size(outputPath);
System.out.printf(Locale.US, "Done. Model size: %,.0f KB%n", bytes / 1024.0);
}
// ---- Corpus helpers ----
/**
* Returns true if the corpus uses the flat layout (files named by language
* code, one sentence per line) rather than the Leipzig layout (subdirectories
* containing {@code *.txt} files with {@code lineNum TAB sentence} lines).
*/
private static boolean isFlatLayout(Path corpusDir) throws IOException {
try (DirectoryStream<Path> stream = Files.newDirectoryStream(corpusDir)) {
for (Path p : stream) {
return Files.isRegularFile(p);
}
}
return true;
}
/**
* List all language paths in the corpus directory, sorted.
* For flat layout: regular files. For Leipzig layout: subdirectories.
*/
private static List<Path> listLangPaths(Path corpusDir,
boolean flat) throws IOException {
List<Path> paths = new ArrayList<>();
try (DirectoryStream<Path> stream = Files.newDirectoryStream(corpusDir,
p -> flat ? Files.isRegularFile(p) : Files.isDirectory(p))) {
for (Path p : stream) {
paths.add(p);
}
}
Collections.sort(paths);
return paths;
}
/**
* Feed up to {@code maxPerLang} sentences from {@code langPath} into the builder.
*
* @return number of sentences consumed
*/
private static long feedLanguage(GenerativeLanguageModel.Builder builder,
String lang, Path langPath,
boolean flat,
int maxPerLang) throws IOException {
long count = 0;
if (flat) {
try (BufferedReader reader = Files.newBufferedReader(langPath,
StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
String text = line.trim();
if (text.isEmpty()) {
continue;
}
builder.addSample(lang, text);
count++;
if (maxPerLang > 0 && count >= maxPerLang) {
break;
}
}
}
} else {
List<Path> files = listTxtFiles(langPath);
outer:
for (Path file : files) {
try (BufferedReader reader = Files.newBufferedReader(file,
StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
int tab = line.indexOf('\t');
if (tab < 0) {
continue;
}
String text = line.substring(tab + 1).trim();
if (text.isEmpty()) {
continue;
}
builder.addSample(lang, text);
count++;
if (maxPerLang > 0 && count >= maxPerLang) {
break outer;
}
}
}
}
}
return count;
}
/**
* Score every training sentence for {@code lang} against the built model
* and return {@code [mean, stdDev, count]} using Welford's online algorithm.
*/
private static double[] calibrateLanguage(
GenerativeLanguageModel model, String lang,
Path langPath, boolean flat, int maxPerLang) throws IOException {
long n = 0;
double mean = 0.0;
double m2 = 0.0;
if (flat) {
try (BufferedReader reader = Files.newBufferedReader(
langPath, StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
String text = line.trim();
if (text.isEmpty()) {
continue;
}
float s = model.score(text, lang);
if (Float.isNaN(s)) {
continue;
}
n++;
double delta = s - mean;
mean += delta / n;
m2 += delta * (s - mean);
if (maxPerLang > 0 && n >= maxPerLang) {
break;
}
}
}
} else {
List<Path> files = listTxtFiles(langPath);
outer:
for (Path file : files) {
try (BufferedReader reader = Files.newBufferedReader(
file, StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
int tab = line.indexOf('\t');
if (tab < 0) {
continue;
}
String text = line.substring(tab + 1).trim();
if (text.isEmpty()) {
continue;
}
float s = model.score(text, lang);
if (Float.isNaN(s)) {
continue;
}
n++;
double delta = s - mean;
mean += delta / n;
m2 += delta * (s - mean);
if (maxPerLang > 0 && n >= maxPerLang) {
break outer;
}
}
}
}
}
double stdDev = n > 1 ? Math.sqrt(m2 / (n - 1)) : 0.0;
return new double[]{mean, stdDev, n};
}
/**
* Probe a language path to decide whether it is CJK.
*/
private static boolean probeCjk(Path langPath, boolean flat,
int maxSentences) throws IOException {
long cjkLetters = 0;
long totalLetters = 0;
int sentences = 0;
List<Path> files = flat
? Collections.singletonList(langPath) : listTxtFiles(langPath);
outer:
for (Path file : files) {
try (BufferedReader reader = Files.newBufferedReader(file,
StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
String text;
if (flat) {
text = line.trim();
} else {
int tab = line.indexOf('\t');
if (tab < 0) continue;
text = line.substring(tab + 1);
}
if (text.isEmpty()) continue;
int i = 0;
while (i < text.length()) {
int cp = text.codePointAt(i);
i += Character.charCount(cp);
if (Character.isLetter(cp)) {
totalLetters++;
if (ScriptAwareFeatureExtractor.isCjkOrKana(
Character.toLowerCase(cp))) {
cjkLetters++;
}
}
}
sentences++;
if (sentences >= maxSentences) {
break outer;
}
}
}
}
if (totalLetters == 0) {
return false;
}
return (double) cjkLetters / totalLetters >= CJK_LETTER_THRESHOLD;
}
private static List<Path> listTxtFiles(Path dir) throws IOException {
List<Path> files = new ArrayList<>();
try (DirectoryStream<Path> stream = Files.newDirectoryStream(dir, "*.txt")) {
for (Path p : stream) {
files.add(p);
}
}
Collections.sort(files);
return files;
}
private static void printUsage() {
System.err.println("Usage: TrainGenerativeLanguageModel");
System.err.println(" --corpus <corpusDir>");
System.err.println(" --output <outputFile>");
System.err.println(" [--max-per-lang <N>] (default 500000)");
System.err.println(" [--add-k <k>] (default 0.01)");
System.err.println(" [--cjk lang1,lang2,...] (override auto-detection)");
}
}