EvalGenerativeModel.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import org.apache.tika.langdetect.charsoup.GenerativeLanguageModel;
/**
* Self-consistency evaluation for {@link GenerativeLanguageModel}.
*
* <p>For each sentence in the test file, computes {@code score(text, L)}
* for every language in the model and checks whether the argmax equals
* the true label. Reports overall accuracy and per-language accuracy
* sorted from worst to best.
*
* <p>Accepts either:
* <ul>
* <li>Flores-200 TSV: {@code lang_Script TAB text} ��� script suffixes are
* stripped and FLORES-specific codes are remapped to model codes.</li>
* <li>Standard corpus format: {@code lang TAB text}</li>
* </ul>
*
* <h3>Usage</h3>
* <pre>
* java EvalGenerativeModel \
* --model generative.bin \
* --test /path/to/flores200_dev.tsv \
* [--max-per-lang 997]
* </pre>
*/
public class EvalGenerativeModel {
private static final int DEFAULT_MAX_PER_LANG = 0; // 0 = unlimited
private static final int DEFAULT_MAX_CHARS = 0; // 0 = full sentence
// ---- Flores-200 normalisation (mirrors CompareDetectors) ----
private static final Set<String> FLORES_KEEP_SCRIPT_SUFFIX = Set.of(
"ace_Arab", "arb_Latn", "bjn_Arab",
"kas_Deva", "knc_Latn", "min_Arab", "taq_Tfng"
);
private static final Map<String, String> FLORES_CODE_REMAP;
static {
Map<String, String> m = new HashMap<>();
m.put("arb", "ara");
m.put("pes", "fas");
m.put("zsm", "msa");
m.put("lvs", "lav");
m.put("azj", "aze");
m.put("ekk", "est");
m.put("npi", "nep");
m.put("als", "sqi");
m.put("ory", "ori");
m.put("nor", "nob");
m.put("cmn", "zho");
m.put("swa", "swh");
m.put("yid", "ydd");
m.put("gug", "grn");
m.put("quz", "que");
m.put("plt", "mlg");
m.put("pbt", "pus");
m.put("uzn", "uzb");
m.put("kmr", "kur");
m.put("khk", "mon");
FLORES_CODE_REMAP = m;
}
static String normalizeLang(String raw) {
if (FLORES_KEEP_SCRIPT_SUFFIX.contains(raw)) {
return raw;
}
int underscore = raw.indexOf('_');
String base = underscore >= 0 ? raw.substring(0, underscore) : raw;
return FLORES_CODE_REMAP.getOrDefault(base, base);
}
// ---- Entry point ----
public static void main(String[] args) throws Exception {
Path modelPath = null;
Path testPath = null;
int maxPerLang = DEFAULT_MAX_PER_LANG;
int[] maxCharsSet = null; // null = full sentence only
boolean showConfusions = false;
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
case "--model":
modelPath = Paths.get(args[++i]);
break;
case "--test":
testPath = Paths.get(args[++i]);
break;
case "--max-per-lang":
maxPerLang = Integer.parseInt(args[++i]);
break;
case "--show-confusions":
showConfusions = true;
break;
case "--lengths": {
String[] parts = args[++i].split(",");
maxCharsSet = new int[parts.length];
for (int j = 0; j < parts.length; j++) {
maxCharsSet[j] = Integer.parseInt(parts[j].trim());
}
break;
}
default:
System.err.println("Unknown option: " + args[i]);
printUsage();
System.exit(1);
}
}
if (modelPath == null || testPath == null) {
printUsage();
System.exit(1);
}
System.out.println("Loading model: " + modelPath);
GenerativeLanguageModel model;
try (InputStream is = new FileInputStream(modelPath.toFile())) {
model = GenerativeLanguageModel.load(is);
}
System.out.printf(Locale.US, " %d languages (%d CJK, %d non-CJK)%n",
model.getLanguages().size(),
model.getLanguages().stream().filter(model::isCjk).count(),
model.getLanguages().stream().filter(l -> !model.isCjk(l)).count());
System.out.println("Loading test data: " + testPath);
List<LabeledSentence> data = loadTestFile(testPath);
boolean floresMode = data.stream().anyMatch(s -> s.getLanguage().contains("_"));
if (floresMode) {
System.out.println(" Flores-200 mode: normalizing lang codes");
List<LabeledSentence> normalized = new ArrayList<>(data.size());
for (LabeledSentence s : data) {
normalized.add(new LabeledSentence(
normalizeLang(s.getLanguage()), s.getText()));
}
data = normalized;
}
// Cap per language if requested
if (maxPerLang > 0) {
data = samplePerLang(data, maxPerLang);
}
Set<String> modelLangs = new java.util.HashSet<>(model.getLanguages());
// Split into scorable (true lang is in model) and unscorable
List<LabeledSentence> scorable = new ArrayList<>();
Map<String, Integer> skipped = new HashMap<>();
for (LabeledSentence s : data) {
if (modelLangs.contains(s.getLanguage())) {
scorable.add(s);
} else {
skipped.merge(s.getLanguage(), 1, Integer::sum);
}
}
System.out.printf(Locale.US, " %,d sentences; %,d scorable, %,d skipped (%d langs not in model)%n",
data.size(), scorable.size(),
data.size() - scorable.size(), skipped.size());
if (!skipped.isEmpty()) {
List<String> sk = new ArrayList<>(skipped.keySet());
java.util.Collections.sort(sk);
System.out.println(" Skipped langs: " + sk);
}
// Build the set of lengths to evaluate
int[] lengths = maxCharsSet != null ? maxCharsSet : new int[]{0};
for (int maxChars : lengths) {
String label = maxChars > 0 ? "@" + maxChars + " chars" : "full";
List<LabeledSentence> run = maxChars > 0
? truncate(scorable, maxChars) : scorable;
System.out.printf(Locale.US, "%nScoring [%s]���%n", label);
long wallStart = System.nanoTime();
// confusions: trueLang -> (predictedLang -> count)
Map<String, Map<String, Integer>> confusions =
showConfusions ? new java.util.TreeMap<>() : null;
Map<String, int[]> perLang = evalAll(model, run, confusions);
long elapsedMs = (System.nanoTime() - wallStart) / 1_000_000;
int totalCorrect = 0;
int totalCount = 0;
for (int[] v : perLang.values()) {
totalCorrect += v[0];
totalCount += v[1];
}
System.out.printf(Locale.US,
"Overall [%s]: %.2f%% (%,d / %,d) in %,dms (%.0f sent/s)%n",
label, 100.0 * totalCorrect / totalCount,
totalCorrect, totalCount,
elapsedMs, totalCount * 1000.0 / elapsedMs);
List<Map.Entry<String, int[]>> rows = new ArrayList<>(perLang.entrySet());
rows.sort(Comparator.comparingDouble(
e -> (double) e.getValue()[0] / e.getValue()[1]));
System.out.printf(Locale.US, "%n%-16s %8s %8s %8s%n",
"Language", "Correct", "Total", "Acc%");
System.out.println("-".repeat(46));
for (Map.Entry<String, int[]> e : rows) {
int[] v = e.getValue();
System.out.printf(Locale.US, "%-16s %8d %8d %7.2f%%%n",
e.getKey(), v[0], v[1], 100.0 * v[0] / v[1]);
}
System.out.println();
int[] thresholds = {100, 95, 90, 80, 50};
for (int t : thresholds) {
long above = rows.stream()
.filter(e -> 100.0 * e.getValue()[0] / e.getValue()[1] >= t)
.count();
System.out.printf(Locale.US, " >= %3d%% accuracy: %3d / %d languages%n",
t, above, rows.size());
}
if (confusions != null) {
System.out.println("\n=== Confusion distributions (wrong predictions only) ===");
for (Map.Entry<String, Map<String, Integer>> langEntry
: confusions.entrySet()) {
String trueLang = langEntry.getKey();
Map<String, Integer> preds = langEntry.getValue();
int total = perLang.get(trueLang)[1];
int correct = perLang.get(trueLang)[0];
int wrong = total - correct;
if (wrong == 0) continue;
System.out.printf(Locale.US,
"%n %s (%d wrong / %d total): ",
trueLang, wrong, total);
preds.entrySet().stream()
.sorted(Map.Entry.<String, Integer>comparingByValue()
.reversed())
.limit(10)
.forEach(e -> System.out.printf(Locale.US,
"%s=%d ", e.getKey(), e.getValue()));
System.out.println();
}
}
}
}
// ---- Scoring ----
private static Map<String, int[]> evalAll(
GenerativeLanguageModel model,
List<LabeledSentence> data,
Map<String, Map<String, Integer>> confusions) {
Map<String, int[]> perLang = new HashMap<>();
List<String> allLangs = model.getLanguages();
for (LabeledSentence s : data) {
String trueLang = s.getLanguage();
String predicted = argmax(model, allLangs, s.getText());
int[] counts = perLang.computeIfAbsent(trueLang, k -> new int[2]);
counts[1]++;
if (trueLang.equals(predicted)) {
counts[0]++;
} else if (confusions != null && predicted != null) {
confusions.computeIfAbsent(trueLang, k -> new HashMap<>())
.merge(predicted, 1, Integer::sum);
}
}
return perLang;
}
private static String argmax(GenerativeLanguageModel model,
List<String> langs, String text) {
String best = null;
float bestS = Float.NEGATIVE_INFINITY;
for (String lang : langs) {
float s = model.score(text, lang);
if (!Float.isNaN(s) && s > bestS) {
bestS = s;
best = lang;
}
}
return best;
}
// ---- I/O helpers ----
static List<LabeledSentence> loadTestFile(Path path) throws Exception {
List<LabeledSentence> sentences = new ArrayList<>();
try (BufferedReader reader = Files.newBufferedReader(path, StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
int tab = line.indexOf('\t');
if (tab < 0) {
continue;
}
String lang = line.substring(0, tab).trim();
String text = line.substring(tab + 1).trim();
if (!lang.isEmpty() && !text.isEmpty()) {
sentences.add(new LabeledSentence(lang, text));
}
}
}
return sentences;
}
private static List<LabeledSentence> truncate(
List<LabeledSentence> data, int maxChars) {
List<LabeledSentence> result = new ArrayList<>(data.size());
for (LabeledSentence s : data) {
String t = s.getText();
result.add(new LabeledSentence(s.getLanguage(),
t.length() > maxChars ? t.substring(0, maxChars) : t));
}
return result;
}
private static List<LabeledSentence> samplePerLang(
List<LabeledSentence> data, int max) {
Map<String, Integer> counts = new HashMap<>();
List<LabeledSentence> result = new ArrayList<>();
for (LabeledSentence s : data) {
int n = counts.merge(s.getLanguage(), 1, Integer::sum);
if (n <= max) {
result.add(s);
}
}
return result;
}
private static void printUsage() {
System.err.println("Usage: EvalGenerativeModel");
System.err.println(" --model <generative.bin>");
System.err.println(" --test <testFile.tsv>");
System.err.println(" [--max-per-lang <N>]");
System.err.println(" [--lengths 50,100,200] (truncate sentences to N chars)");
}
}