DiagnoseUnknownScript.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.FeatureExtractor;
/**
* Prints top-N logits and softmax scores for Flores test sentences
* belonging to languages that had 0% F1 (e.g. bod, dzo) to assess
* whether mispredictions are confident or marginal.
*/
public class DiagnoseUnknownScript {
private static final int TOP_N = 5;
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.err.println(
"Usage: DiagnoseUnknownScript <flores_tsv> <model_bin> <lang1> [lang2 ...]");
System.exit(1);
}
String floresPath = args[0];
String modelPath = args[1];
String[] targetLangs = Arrays.copyOfRange(args, 2, args.length);
CharSoupModel model;
try (DataInputStream dis = new DataInputStream(
new FileInputStream(modelPath))) {
model = CharSoupModel.load(dis);
}
String[] labels = model.getLabels();
FeatureExtractor ext = model.createExtractor();
int[] buf = new int[model.getNumBuckets()];
// build label���index map
java.util.Map<String, Integer> labelIndex = new java.util.HashMap<>();
for (int i = 0; i < labels.length; i++) {
labelIndex.put(labels[i], i);
}
List<String> lines = Files.readAllLines(Paths.get(floresPath));
for (String targetPrefix : targetLangs) {
System.out.println("���������������������������������������������������������������������������������������������������������������������������������������������");
System.out.println("Language prefix: " + targetPrefix);
System.out.println("���������������������������������������������������������������������������������������������������������������������������������������������");
int sentenceNum = 0;
for (String line : lines) {
int tab = line.indexOf('\t');
if (tab < 0) continue;
String langCode = line.substring(0, tab);
// match e.g. "bod" against "bod_Tibt"
if (!langCode.startsWith(targetPrefix)) continue;
String text = line.substring(tab + 1).trim();
if (text.isEmpty()) continue;
sentenceNum++;
if (sentenceNum > 5) break; // first 5 sentences is enough
ext.extractFromPreprocessed(text, buf, true);
float[] logits = model.predictLogits(buf);
float[] probs = CharSoupModel.softmax(logits.clone());
// rank by logit descending
Integer[] idx = new Integer[labels.length];
for (int i = 0; i < idx.length; i++) idx[i] = i;
Arrays.sort(idx, (a, b) -> Float.compare(logits[b], logits[a]));
// max logit and entropy
float maxLogit = logits[idx[0]];
double entropy = 0.0;
for (float p : probs) {
if (p > 0f) entropy -= p * Math.log(p);
}
double maxEntropy = Math.log(labels.length);
System.out.printf(Locale.US,
"%nSentence %d [%s]: %.60s...%n", sentenceNum, langCode,
text);
System.out.printf(Locale.US,
" max_logit=%.3f entropy=%.3f / %.3f (%.1f%% of max)%n",
maxLogit, entropy, maxEntropy,
100.0 * entropy / maxEntropy);
System.out.printf(Locale.US, " %-12s %8s %8s%n",
"Label", "Logit", "Softmax%");
System.out.println(" " + "���".repeat(34));
for (int r = 0; r < TOP_N; r++) {
int i = idx[r];
System.out.printf(Locale.US, " %-12s %8.3f %7.2f%%%n",
labels[i], logits[i], probs[i] * 100f);
}
// also show the true language if it's in the model
String trueLabel = langCode.contains("_")
? langCode.substring(0, langCode.indexOf('_'))
: langCode;
Integer trueIdx = labelIndex.get(trueLabel);
if (trueIdx != null) {
int rank = 0;
for (int r = 0; r < idx.length; r++) {
if (idx[r] == trueIdx) {
rank = r + 1;
break;
}
}
if (rank > TOP_N) {
System.out.printf(Locale.US,
" ... [true label '%s' is rank %d, "
+ "logit=%.3f, softmax=%.4f%%]%n",
trueLabel, rank, logits[trueIdx],
probs[trueIdx] * 100f);
}
} else {
System.out.printf(Locale.ROOT, " [true label '%s' not in model]%n",
trueLabel);
}
}
System.out.println();
}
}
}