ZScoreDistributionReport.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.FileInputStream;
import java.io.InputStream;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

import org.apache.tika.langdetect.charsoup.GenerativeLanguageModel;

/**
 * Report what fraction of sentences fall below various z-score thresholds,
 * at different text lengths. Uses {@link GenerativeLanguageModel#zScoreLengthAdjusted}.
 *
 * <h3>Usage</h3>
 * <pre>
 *   java ZScoreDistributionReport \
 *       --model  generative.bin \
 *       --test   flores200_dev.tsv \
 *       [--langs eng,fra,deu,zho,ara,rus,jpn] \
 *       [--lengths 20,50,100,200,full]
 * </pre>
 */
public class ZScoreDistributionReport {

    private static final int[] DEFAULT_LENGTHS = {20, 50, 100, 200, 0};
    private static final float[] THRESHOLDS = {-1.0f, -1.5f, -2.0f, -2.5f, -3.0f, -4.0f, -5.0f};

    public static void main(String[] args) throws Exception {
        Path modelPath = null;
        Path testPath = null;
        Set<String> filterLangs = null;
        int[] lengths = DEFAULT_LENGTHS;

        for (int i = 0; i < args.length; i++) {
            switch (args[i]) {
                case "--model":
                    modelPath = Paths.get(args[++i]);
                    break;
                case "--test":
                    testPath = Paths.get(args[++i]);
                    break;
                case "--langs":
                    filterLangs = new TreeSet<>(Arrays.asList(args[++i].split(",")));
                    break;
                case "--lengths": {
                    String[] parts = args[++i].split(",");
                    lengths = new int[parts.length];
                    for (int j = 0; j < parts.length; j++) {
                        String p = parts[j].trim();
                        lengths[j] = p.equalsIgnoreCase("full") ? 0 : Integer.parseInt(p);
                    }
                    break;
                }
                default:
                    System.err.println("Unknown option: " + args[i]);
                    System.exit(1);
            }
        }
        if (modelPath == null || testPath == null) {
            System.err.println("Usage: ZScoreDistributionReport --model <.bin> --test <.tsv>");
            System.exit(1);
        }

        GenerativeLanguageModel model;
        try (InputStream is = new FileInputStream(modelPath.toFile())) {
            model = GenerativeLanguageModel.load(is);
        }

        List<LabeledSentence> data = EvalGenerativeModel.loadTestFile(testPath);
        // Normalize FLORES codes
        if (data.stream().anyMatch(s -> s.getLanguage().contains("_"))) {
            List<LabeledSentence> norm = new ArrayList<>(data.size());
            for (LabeledSentence s : data) {
                norm.add(new LabeledSentence(
                        EvalGenerativeModel.normalizeLang(s.getLanguage()), s.getText()));
            }
            data = norm;
        }

        // Group by language
        Set<String> modelLangs = new TreeSet<>(model.getLanguages());
        Map<String, List<String>> byLang = new HashMap<>();
        for (LabeledSentence s : data) {
            if (!modelLangs.contains(s.getLanguage())) {
                continue;
            }
            byLang.computeIfAbsent(s.getLanguage(), k -> new ArrayList<>())
                    .add(s.getText());
        }

        // Select languages to report
        List<String> reportLangs;
        if (filterLangs != null) {
            reportLangs = new ArrayList<>();
            for (String l : filterLangs) {
                if (byLang.containsKey(l)) {
                    reportLangs.add(l);
                } else {
                    System.err.println("Warning: " + l + " not found in test data");
                }
            }
        } else {
            reportLangs = new ArrayList<>(new TreeSet<>(byLang.keySet()));
        }

        System.out.printf(Locale.US, "Model: %s  (%d languages)%n", modelPath, model.getLanguages().size());
        System.out.printf(Locale.US, "Test:  %s  (%d languages, %,d sentences)%n%n",
                testPath, byLang.size(), data.size());

        for (int len : lengths) {
            String lenLabel = len > 0 ? len + " chars" : "full";
            System.out.printf(Locale.US, "=== Length: %s ===%n%n", lenLabel);

            // Header
            StringBuilder hdr = new StringBuilder();
            hdr.append(String.format(Locale.US, "%-8s  %5s  %6s  %6s", "Lang", "N", "mean-z", "std-z"));
            for (float t : THRESHOLDS) {
                hdr.append(String.format(Locale.US, "  z<%.1f", t));
            }
            System.out.println(hdr);
            System.out.println("-".repeat(hdr.length()));

            // Aggregate stats across all languages
            int totalN = 0;
            int[] totalBelow = new int[THRESHOLDS.length];

            for (String lang : reportLangs) {
                List<String> sentences = byLang.get(lang);
                if (sentences == null) {
                    continue;
                }
                float[] zScores = new float[sentences.size()];
                int valid = 0;
                for (int si = 0; si < sentences.size(); si++) {
                    String text = sentences.get(si);
                    if (len > 0 && text.length() > len) {
                        text = text.substring(0, len);
                    }
                    float z = model.zScoreLengthAdjusted(text, lang);
                    if (!Float.isNaN(z)) {
                        zScores[valid++] = z;
                    }
                }
                if (valid == 0) {
                    continue;
                }
                zScores = Arrays.copyOf(zScores, valid);
                Arrays.sort(zScores);

                double sum = 0;
                double sum2 = 0;
                for (int j = 0; j < valid; j++) {
                    sum += zScores[j];
                    sum2 += (double) zScores[j] * zScores[j];
                }
                double mean = sum / valid;
                double std = valid > 1 ? Math.sqrt((sum2 - sum * sum / valid) / (valid - 1)) : 0;

                StringBuilder row = new StringBuilder();
                row.append(String.format(Locale.US, "%-8s  %5d  %+6.2f  %6.3f", lang, valid, mean, std));
                for (int ti = 0; ti < THRESHOLDS.length; ti++) {
                    int below = countBelow(zScores, THRESHOLDS[ti]);
                    double pct = 100.0 * below / valid;
                    row.append(String.format(Locale.US, "  %5.1f%%", pct));
                    totalBelow[ti] += below;
                }
                totalN += valid;
                System.out.println(row);
            }

            // Summary row
            if (reportLangs.size() > 1) {
                System.out.println("-".repeat(hdr.length()));
                StringBuilder tot = new StringBuilder();
                tot.append(String.format(Locale.US, "%-8s  %5d  %6s  %6s", "ALL", totalN, "", ""));
                for (int ti = 0; ti < THRESHOLDS.length; ti++) {
                    double pct = 100.0 * totalBelow[ti] / totalN;
                    tot.append(String.format(Locale.US, "  %5.1f%%", pct));
                }
                System.out.println(tot);
            }
            System.out.println();
        }
    }

    private static int countBelow(float[] sorted, float threshold) {
        int idx = Arrays.binarySearch(sorted, threshold);
        if (idx < 0) {
            return -idx - 1;
        }
        while (idx < sorted.length && sorted[idx] <= threshold) {
            idx++;
        }
        return idx;
    }
}