GlmRerankerEval.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.TreeMap;
import org.apache.tika.langdetect.charsoup.CharSoupDetectorConfig;
import org.apache.tika.langdetect.charsoup.CharSoupLanguageDetector;
import org.apache.tika.langdetect.charsoup.GenerativeLanguageModel;
import org.apache.tika.language.detect.LanguageResult;
/**
* Evaluates GLM reranking accuracy on top-N discriminative candidates.
* <p>
* For each FLORES-200 dev sentence at each text length, runs the discriminative
* model (STANDARD strategy, no GLM) to get the top-N ranked candidates, then uses
* GLM z-scores to rerank them. Reports the four outcome categories:
* <ul>
* <li><b>DISC_RIGHT / GLM_KEPT</b> ��� disc was right, GLM agreed (no harm done)</li>
* <li><b>DISC_RIGHT / GLM_BROKE</b> ��� disc was right, GLM flipped to wrong answer</li>
* <li><b>DISC_WRONG / GLM_RESCUED</b> ��� disc was wrong, correct answer was in top-N,
* GLM promoted it</li>
* <li><b>DISC_WRONG / GLM_MISSED</b> ��� disc was wrong, correct answer was in top-N,
* GLM failed to promote it</li>
* <li><b>DISC_WRONG / OUT_OF_TOPN</b> ��� disc was wrong and correct answer wasn't in
* top-N; GLM cannot help regardless</li>
* </ul>
* Net lift = RESCUED ��� BROKE (as percentage of total sentences).
* <p>
* Usage:
* <pre>
* GlmRerankerEval [floresDevTsv [topN [lengths]]]
* e.g. GlmRerankerEval ~/datasets/flores-200/flores200_dev.tsv 5 20,50,100,200
* </pre>
*/
public class GlmRerankerEval {
private static final int[] DEFAULT_LENGTHS = {20, 50, 100, 200};
private static final int DEFAULT_TOP_N = 5;
public static void main(String[] args) throws Exception {
String floresPath = args.length > 0
? args[0]
: System.getProperty("user.home") + "/datasets/flores-200/flores200_dev.tsv";
int topN = args.length > 1 ? Integer.parseInt(args[1]) : DEFAULT_TOP_N;
int[] lengths = args.length > 2 ? parseLengths(args[2]) : DEFAULT_LENGTHS;
CharSoupDetectorConfig cfg = CharSoupDetectorConfig.fromMap(
Map.of("strategy", "STANDARD"));
CharSoupLanguageDetector det = new CharSoupLanguageDetector(cfg);
det.loadModels();
GenerativeLanguageModel glm = GenerativeLanguageModel.loadFromClasspath(
GenerativeLanguageModel.DEFAULT_MODEL_RESOURCE);
System.out.println("Loading FLORES-200 dev: " + floresPath);
List<String[]> rows = loadFlores(floresPath); // [lang, text]
System.out.printf(Locale.ROOT, "Loaded %d sentences, topN=%d%n%n", rows.size(), topN);
// Header
System.out.printf(Locale.ROOT, "%-8s %8s %8s %8s %8s %8s %8s %8s %8s %8s %8s %10s%n",
"length", "total", "discAcc%", "glmAcc%", "netLift%",
"kept", "broke", "rescued", "missed", "outTopN",
"adjRate%", "rescued/broke");
System.out.println("-".repeat(120));
for (int len : lengths) {
Counters c = evalAtLength(rows, det, glm, topN, len);
printRow(len, c);
}
System.out.println();
// Per-language breakdown at @20 (most interesting)
System.out.println("=== Per-language breakdown @20 chars ===");
System.out.println("(languages where GLM rescued OR broke >= 3% of sentences)");
System.out.println();
evalPerLang(rows, det, glm, topN, 20);
}
// ---- evaluation ----
private static Counters evalAtLength(List<String[]> rows,
CharSoupLanguageDetector det,
GenerativeLanguageModel glm,
int topN, int len) {
Counters c = new Counters();
for (String[] row : rows) {
String trueLang = row[0];
String text = truncate(row[1], len);
List<LanguageResult> topResults = getTopN(det, text, topN);
if (topResults.isEmpty()) {
c.total++;
c.outTopN++;
continue;
}
String discPick = topResults.get(0).getLanguage();
boolean discRight = trueLang.equals(discPick);
boolean inTopN = topResults.stream()
.anyMatch(r -> trueLang.equals(r.getLanguage()));
String glmPick = rerank(glm, topResults, text);
boolean glmRight = trueLang.equals(glmPick);
c.total++;
if (discRight) {
c.discRight++;
if (glmRight) c.kept++;
else c.broke++;
} else if (inTopN) {
if (glmRight) c.rescued++;
else c.missed++;
} else {
c.outTopN++;
}
}
return c;
}
private static void evalPerLang(List<String[]> rows,
CharSoupLanguageDetector det,
GenerativeLanguageModel glm,
int topN, int len) {
// Per-language counters
Map<String, Counters> perLang = new TreeMap<>();
for (String[] row : rows) {
String trueLang = row[0];
String text = truncate(row[1], len);
List<LanguageResult> topResults = getTopN(det, text, topN);
Counters c = perLang.computeIfAbsent(trueLang, k -> new Counters());
c.total++;
if (topResults.isEmpty()) {
c.outTopN++;
continue;
}
String discPick = topResults.get(0).getLanguage();
boolean discRight = trueLang.equals(discPick);
boolean inTopN = topResults.stream()
.anyMatch(r -> trueLang.equals(r.getLanguage()));
String glmPick = rerank(glm, topResults, text);
boolean glmRight = trueLang.equals(glmPick);
if (discRight) {
c.discRight++;
if (glmRight) {
c.kept++;
} else {
c.broke++;
}
} else if (inTopN) {
if (glmRight) {
c.rescued++;
} else {
c.missed++;
}
} else {
c.outTopN++;
}
}
System.out.printf(Locale.ROOT, "%-12s %8s %8s %8s %8s %8s %8s %8s%n",
"lang", "total", "discAcc%", "glmAcc%", "netLift%",
"rescued", "broke", "outTopN");
System.out.println("-".repeat(90));
perLang.entrySet().stream()
.filter(e -> {
Counters c = e.getValue();
double rescuedPct = 100.0 * c.rescued / c.total;
double brokePct = 100.0 * c.broke / c.total;
return rescuedPct >= 3.0 || brokePct >= 3.0;
})
.sorted((a, b) -> {
double netA = (double)(a.getValue().rescued - a.getValue().broke) / a.getValue().total;
double netB = (double)(b.getValue().rescued - b.getValue().broke) / b.getValue().total;
return Double.compare(netB, netA); // descending by net lift
})
.forEach(e -> {
Counters c = e.getValue();
System.out.printf(Locale.ROOT, "%-12s %8d %8.2f %8.2f %8.2f %8d %8d %8d%n",
e.getKey(), c.total,
100.0 * c.discRight / c.total,
100.0 * (c.discRight - c.broke + c.rescued) / c.total,
100.0 * (c.rescued - c.broke) / c.total,
c.rescued, c.broke, c.outTopN);
});
}
// ---- GLM reranking ----
/**
* Rerank {@code candidates} by GLM z-score and return the top pick.
* Falls back to the discriminative winner if GLM produces no finite scores.
*/
private static String rerank(GenerativeLanguageModel glm,
List<LanguageResult> candidates,
String text) {
String best = candidates.get(0).getLanguage();
float bestZ = Float.NEGATIVE_INFINITY;
for (LanguageResult r : candidates) {
String lang = r.getLanguage();
if (lang.isEmpty()) continue;
float z = glm.zScoreLengthAdjusted(text, lang);
if (!Float.isNaN(z) && z > bestZ) {
bestZ = z;
best = lang;
}
}
return best;
}
// ---- discriminative model ----
private static List<LanguageResult> getTopN(CharSoupLanguageDetector det,
String text, int topN) {
det.reset();
det.addText(text);
List<LanguageResult> all = det.detectAll();
return all.size() <= topN ? all : all.subList(0, topN);
}
// ---- output ----
private static void printRow(int len, Counters c) {
int discAcc = c.discRight;
int glmAcc = c.discRight - c.broke + c.rescued;
int adjudicated = c.rescued + c.missed + c.broke + c.kept - c.discRight; // hmm
// adjudicated = all cases where GLM differed from disc
int glmChanged = c.rescued + c.broke;
double adjRate = 100.0 * glmChanged / c.total;
double rescuedOverBroke = c.broke > 0
? (double) c.rescued / c.broke : Double.POSITIVE_INFINITY;
System.out.printf(Locale.ROOT, "@%-7d %8d %8.2f %8.2f %+8.2f %8d %8d %8d %8d %8d %8.1f %10.2f%n",
len, c.total,
100.0 * discAcc / c.total,
100.0 * glmAcc / c.total,
100.0 * (c.rescued - c.broke) / c.total,
c.kept, c.broke, c.rescued, c.missed, c.outTopN,
adjRate, rescuedOverBroke);
}
// ---- helpers ----
private static String truncate(String text, int len) {
// Truncate at codepoint boundary
if (text.length() <= len) return text;
int cpCount = 0, i = 0;
while (i < text.length() && cpCount < len) {
int cp = text.codePointAt(i);
i += Character.charCount(cp);
cpCount++;
}
return text.substring(0, i);
}
private static int[] parseLengths(String s) {
String[] parts = s.split(",");
int[] arr = new int[parts.length];
for (int i = 0; i < parts.length; i++) arr[i] = Integer.parseInt(parts[i].trim());
return arr;
}
private static List<String[]> loadFlores(String path) throws IOException {
List<String[]> rows = new ArrayList<>();
try (BufferedReader br = new BufferedReader(new FileReader(path, StandardCharsets.UTF_8))) {
String line;
while ((line = br.readLine()) != null) {
String[] parts = line.split("\t", 2);
if (parts.length < 2) continue;
// Normalize FLORES lang codes (strip script suffix, apply remaps)
String lang = EvalGenerativeModel.normalizeLang(parts[0].trim());
rows.add(new String[]{lang, parts[1].trim()});
}
}
return rows;
}
// ---- counters ----
private static class Counters {
int total, discRight, kept, broke, rescued, missed, outTopN;
}
}