LengthCalibrationReport.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.LinkedHashMap;
import java.util.Locale;
import java.util.Map;
import org.apache.tika.langdetect.charsoup.GenerativeLanguageModel;
/**
* Measures how score mean and stddev vary with text length for selected
* languages. Used to decide whether z-scores need length normalization
* at runtime.
*
* <p>For each language, truncates training sentences to various character
* lengths, scores them, and reports per-bucket (��, ��, n). If �� follows
* 1/���(charLen), a simple correction factor suffices at runtime.
*
* <h3>Usage</h3>
* <pre>
* java LengthCalibrationReport \
* --model generative.bin \
* --corpus /path/to/pool_filtered \
* --langs eng,fra,zho,jpn,ara,kor \
* [--max-per-lang 50000]
* </pre>
*/
public class LengthCalibrationReport {
private static final int DEFAULT_MAX = 50_000;
private static final int[] CHAR_LENGTHS = {10, 20, 30, 50, 75, 100, 150, 200, 500, 99999};
public static void main(String[] args) throws Exception {
Path modelPath = null;
Path corpusPath = null;
String langsArg = "eng,fra,zho,jpn,ara";
int max = DEFAULT_MAX;
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
case "--model":
modelPath = Paths.get(args[++i]);
break;
case "--corpus":
corpusPath = Paths.get(args[++i]);
break;
case "--langs":
langsArg = args[++i];
break;
case "--max-per-lang":
max = Integer.parseInt(args[++i]);
break;
default:
System.err.println("Unknown option: " + args[i]);
System.exit(1);
}
}
if (modelPath == null || corpusPath == null) {
System.err.println(
"Usage: LengthCalibrationReport --model <bin> --corpus <dir> "
+ "[--langs eng,fra,zho] [--max-per-lang 50000]");
System.exit(1);
}
GenerativeLanguageModel model;
try (InputStream is = new FileInputStream(modelPath.toFile())) {
model = GenerativeLanguageModel.load(is);
}
String[] langs = langsArg.split(",");
for (String lang : langs) {
lang = lang.trim();
if (!model.getLanguages().contains(lang)) {
System.err.println("Skipping unknown language: " + lang);
continue;
}
Path langFile = corpusPath.resolve(lang);
if (!Files.exists(langFile)) {
System.err.println("No corpus file for: " + lang);
continue;
}
System.out.printf(Locale.US, "%n=== %s ===%n", lang);
System.out.printf(Locale.US,
"%-10s %8s %10s %10s %12s %12s%n",
"MaxChars", "N", "��(score)", "��(score)",
"��*���(len/50)", "��(z-full)");
System.out.println("-".repeat(70));
// Read sentences once
String[] sentences = readSentences(langFile, max);
for (int maxLen : CHAR_LENGTHS) {
// Welford's online algorithm
long n = 0;
double mean = 0.0;
double m2 = 0.0;
double zSum = 0.0;
for (String sentence : sentences) {
String text = sentence.length() > maxLen
? sentence.substring(0, maxLen) : sentence;
float score = model.score(text, lang);
if (Float.isNaN(score)) {
continue;
}
n++;
double delta = score - mean;
mean += delta / n;
m2 += delta * (score - mean);
float z = model.zScore(text, lang);
if (!Float.isNaN(z)) {
zSum += z;
}
}
double stdDev = n > 1 ? Math.sqrt(m2 / (n - 1)) : 0.0;
// If �� ~ 1/���len, then ��*���(len/50) should be roughly constant
double normalized = stdDev * Math.sqrt((double) Math.min(maxLen, 200) / 50.0);
double meanZ = n > 0 ? zSum / n : 0.0;
String label = maxLen >= 99999 ? "full" : String.valueOf(maxLen);
System.out.printf(Locale.US,
"%-10s %,8d %10.4f %10.4f %12.4f %12.4f%n",
label, n, mean, stdDev, normalized, meanZ);
}
}
}
private static String[] readSentences(Path file, int max) throws Exception {
Map<Integer, String> lines = new LinkedHashMap<>();
try (BufferedReader reader = Files.newBufferedReader(
file, StandardCharsets.UTF_8)) {
String line;
int idx = 0;
while ((line = reader.readLine()) != null && idx < max) {
String text = line.trim();
if (!text.isEmpty()) {
lines.put(idx++, text);
}
}
}
return lines.values().toArray(new String[0]);
}
}