CrossDomainEval.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.apache.tika.langdetect.charsoup.CharSoupFeatureExtractor;
import org.apache.tika.langdetect.charsoup.CharSoupLanguageDetector;
import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.FeatureExtractor;
import org.apache.tika.language.detect.LanguageDetector;
/**
* Cross-domain evaluation: test CharSoup model and OpenNLP on external
* datasets (Tatoeba) that were NOT used during training.
* <p>
* Reports accuracy at multiple text-length thresholds to assess
* short-text performance ��� the key use case for PDF directionality
* detection.
* <p>
* Also measures heap usage and throughput, and writes out a normalized
* {@code lang\ttext} file for the shared-language subset so that
* external tools (e.g., {@code eval_fasttext.py}) can be run on
* identical data.
* <p>
* Usage:
* <pre>
* CrossDomainEval <charSoupModelFile> <dataset> <dataPath>
* [reportFile] [threads]
*
* dataset: "tatoeba"
* dataPath: for tatoeba: path to sentences.csv
* </pre>
*/
public class CrossDomainEval {
/** Length thresholds for bucketed accuracy reporting. */
private static final int[] THRESHOLDS = {10, 20, 30, 50, 100};
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.err.println("Usage: CrossDomainEval <charSoupModel>"
+ " tatoeba <dataPath>"
+ " [reportFile] [threads]");
System.exit(1);
}
Path modelFile = Paths.get(args[0]);
String dataset = args[1].toLowerCase(Locale.ROOT);
Path dataPath = Paths.get(args[2]);
Path reportFile = args.length > 3 ? Paths.get(args[3]) : null;
int threads = args.length > 4
? Integer.parseInt(args[4])
: Runtime.getRuntime().availableProcessors();
// ---- Load CharSoup model (with heap measurement) ----
System.out.println("Loading CharSoup model: " + modelFile);
long heapBefore = usedHeap();
CharSoupModel model;
try (InputStream is = new BufferedInputStream(
Files.newInputStream(modelFile))) {
model = CharSoupModel.load(is);
}
FeatureExtractor extractor = model.createExtractor();
long bigramHeapBytes = usedHeap() - heapBefore;
Set<String> modelLangs = new HashSet<>(
Arrays.asList(model.getLabels()));
System.out.printf(Locale.US,
" %d classes, %d buckets, ~%.1f MB heap%n",
model.getNumClasses(), model.getNumBuckets(),
bigramHeapBytes / (1024.0 * 1024.0));
// ---- Load OpenNLP detectors (one per thread) ----
System.out.println("Loading OpenNLP detector(s)...");
heapBefore = usedHeap();
List<LanguageDetector> opennlpPool = new ArrayList<>();
for (int i = 0; i < threads; i++) {
LanguageDetector d =
CompareDetectors.loadDetector(
"org.apache.tika.langdetect.opennlp.OpenNLPDetector");
if (d == null) {
break;
}
opennlpPool.add(d);
}
LanguageDetector opennlpDetector =
opennlpPool.isEmpty() ? null : opennlpPool.get(0);
long opennlpHeapBytes = opennlpDetector != null
? usedHeap() - heapBefore : 0;
if (opennlpDetector != null) {
System.out.printf(Locale.US,
" OpenNLP: %d instance(s), ~%.1f MB heap%n",
opennlpPool.size(),
opennlpHeapBytes / (1024.0 * 1024.0));
}
// ---- Load Optimaize detectors (one per thread) ----
System.out.println("Loading Optimaize detector(s)...");
heapBefore = usedHeap();
List<LanguageDetector> optimaizePool = new ArrayList<>();
for (int i = 0; i < threads; i++) {
LanguageDetector d =
CompareDetectors.loadDetector(
"org.apache.tika.langdetect.optimaize.OptimaizeLangDetector");
if (d == null) {
break;
}
optimaizePool.add(d);
}
LanguageDetector optimaizeDetector =
optimaizePool.isEmpty() ? null : optimaizePool.get(0);
long optimaizeHeapBytes = optimaizeDetector != null
? usedHeap() - heapBefore : 0;
if (optimaizeDetector != null) {
System.out.printf(Locale.US,
" Optimaize: %d instance(s), ~%.1f MB heap%n",
optimaizePool.size(),
optimaizeHeapBytes / (1024.0 * 1024.0));
}
// ---- Load dataset ----
System.out.println("\nLoading dataset: " + dataset
+ " from " + dataPath);
List<LabeledSentence> data;
if ("tatoeba".equals(dataset)) {
data = loadTatoeba(dataPath, modelLangs);
} else {
System.err.println("Unknown dataset: " + dataset
+ " (only 'tatoeba' is supported)");
System.exit(1);
return;
}
// Count shared languages
Set<String> sharedLangs = new HashSet<>();
for (LabeledSentence s : data) {
sharedLangs.add(s.getLanguage());
}
System.out.printf(Locale.US,
"Loaded %,d sentences across %d shared languages%n",
data.size(), sharedLangs.size());
// ---- Write normalized shared-language file for fastText ----
if (reportFile != null) {
Path sharedFile = reportFile.getParent() != null
? reportFile.getParent().resolve(
dataset + "-shared.txt")
: Paths.get(dataset + "-shared.txt");
System.out.println("Writing shared-language data: "
+ sharedFile);
try (BufferedWriter w = Files.newBufferedWriter(
sharedFile, StandardCharsets.UTF_8)) {
for (LabeledSentence s : data) {
w.write(s.getLanguage());
w.write('\t');
w.write(s.getText());
w.newLine();
}
}
System.out.printf(Locale.US,
" Wrote %,d sentences for external eval%n",
data.size());
}
// ---- Bucket by length ----
Map<String, List<LabeledSentence>> buckets =
bucketByLength(data);
for (Map.Entry<String, List<LabeledSentence>> e
: buckets.entrySet()) {
System.out.printf(Locale.US, " %-14s %,d%n",
e.getKey(), e.getValue().size());
}
// ---- Warmup ----
System.out.println("\nWarming up...");
int warmup = Math.min(500, data.size());
CharSoupLanguageDetector warmupDet = new CharSoupLanguageDetector();
for (int i = 0; i < warmup; i++) {
String wt = data.get(i).getText();
warmupDet.reset();
warmupDet.addText(wt.toCharArray(), 0, wt.length());
warmupDet.detectAll();
if (opennlpDetector != null) {
opennlpDetector.reset();
opennlpDetector.addText(data.get(i).getText());
opennlpDetector.detectAll();
}
if (optimaizeDetector != null) {
optimaizeDetector.reset();
optimaizeDetector.addText(data.get(i).getText());
optimaizeDetector.detectAll();
}
}
// ---- Evaluate all length buckets ----
System.out.println("Evaluating (" + threads
+ " threads)...\n");
Map<String, CompareDetectors.EvalResult> bigramResults =
new LinkedHashMap<>();
Map<String, CompareDetectors.EvalResult> opennlpResults =
new LinkedHashMap<>();
Map<String, CompareDetectors.EvalResult> optimaizeResults =
new LinkedHashMap<>();
for (Map.Entry<String, List<LabeledSentence>> e
: buckets.entrySet()) {
String bucket = e.getKey();
List<LabeledSentence> subset = e.getValue();
bigramResults.put(bucket,
CompareDetectors.evaluateBigramParallel(
subset, "bigram-" + bucket, threads, null,
CharSoupLanguageDetector.Strategy.STANDARD));
opennlpResults.put(bucket,
CompareDetectors.evaluateOpenNLPParallel(
opennlpPool, subset,
"opennlp-" + bucket));
optimaizeResults.put(bucket,
CompareDetectors.evaluateOptimaizeParallel(
optimaizePool, subset,
"optimaize-" + bucket));
}
// ---- Build report ----
StringBuilder report = new StringBuilder();
report.append(String.format(Locale.US,
"=== Cross-Domain Evaluation: %s ===%n%n",
dataset.toUpperCase(Locale.ROOT)));
report.append(String.format(Locale.US,
"Model: %s (%d classes, %d buckets)%n",
modelFile.getFileName(),
model.getNumClasses(), model.getNumBuckets()));
report.append(String.format(Locale.US,
"Dataset: %s (%,d sentences, %d shared langs)%n",
dataset, data.size(), sharedLangs.size()));
report.append(String.format(Locale.US,
"Threads: %d%n%n", threads));
// Model sizes
report.append("Model heap (approx):\n");
report.append(String.format(Locale.US,
" CharSoup: ~%.1f MB%n",
bigramHeapBytes / (1024.0 * 1024.0)));
report.append(String.format(Locale.US,
" OpenNLP: ~%.1f MB%n",
opennlpHeapBytes / (1024.0 * 1024.0)));
report.append(String.format(Locale.US,
" Optimaize: ~%.1f MB%n%n",
optimaizeHeapBytes / (1024.0 * 1024.0)));
// Strict accuracy summary table
report.append(
"Strict accuracy (exact language match):\n");
report.append(String.format(Locale.US,
"%-14s %10s %10s %10s %12s %12s%n",
"Length bucket", "CharSoup", "OpenNLP", "Optimaize",
"Time(ms)", "Sent/sec"));
report.append("-".repeat(80)).append("\n");
for (String bucket : buckets.keySet()) {
CompareDetectors.EvalResult br =
bigramResults.get(bucket);
CompareDetectors.EvalResult or =
opennlpResults.get(bucket);
CompareDetectors.EvalResult pr =
optimaizeResults.get(bucket);
report.append(String.format(Locale.US,
"%-14s %9s %9s %9s %,10d %,10.0f%n",
bucket,
fmtAcc(br, false),
fmtAcc(or, false),
fmtAcc(pr, false),
br.elapsedMs,
throughput(br)));
}
report.append("\n");
// Group accuracy summary table
report.append(
"Group accuracy (confusable languages "
+ "counted as correct):\n");
report.append(formatConfusableGroups());
report.append(String.format(Locale.US,
"%-14s %10s %10s %10s%n",
"Length bucket", "CharSoup", "OpenNLP", "Optimaize"));
report.append("-".repeat(52)).append("\n");
for (String bucket : buckets.keySet()) {
CompareDetectors.EvalResult br =
bigramResults.get(bucket);
CompareDetectors.EvalResult or =
opennlpResults.get(bucket);
CompareDetectors.EvalResult pr =
optimaizeResults.get(bucket);
report.append(String.format(Locale.US,
"%-14s %9s %9s %9s%n",
bucket,
fmtAcc(br, true),
fmtAcc(or, true),
fmtAcc(pr, true)));
}
report.append("\n");
// CharSoup timing (full pipeline: script gate + group collapse)
CompareDetectors.EvalResult bigramAll =
bigramResults.get("all");
if (bigramAll != null) {
report.append("CharSoup timing (wall-clock, full pipeline):\n");
report.append(String.format(Locale.US,
" Wall-clock total: %,d ms%n%n", bigramAll.elapsedMs));
}
// Per-language detail
report.append(perLanguageReport(bigramResults, opennlpResults, optimaizeResults));
// Detailed CharSoup analysis: macro F1, confusion pairs,
// confidence calibration, entropy-threshold accuracy
System.out.println("Running detailed CharSoup analysis...");
report.append(detailedCharSoupAnalysis(
model, extractor, data, threads));
String reportStr = report.toString();
System.out.println(reportStr);
if (reportFile != null) {
if (reportFile.getParent() != null) {
Files.createDirectories(reportFile.getParent());
}
try (BufferedWriter w = Files.newBufferedWriter(
reportFile, StandardCharsets.UTF_8)) {
w.write(reportStr);
}
System.out.println("Report written to: " + reportFile);
}
}
// ---- Dataset loaders ----
/**
* Load Tatoeba sentences.csv (id\tlang\ttext), filtering to
* languages our model supports. Tatoeba uses ISO 639-3 codes
* which match ours directly.
*/
static List<LabeledSentence> loadTatoeba(
Path sentencesFile, Set<String> modelLangs)
throws Exception {
List<LabeledSentence> sentences = new ArrayList<>();
Map<String, Integer> skippedLangs = new HashMap<>();
int total = 0;
try (BufferedReader reader = Files.newBufferedReader(
sentencesFile, StandardCharsets.UTF_8)) {
String line;
while ((line = reader.readLine()) != null) {
total++;
String[] parts = line.split("\t", 3);
if (parts.length < 3) {
continue;
}
String lang = parts[1].trim();
String text = parts[2].trim();
if (text.isEmpty()) {
continue;
}
// Map Tatoeba codes to our ISO 639-3 codes
String mapped = mapTatoebaLang(lang);
if (mapped != null && modelLangs.contains(mapped)) {
sentences.add(
new LabeledSentence(mapped, text));
} else {
skippedLangs.merge(lang, 1, Integer::sum);
}
}
}
Set<String> foundLangs = new HashSet<>();
for (LabeledSentence s : sentences) {
foundLangs.add(s.getLanguage());
}
System.out.printf(Locale.US,
"Tatoeba: %,d/%,d sentences in %d shared languages"
+ " (skipped %d language codes)%n",
sentences.size(), total, foundLangs.size(),
skippedLangs.size());
return sentences;
}
// ---- Language code mapping ----
/**
* Map Tatoeba language codes to our model's ISO 639-3 codes.
* Tatoeba mostly uses ISO 639-3 but has some exceptions.
*/
static String mapTatoebaLang(String code) {
switch (code) {
case "ber": return "kab";
case "cmn": return "cmn";
case "zsm": return "zsm";
case "lvs": return "lvs";
case "ekk": return "ekk";
case "nob": return "nob";
case "nno": return "nno";
case "yue": return "yue";
case "wuu": return "wuu";
case "por": return "por";
default: return code;
}
}
// ---- Length bucketing ----
static Map<String, List<LabeledSentence>> bucketByLength(
List<LabeledSentence> data) {
Map<String, List<LabeledSentence>> buckets =
new LinkedHashMap<>();
for (int t : THRESHOLDS) {
buckets.put("<=" + t + " chars", new ArrayList<>());
}
buckets.put(">" + THRESHOLDS[THRESHOLDS.length - 1]
+ " chars", new ArrayList<>());
buckets.put("all", new ArrayList<>());
for (LabeledSentence s : data) {
int len = s.getText().length();
buckets.get("all").add(s);
boolean placed = false;
for (int t : THRESHOLDS) {
if (len <= t) {
buckets.get("<=" + t + " chars").add(s);
placed = true;
break;
}
}
if (!placed) {
buckets.get(">" + THRESHOLDS[THRESHOLDS.length - 1]
+ " chars").add(s);
}
}
return buckets;
}
// ---- Per-language report ----
/**
* Build per-language accuracy table from the "all" bucket results.
* Includes strict and group accuracy for both detectors, plus
* confusable-group markers.
*/
static String perLanguageReport(
Map<String, CompareDetectors.EvalResult> bigramResults,
Map<String, CompareDetectors.EvalResult> opennlpResults,
Map<String, CompareDetectors.EvalResult> optimaizeResults) {
CompareDetectors.EvalResult bigramAll =
bigramResults.get("all");
CompareDetectors.EvalResult opennlpAll =
opennlpResults.get("all");
CompareDetectors.EvalResult optimaizeAll =
optimaizeResults.get("all");
StringBuilder sb = new StringBuilder();
sb.append("Per-language accuracy (all sentences):\n");
sb.append(String.format(Locale.US,
"%-12s %8s %8s %8s %8s %8s %8s %8s %8s %8s%n",
"Language", "CharSoup", "CS-Grp%", "CharSoup%",
"OpenNLP", "ON-Grp%", "ONLP%",
"Optimaize", "Opt-Grp%", "Opt%"));
sb.append("-".repeat(118)).append("\n");
// Merge per-lang:
// [0]=bigram strict, [1]=bigram total, [2]=bigram group,
// [3]=opennlp strict, [4]=opennlp total, [5]=opennlp group,
// [6]=optimaize strict, [7]=optimaize total, [8]=optimaize group
Map<String, int[]> merged = new TreeMap<>();
if (bigramAll != null && bigramAll.perLang != null) {
for (var e : bigramAll.perLang.entrySet()) {
int[] row = merged.computeIfAbsent(
e.getKey(), k -> new int[9]);
row[0] = e.getValue()[0];
row[1] = e.getValue()[1];
row[2] = e.getValue()[2];
}
}
if (opennlpAll != null && opennlpAll.perLang != null) {
for (var e : opennlpAll.perLang.entrySet()) {
int[] row = merged.computeIfAbsent(
e.getKey(), k -> new int[9]);
row[3] = e.getValue()[0];
row[4] = e.getValue()[1];
row[5] = e.getValue()[2];
}
}
if (optimaizeAll != null && optimaizeAll.perLang != null) {
for (var e : optimaizeAll.perLang.entrySet()) {
int[] row = merged.computeIfAbsent(
e.getKey(), k -> new int[9]);
row[6] = e.getValue()[0];
row[7] = e.getValue()[1];
row[8] = e.getValue()[2];
}
}
int bigramWins = 0;
int opennlpWins = 0;
int optimaizeWins = 0;
int ties = 0;
for (var e : merged.entrySet()) {
int[] c = e.getValue();
String lang = e.getKey();
boolean confusable = isInConfusableGroup(lang);
String bStrict = c[1] > 0
? String.format(Locale.US, "%6.1f%%",
100.0 * c[0] / c[1]) : " N/A";
String bGroup = confusable && c[1] > 0
? String.format(Locale.US, "%6.1f%%",
100.0 * c[2] / c[1]) : " ";
String oStrict = c[4] > 0
? String.format(Locale.US, "%6.1f%%",
100.0 * c[3] / c[4]) : " N/A";
String oGroup = confusable && c[4] > 0
? String.format(Locale.US, "%6.1f%%",
100.0 * c[5] / c[4]) : " ";
String pStrict = c[7] > 0
? String.format(Locale.US, "%6.1f%%",
100.0 * c[6] / c[7]) : " N/A";
String pGroup = confusable && c[7] > 0
? String.format(Locale.US, "%6.1f%%",
100.0 * c[8] / c[7]) : " ";
String marker = confusable ? " *" : "";
sb.append(String.format(Locale.US,
"%-12s %4d/%-4d %s %s"
+ " %4d/%-4d %s %s"
+ " %4d/%-4d %s %s%s%n",
lang, c[0], c[1], bGroup, bStrict,
c[3], c[4], oGroup, oStrict,
c[6], c[7], pGroup, pStrict, marker));
double bAcc = c[1] > 0 ? (double) c[0] / c[1] : -1;
double oAcc = c[4] > 0 ? (double) c[3] / c[4] : -1;
double pAcc = c[7] > 0 ? (double) c[6] / c[7] : -1;
double best = Math.max(bAcc, Math.max(oAcc, pAcc));
if (best < 0) {
continue;
}
if (bAcc >= best - 0.005 && oAcc >= best - 0.005
&& pAcc >= best - 0.005) {
ties++;
} else if (bAcc >= best - 0.005) {
bigramWins++;
} else if (oAcc >= best - 0.005) {
opennlpWins++;
} else {
optimaizeWins++;
}
}
sb.append("\n* = member of a confusable group; "
+ "Grp% = group accuracy\n");
sb.append(String.format(Locale.US,
"%nCharSoup wins: %d OpenNLP wins: %d Optimaize wins: %d Ties: %d "
+ "(>0.5%% margin)%n",
bigramWins, opennlpWins, optimaizeWins, ties));
return sb.toString();
}
// ---- Detailed CharSoup analysis ----
/** Number of entropy histogram bins (0.1 resolution, 0.0 to 12.0). */
private static final int ENTROPY_BINS = 120;
private static final float ENTROPY_BIN_WIDTH = 0.1f;
/** Entropy thresholds for "reject low-confidence" analysis. */
private static final float[] ENTROPY_THRESHOLDS =
{0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 5.0f};
/** Top-N confusion pairs to report. */
private static final int TOP_CONFUSIONS = 30;
/**
* Thread-local accumulator for detailed per-prediction stats.
* Designed to be merged across threads after parallel evaluation.
*/
static class DetailedStats {
// Per-language TP/FP/FN for macro F1
// key=lang, value=[TP, FP, FN]
Map<String, int[]> langCounts = new HashMap<>();
// Confusion pairs: "truth\tpred" -> count (errors only)
Map<String, Integer> confusions = new HashMap<>();
// Entropy histogram: [bin][0]=total, [bin][1]=correct
int[][] entropyHist = new int[ENTROPY_BINS + 1][2];
// Running sums for mean entropy
double entropyCorrectSum = 0;
int entropyCorrectCount = 0;
double entropyIncorrectSum = 0;
int entropyIncorrectCount = 0;
void record(String truth, String predicted, float entropy) {
boolean correct = truth.equals(predicted);
// TP/FP/FN
int[] tc = langCounts.computeIfAbsent(
truth, k -> new int[3]);
int[] pc = langCounts.computeIfAbsent(
predicted, k -> new int[3]);
if (correct) {
tc[0]++; // TP for truth
} else {
tc[2]++; // FN for truth
pc[1]++; // FP for predicted
String key = truth + "\t" + predicted;
confusions.merge(key, 1, Integer::sum);
}
// Entropy histogram
int bin = Math.min(
(int) (entropy / ENTROPY_BIN_WIDTH), ENTROPY_BINS);
entropyHist[bin][0]++;
if (correct) {
entropyHist[bin][1]++;
}
// Running entropy sums
if (correct) {
entropyCorrectSum += entropy;
entropyCorrectCount++;
} else {
entropyIncorrectSum += entropy;
entropyIncorrectCount++;
}
}
void merge(DetailedStats other) {
for (var e : other.langCounts.entrySet()) {
int[] mine = langCounts.computeIfAbsent(
e.getKey(), k -> new int[3]);
int[] theirs = e.getValue();
mine[0] += theirs[0];
mine[1] += theirs[1];
mine[2] += theirs[2];
}
for (var e : other.confusions.entrySet()) {
confusions.merge(
e.getKey(), e.getValue(), Integer::sum);
}
for (int i = 0; i <= ENTROPY_BINS; i++) {
entropyHist[i][0] += other.entropyHist[i][0];
entropyHist[i][1] += other.entropyHist[i][1];
}
entropyCorrectSum += other.entropyCorrectSum;
entropyCorrectCount += other.entropyCorrectCount;
entropyIncorrectSum += other.entropyIncorrectSum;
entropyIncorrectCount += other.entropyIncorrectCount;
}
}
/**
* Run a detailed single-pass evaluation collecting macro F1,
* confusion pairs, confidence calibration, and entropy-threshold
* accuracy.
*/
static String detailedCharSoupAnalysis(
CharSoupModel model, FeatureExtractor extractor,
List<LabeledSentence> data, int threads)
throws Exception {
// Build label index for argmax -> label
Map<String, Integer> labelIndex = new HashMap<>();
for (int i = 0; i < model.getNumClasses(); i++) {
labelIndex.put(model.getLabel(i), i);
}
// Parallel evaluation
DetailedStats merged;
if (threads <= 1) {
merged = detailedChunk(model, extractor, data);
} else {
List<List<LabeledSentence>> chunks =
CompareDetectors.partition(data, threads);
ExecutorService pool =
Executors.newFixedThreadPool(chunks.size());
try {
List<Future<DetailedStats>> futures =
new ArrayList<>();
for (List<LabeledSentence> chunk : chunks) {
FeatureExtractor te =
model.createExtractor();
futures.add(pool.submit(
() -> detailedChunk(
model, te, chunk)));
}
merged = new DetailedStats();
for (Future<DetailedStats> f : futures) {
merged.merge(f.get());
}
} finally {
pool.shutdown();
}
}
// Format report
StringBuilder sb = new StringBuilder();
// 1. Macro F1
sb.append("Macro F1 (CharSoup model):\n");
double macroF1Sum = 0;
int macroF1Count = 0;
for (var e : merged.langCounts.entrySet()) {
int[] c = e.getValue();
int tp = c[0];
int fp = c[1];
int fn = c[2];
double precision = tp + fp > 0
? (double) tp / (tp + fp) : 0;
double recall = tp + fn > 0
? (double) tp / (tp + fn) : 0;
double f1 = precision + recall > 0
? 2 * precision * recall / (precision + recall)
: 0;
if (tp + fn > 0) {
macroF1Sum += f1;
macroF1Count++;
}
}
double macroF1 = macroF1Count > 0
? macroF1Sum / macroF1Count : 0;
sb.append(String.format(Locale.US,
" Macro F1: %.4f (%d languages)%n",
macroF1, macroF1Count));
// Also compute micro F1 (= accuracy for multi-class)
int totalTP = 0;
int totalSamples = 0;
for (int[] c : merged.langCounts.values()) {
totalTP += c[0];
totalSamples += c[0] + c[2]; // TP + FN = total per lang
}
sb.append(String.format(Locale.US,
" Micro F1: %.4f (= accuracy)%n%n",
totalSamples > 0
? (double) totalTP / totalSamples : 0));
// Bottom-10 languages by F1
List<Map.Entry<String, double[]>> langF1 = new ArrayList<>();
for (var e : merged.langCounts.entrySet()) {
int[] c = e.getValue();
int tp = c[0], fp = c[1], fn = c[2];
double p = tp + fp > 0 ? (double) tp / (tp + fp) : 0;
double r = tp + fn > 0 ? (double) tp / (tp + fn) : 0;
double f1 = p + r > 0 ? 2 * p * r / (p + r) : 0;
langF1.add(Map.entry(e.getKey(),
new double[]{f1, p, r, tp + fn}));
}
langF1.sort((a, b) -> Double.compare(
a.getValue()[0], b.getValue()[0]));
sb.append("Bottom 15 languages by F1:\n");
sb.append(String.format(Locale.US,
" %-12s %8s %8s %8s %6s%n",
"Language", "F1", "Prec", "Recall", "Count"));
sb.append(" ").append("-".repeat(50)).append("\n");
int show = Math.min(15, langF1.size());
for (int i = 0; i < show; i++) {
var e = langF1.get(i);
double[] v = e.getValue();
sb.append(String.format(Locale.US,
" %-12s %7.4f %7.4f %7.4f %5.0f%n",
e.getKey(), v[0], v[1], v[2], v[3]));
}
sb.append("\n");
// 2. Top confusion pairs
sb.append(String.format(Locale.US,
"Top %d confusion pairs (truth -> predicted):%n",
TOP_CONFUSIONS));
sb.append(String.format(Locale.US,
" %-12s %-12s %8s%n",
"Truth", "Predicted", "Count"));
sb.append(" ").append("-".repeat(36)).append("\n");
merged.confusions.entrySet().stream()
.sorted((a, b) -> Integer.compare(
b.getValue(), a.getValue()))
.limit(TOP_CONFUSIONS)
.forEach(e -> {
String[] parts = e.getKey().split("\t", 2);
sb.append(String.format(Locale.US,
" %-12s %-12s %,8d%n",
parts[0], parts[1], e.getValue()));
});
sb.append("\n");
// 3. Confidence calibration (mean entropy)
sb.append("Confidence calibration (entropy, bits):\n");
double meanCorrect = merged.entropyCorrectCount > 0
? merged.entropyCorrectSum
/ merged.entropyCorrectCount : 0;
double meanIncorrect = merged.entropyIncorrectCount > 0
? merged.entropyIncorrectSum
/ merged.entropyIncorrectCount : 0;
sb.append(String.format(Locale.US,
" Correct predictions: mean entropy = "
+ "%.3f bits (%,d samples)%n",
meanCorrect, merged.entropyCorrectCount));
sb.append(String.format(Locale.US,
" Incorrect predictions: mean entropy = "
+ "%.3f bits (%,d samples)%n",
meanIncorrect, merged.entropyIncorrectCount));
sb.append(String.format(Locale.US,
" Separation ratio: %.1fx%n%n",
meanCorrect > 0
? meanIncorrect / meanCorrect : 0));
// 4. Entropy-threshold accuracy ("reject uncertain")
sb.append("Entropy-threshold accuracy "
+ "(reject predictions above threshold):\n");
sb.append(String.format(Locale.US,
" %-14s %10s %10s %10s%n",
"Max entropy", "Accuracy", "Accepted",
"Rejected%"));
sb.append(" ").append("-".repeat(50)).append("\n");
// Cumulate histogram from low to high
int cumTotal = 0;
int cumCorrect = 0;
int grandTotal = merged.entropyCorrectCount
+ merged.entropyIncorrectCount;
int threshIdx = 0;
for (int bin = 0; bin <= ENTROPY_BINS
&& threshIdx < ENTROPY_THRESHOLDS.length; bin++) {
cumTotal += merged.entropyHist[bin][0];
cumCorrect += merged.entropyHist[bin][1];
float binEnd = (bin + 1) * ENTROPY_BIN_WIDTH;
while (threshIdx < ENTROPY_THRESHOLDS.length
&& ENTROPY_THRESHOLDS[threshIdx] <= binEnd) {
double acc = cumTotal > 0
? 100.0 * cumCorrect / cumTotal : 0;
double rejPct = grandTotal > 0
? 100.0 * (grandTotal - cumTotal)
/ grandTotal : 0;
sb.append(String.format(Locale.US,
" <= %-9.1f %9.2f%% %,10d %9.1f%%%n",
ENTROPY_THRESHOLDS[threshIdx],
acc, cumTotal, rejPct));
threshIdx++;
}
}
// Fill remaining thresholds with cumulated totals
// (accumulate remaining bins)
for (int bin = (int) (ENTROPY_THRESHOLDS[
Math.min(threshIdx, ENTROPY_THRESHOLDS.length) - 1]
/ ENTROPY_BIN_WIDTH) + 1;
bin <= ENTROPY_BINS; bin++) {
cumTotal += merged.entropyHist[bin][0];
cumCorrect += merged.entropyHist[bin][1];
}
while (threshIdx < ENTROPY_THRESHOLDS.length) {
double acc = cumTotal > 0
? 100.0 * cumCorrect / cumTotal : 0;
double rejPct = grandTotal > 0
? 100.0 * (grandTotal - cumTotal)
/ grandTotal : 0;
sb.append(String.format(Locale.US,
" <= %-9.1f %9.2f%% %,10d %9.1f%%%n",
ENTROPY_THRESHOLDS[threshIdx],
acc, cumTotal, rejPct));
threshIdx++;
}
// "no threshold" = accept all
sb.append(String.format(Locale.US,
" %-14s %9.2f%% %,10d %9.1f%%%n",
"(no threshold)",
grandTotal > 0
? 100.0 * (merged.entropyCorrectCount)
/ grandTotal : 0,
grandTotal, 0.0));
sb.append("\n");
return sb.toString();
}
/**
* Evaluate a chunk of data collecting detailed per-prediction stats.
*/
private static DetailedStats detailedChunk(
CharSoupModel model, FeatureExtractor extractor,
List<LabeledSentence> data) {
DetailedStats stats = new DetailedStats();
for (LabeledSentence s : data) {
String cleaned = CharSoupFeatureExtractor.preprocess(
s.getText());
int[] features =
extractor.extractFromPreprocessed(cleaned);
float[] probs = model.predict(features);
float entropy = CharSoupModel.entropy(probs);
int predIdx = 0;
for (int c = 1; c < probs.length; c++) {
if (probs[c] > probs[predIdx]) {
predIdx = c;
}
}
String predicted = model.getLabel(predIdx);
stats.record(s.getLanguage(), predicted, entropy);
}
return stats;
}
// ---- Helpers ----
/**
* Check if a language is a member of any confusable group.
*/
private static boolean isInConfusableGroup(String lang) {
for (String[] group : CompareDetectors.CONFUSABLE_GROUPS) {
for (String member : group) {
if (member.equals(lang)) {
return true;
}
}
}
return false;
}
private static String fmtAcc(
CompareDetectors.EvalResult r, boolean group) {
if (r == null || r.total == 0) {
return "N/A";
}
int num = group ? r.correctGroup : r.correct;
return String.format(Locale.US,
"%.2f%%", 100.0 * num / r.total);
}
private static double throughput(
CompareDetectors.EvalResult r) {
if (r == null || r.elapsedMs <= 0) {
return 0;
}
return r.total / (r.elapsedMs / 1000.0);
}
private static double pct(long num, long denom) {
return denom > 0 ? 100.0 * num / denom : 0.0;
}
private static String formatConfusableGroups() {
StringBuilder sb = new StringBuilder();
sb.append(" Groups: ");
for (int i = 0;
i < CompareDetectors.CONFUSABLE_GROUPS.length; i++) {
if (i > 0) {
sb.append(", ");
}
sb.append("{").append(String.join("/",
CompareDetectors.CONFUSABLE_GROUPS[i]))
.append("}");
}
sb.append("\n");
return sb.toString();
}
/** Force GC and return approximate used heap in bytes. */
private static long usedHeap() {
Runtime rt = Runtime.getRuntime();
for (int i = 0; i < 3; i++) {
rt.gc();
}
try {
Thread.sleep(100);
} catch (InterruptedException ignored) {
Thread.currentThread().interrupt();
}
return rt.totalMemory() - rt.freeMemory();
}
}