ConfusionDumper.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedInputStream;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.FeatureExtractor;
/**
* Dumps a confusion matrix for a specified set of languages.
* For each target language, shows what the model predicted instead
* (top-N confusions by count), plus accuracy.
* <p>
* Usage: {@code ConfusionDumper <testSplitFile> <modelFile> [lang1,lang2,...]}
* <p>
* If no languages are specified, a default set of known weak languages is used.
*/
public class ConfusionDumper {
private static final String[] DEFAULT_LANGUAGES = {
"sqi", "tat", "ita", "sun", "mad", "pus", "mkd",
"ban", "sme", "spa"
};
/** Show top-N confusions per language. */
private static final int TOP_N = 10;
public static void main(String[] args) throws Exception {
if (args.length < 2) {
System.err.println(
"Usage: ConfusionDumper <testSplitFile> <modelFile> [lang1,lang2,...]");
System.exit(1);
}
Path testFile = Paths.get(args[0]);
Path modelFile = Paths.get(args[1]);
Set<String> targetLangs;
if (args.length >= 3) {
targetLangs = new HashSet<>(Arrays.asList(args[2].split(",")));
} else {
targetLangs = new HashSet<>(Arrays.asList(DEFAULT_LANGUAGES));
}
// Load model
System.out.println("Loading model: " + modelFile);
CharSoupModel model;
try (InputStream is = new BufferedInputStream(Files.newInputStream(modelFile))) {
model = CharSoupModel.load(is);
}
FeatureExtractor extractor = model.createExtractor();
System.out.printf(Locale.US, " %d classes, %d buckets%n",
model.getNumClasses(), model.getNumBuckets());
// Load test data
System.out.println("Loading test data: " + testFile);
List<LabeledSentence> data = TrainLanguageModel.readPreprocessedFile(testFile);
System.out.printf(Locale.US, " %,d sentences%n%n", data.size());
// For each target language: map of predicted_label -> count
// Using LinkedHashMap to preserve insertion order for target langs
Map<String, Map<String, Integer>> confusions = new LinkedHashMap<>();
Map<String, Integer> totals = new LinkedHashMap<>();
Map<String, Integer> corrects = new LinkedHashMap<>();
for (String lang : targetLangs) {
confusions.put(lang, new TreeMap<>());
totals.put(lang, 0);
corrects.put(lang, 0);
}
// Evaluate
for (LabeledSentence s : data) {
String truth = s.getLanguage();
if (!targetLangs.contains(truth)) {
continue;
}
int[] features = extractor.extract(s.getText());
float[] probs = model.predict(features);
int predicted = argmax(probs);
String predLabel = model.getLabel(predicted);
totals.merge(truth, 1, Integer::sum);
if (predLabel.equals(truth)) {
corrects.merge(truth, 1, Integer::sum);
} else {
confusions.get(truth).merge(predLabel, 1, Integer::sum);
}
}
// Print results
for (String lang : targetLangs) {
int total = totals.getOrDefault(lang, 0);
int correct = corrects.getOrDefault(lang, 0);
if (total == 0) {
System.out.printf(Locale.US, "%s: no test samples found%n%n", lang);
continue;
}
System.out.printf(Locale.US, "%s: %,d/%,d correct (%.1f%%)%n",
lang, correct, total, 100.0 * correct / total);
Map<String, Integer> confused = confusions.get(lang);
if (confused.isEmpty()) {
System.out.println(" (no confusions)\n");
continue;
}
// Sort by count descending, take top N
confused.entrySet().stream()
.sorted((a, b) -> Integer.compare(b.getValue(), a.getValue()))
.limit(TOP_N)
.forEach(e -> {
System.out.printf(Locale.US,
" -> %-12s %,6d (%5.1f%% of errors, %5.1f%% of total)%n",
e.getKey(), e.getValue(),
100.0 * e.getValue() / (total - correct),
100.0 * e.getValue() / total);
});
// Show how many other distinct confusions there are
long shown = Math.min(TOP_N, confused.size());
if (confused.size() > shown) {
int otherCount = confused.entrySet().stream()
.sorted((a, b) -> Integer.compare(b.getValue(), a.getValue()))
.skip(TOP_N)
.mapToInt(Map.Entry::getValue)
.sum();
System.out.printf(Locale.US,
" -> (+ %d other languages, %,d errors)%n",
confused.size() - shown, otherCount);
}
System.out.println();
}
}
private static int argmax(float[] arr) {
int best = 0;
for (int i = 1; i < arr.length; i++) {
if (arr[i] > arr[best]) {
best = i;
}
}
return best;
}
}