GlmNoiseSensitivityReport.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.tika.langdetect.charsoup.tools;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Random;

import org.apache.tika.langdetect.charsoup.GenerativeLanguageModel;

/**
 * Measures GLM sensitivity to synthetic text corruption at multiple text lengths.
 * <p>
 * For each language and text length, randomly extracts runs of codepoints from the
 * concatenated FLORES-200 corpus (not sentence-initial truncation, which biases toward
 * sentence-opening patterns).  Each window is scored clean and under several noise types:
 * <ul>
 *   <li><b>clean</b>      ��� original run</li>
 *   <li><b>sub10/30/50</b> ��� X% of codepoints replaced with random same-script chars</li>
 *   <li><b>shuffle</b>    ��� codepoints randomly permuted (breaks context, preserves distribution)</li>
 *   <li><b>reversed</b>   ��� codepoint order reversed (directionality; PDFs store RTL visually)</li>
 *   <li><b>wrong-lang</b> ��� replaced with a run from a different language</li>
 *   <li><b>spc-ins</b>    ��� spaces injected randomly (PDF over-segmentation)</li>
 *   <li><b>spc-rem</b>    ��� existing spaces removed (PDF under-segmentation)</li>
 * </ul>
 * Output: one table per text length.  The "sep-*" columns are (clean ��� noised), showing
 * how much each noise type degrades the z-score ��� higher separation = better sensitivity.
 * <p>
 * Usage:
 * <pre>
 *   GlmNoiseSensitivityReport [floresDevTsv [samplesPerLang [seed [summaryTsv]]]]
 * </pre>
 */
public class GlmNoiseSensitivityReport {

    /** Text lengths (in codepoints) to evaluate. */
    private static final int[] LENGTHS = {20, 50, 100, 200};

    /** How many random windows to sample per language per length. */
    private static final int DEFAULT_SAMPLES = 500;

    /** Noise levels for the random-substitution sweep. */
    private static final double[] SUBST_RATES = {0.10, 0.30, 0.50};

    /** Rate at which spaces are inserted after non-space codepoints. */
    private static final double SPACE_INSERT_RATE = 0.20;

    /** Rate at which existing space codepoints are dropped. */
    private static final double SPACE_REMOVE_RATE = 0.80;

    // Column indices (must match buildNoiseLabels order)
    private static final Charset WIN1252 = Charset.forName("windows-1252");
    private static final Charset WIN1251 = Charset.forName("windows-1251");

    private static final int COL_CLEAN      = 0;
    private static final int COL_SUB10      = 1;
    private static final int COL_SUB30      = 2;
    private static final int COL_SUB50      = 3;
    private static final int COL_SHUFFLE    = 4;
    private static final int COL_REVERSED   = 5;
    private static final int COL_WRONGLANG  = 6;
    private static final int COL_SPC_INS    = 7;
    private static final int COL_SPC_REM    = 8;
    private static final int COL_MBK_LAT1   = 9;
    private static final int COL_MBK_WIN1252 = 10;
    private static final int COL_MBK_WIN1251 = 11;
    private static final int NUM_COLS        = 12;

    public static void main(String[] args) throws Exception {
        if (args.length < 2) {
            System.err.println("Usage: GlmNoiseSensitivityReport <modelFile> <floresDevTsv>"
                    + " [samplesPerLang] [seed] [summaryTsv]");
            System.err.println("  modelFile    -- path to GLM binary (required, no classpath default)");
            System.err.println("  floresDevTsv -- path to flores200_dev.tsv (required)");
            System.exit(1);
        }
        String modelPath = args[0];
        String floresPath = args[1];
        int samplesPerLang = args.length > 2 ? Integer.parseInt(args[2]) : DEFAULT_SAMPLES;
        long seed = args.length > 3 ? Long.parseLong(args[3]) : 42L;
        String tsvPath = args.length > 4 ? args[4] : null;

        System.out.println("Loading GLM model: " + modelPath);
        GenerativeLanguageModel glm;
        try (java.io.InputStream is = new java.io.FileInputStream(modelPath)) {
            glm = GenerativeLanguageModel.load(is);
        }

        System.out.println("Loading FLORES-200 dev: " + floresPath);
        Map<String, List<String>> byLang = loadFloresByLang(floresPath);
        System.out.printf(Locale.ROOT, "Loaded %d languages, %d total sentences%n",
                byLang.size(),
                byLang.values().stream().mapToInt(List::size).sum());
        System.out.printf(Locale.ROOT, "Samples per language per length: %d  seed: %d%n%n",
                samplesPerLang, seed);

        // Build per-language codepoint pools (all sentences concatenated with a space)
        Map<String, int[]> langPools = new LinkedHashMap<>();
        for (Map.Entry<String, List<String>> e : byLang.entrySet()) {
            langPools.put(e.getKey(), buildPool(e.getValue()));
        }

        // Build a flat wrong-lang pool per language: concatenation of all OTHER languages
        // (lazily ��� we pick from a random other-lang pool at sample time)
        List<String> allLangCodes = new ArrayList<>(langPools.keySet());

        String[] noiseLabels = buildNoiseLabels();

        // Accumulate grand means per length for TSV output
        double[][] tsvMeans = new double[LENGTHS.length][];

        for (int li = 0; li < LENGTHS.length; li++) {
            int targetLen = LENGTHS[li];
            System.out.printf(Locale.ROOT, "=== Length @%d codepoints ===%n%n", targetLen);

            printHeader(noiseLabels);

            double[] grandSum = new double[NUM_COLS];
            int[]    grandN   = new int[NUM_COLS];
            int      langCount = 0;

            for (String lang : glm.getLanguages()) {
                int[] pool = langPools.get(lang);
                if (pool == null || pool.length < targetLen) {
                    continue;
                }

                // Pick a different language's pool for wrong-lang substitution
                int[] wrongPool = pickWrongLangPool(lang, allLangCodes, langPools,
                        new Random(seed), targetLen);

                double[] sums = new double[NUM_COLS];
                int[]    ns   = new int[NUM_COLS];
                Random rng = new Random(seed ^ lang.hashCode());

                for (int s = 0; s < samplesPerLang; s++) {
                    // Randomly extract a window of targetLen codepoints from the pool
                    int[] window = randomWindow(pool, targetLen, rng);

                    // clean
                    addScore(glm, lang, fromCodepoints(window), sums, ns, COL_CLEAN);

                    // random substitutions
                    addScore(glm, lang, substituteRandom(window, 0.10, pool, rng), sums, ns, COL_SUB10);
                    addScore(glm, lang, substituteRandom(window, 0.30, pool, rng), sums, ns, COL_SUB30);
                    addScore(glm, lang, substituteRandom(window, 0.50, pool, rng), sums, ns, COL_SUB50);

                    // shuffle
                    addScore(glm, lang, shuffleCodepoints(window, rng), sums, ns, COL_SHUFFLE);

                    // reversed
                    addScore(glm, lang, reverseCodepoints(window), sums, ns, COL_REVERSED);

                    // wrong-lang: random window from a different language's pool
                    if (wrongPool != null) {
                        int[] wrongWindow = randomWindow(wrongPool, targetLen, rng);
                        addScore(glm, lang, fromCodepoints(wrongWindow), sums, ns, COL_WRONGLANG);
                    }

                    // space insertion
                    addScore(glm, lang, insertSpaces(window, rng), sums, ns, COL_SPC_INS);

                    // space removal
                    addScore(glm, lang, removeSpaces(window, rng), sums, ns, COL_SPC_REM);

                    // mojibake: UTF-8 bytes misread as Latin-1 / Win-1252 / Win-1251
                    // Only score if the mis-decoded string actually differs (ASCII-only windows
                    // are unchanged and should not inflate the mean).
                    String clean = fromCodepoints(window);
                    addMojibakeScore(glm, lang, clean, StandardCharsets.ISO_8859_1, sums, ns, COL_MBK_LAT1);
                    addMojibakeScore(glm, lang, clean, WIN1252,                     sums, ns, COL_MBK_WIN1252);
                    addMojibakeScore(glm, lang, clean, WIN1251,                     sums, ns, COL_MBK_WIN1251);
                }

                double[] means = computeMeans(sums, ns);
                printLangRow(lang, samplesPerLang, means);

                for (int i = 0; i < NUM_COLS; i++) {
                    grandSum[i] += sums[i];
                    grandN[i]   += ns[i];
                }
                langCount++;
            }

            double[] grandMeans = computeMeans(grandSum, grandN);
            printSeparator(noiseLabels);
            printLangRow("MEAN(" + langCount + ")", -1, grandMeans);
            System.out.println();
            tsvMeans[li] = grandMeans;
        }

        System.out.println("Columns: " + Arrays.toString(noiseLabels));
        System.out.println("sep-sub10  = clean ��� sub10 z-score  (char noise sensitivity)");
        System.out.println("sep-rev    = clean ��� reversed       (directionality)");
        System.out.println("sep-spc+   = clean ��� spc-ins        (over-segmentation sensitivity)");
        System.out.println("sep-spc-   = clean ��� spc-rem        (under-segmentation sensitivity)");
        System.out.println("mbk-lat1   = UTF-8 misread as ISO-8859-1 (only non-ASCII windows counted)");
        System.out.println("mbk-1252   = UTF-8 misread as Windows-1252");
        System.out.println("mbk-1251   = UTF-8 misread as Windows-1251 (Cyrillic)");

        if (tsvPath != null) {
            writeSummaryTsv(tsvPath, noiseLabels, tsvMeans);
            System.out.println("\nSummary TSV written to: " + tsvPath);
        }
    }

    // ---- sampling ----

    /** Extract a random window of {@code len} codepoints from {@code pool}. */
    private static int[] randomWindow(int[] pool, int len, Random rng) {
        int start = rng.nextInt(pool.length - len + 1);
        return Arrays.copyOfRange(pool, start, start + len);
    }

    private static int[] pickWrongLangPool(String lang, List<String> allCodes,
                                            Map<String, int[]> pools,
                                            Random rng, int minLen) {
        // Try up to 20 random other languages with a large enough pool
        for (int attempt = 0; attempt < 20; attempt++) {
            String other = allCodes.get(rng.nextInt(allCodes.size()));
            if (!other.equals(lang)) {
                int[] p = pools.get(other);
                if (p != null && p.length >= minLen) return p;
            }
        }
        return null;
    }

    // ---- noise functions ----

    private static String substituteRandom(int[] cps, double rate, int[] pool, Random rng) {
        int[] out = cps.clone();
        for (int i = 0; i < out.length; i++) {
            if (rng.nextDouble() < rate && pool.length > 0) {
                out[i] = pool[rng.nextInt(pool.length)];
            }
        }
        return fromCodepoints(out);
    }

    private static String shuffleCodepoints(int[] cps, Random rng) {
        int[] out = cps.clone();
        for (int i = out.length - 1; i > 0; i--) {
            int j = rng.nextInt(i + 1);
            int tmp = out[i];
            out[i] = out[j];
            out[j] = tmp;
        }
        return fromCodepoints(out);
    }

    private static String reverseCodepoints(int[] cps) {
        int[] out = new int[cps.length];
        for (int i = 0; i < cps.length; i++) out[i] = cps[cps.length - 1 - i];
        return fromCodepoints(out);
    }

    private static String insertSpaces(int[] cps, Random rng) {
        StringBuilder sb = new StringBuilder(cps.length * 2);
        for (int cp : cps) {
            sb.appendCodePoint(cp);
            if (cp != ' ' && rng.nextDouble() < SPACE_INSERT_RATE) sb.append(' ');
        }
        return sb.toString();
    }

    private static String removeSpaces(int[] cps, Random rng) {
        StringBuilder sb = new StringBuilder(cps.length);
        for (int cp : cps) {
            if (cp == ' ' && rng.nextDouble() < SPACE_REMOVE_RATE) continue;
            sb.appendCodePoint(cp);
        }
        return sb.toString();
    }

    // ---- output ----

    private static String[] buildNoiseLabels() {
        List<String> labels = new ArrayList<>();
        labels.add("clean");
        for (double r : SUBST_RATES) labels.add(String.format(Locale.ROOT, "sub%d%%", (int)(r * 100)));
        labels.add("shuffle");
        labels.add("reversed");
        labels.add("wrng-lng");
        labels.add("spc-ins");
        labels.add("spc-rem");
        labels.add("mbk-lat1");
        labels.add("mbk-1252");
        labels.add("mbk-1251");
        return labels.toArray(new String[0]);
    }

    private static void printHeader(String[] noiseLabels) {
        System.out.printf(Locale.ROOT, "%-14s", "lang");
        for (String label : noiseLabels) {
            System.out.printf(Locale.ROOT, "  %8s", label);
        }
        System.out.printf(Locale.ROOT, "  %8s  %8s  %8s  %8s%n", "sep-sub", "sep-rev", "sep-spc+", "sep-spc-");
        printSeparator(noiseLabels);
    }

    private static void printSeparator(String[] noiseLabels) {
        System.out.println("-".repeat(14 + noiseLabels.length * 10 + 38));
    }

    private static void printLangRow(String lang, int n, double[] means) {
        System.out.printf(Locale.ROOT, "%-14s", n >= 0 ? String.format(Locale.ROOT, "%s(%d)", lang, n) : lang);
        for (double m : means) {
            System.out.printf(Locale.ROOT, "  %8.3f", m);
        }
        double sepSub = means[COL_CLEAN] - means[COL_SUB10];
        double sepRev = means[COL_CLEAN] - means[COL_REVERSED];
        double sepSpcIns = means[COL_CLEAN] - means[COL_SPC_INS];
        double sepSpcRem = means[COL_CLEAN] - means[COL_SPC_REM];
        System.out.printf(Locale.ROOT, "  %8.3f  %8.3f  %8.3f  %8.3f%n", sepSub, sepRev, sepSpcIns, sepSpcRem);
    }

    // ---- TSV output ----

    /**
     * Write a compact TSV of grand-mean z-scores per length.
     * One header row, then one data row per length.
     * Separation columns (clean ��� noised) are appended for the key noise types.
     * <p>
     * Compare two runs with e.g.:
     * <pre>
     *   python3 -c "
     *   import csv, sys
     *   a = {r['length']: r for r in csv.DictReader(open(sys.argv[1]), delimiter='\t')}
     *   b = {r['length']: r for r in csv.DictReader(open(sys.argv[2]), delimiter='\t')}
     *   for ln in a:
     *       for k in a[ln]:
     *           if k != 'length':
     *               delta = float(b[ln][k]) - float(a[ln][k])
     *               if abs(delta) > 0.01: print(f'{ln}\t{k}\t{delta:+.3f}')
     *   " baseline.tsv new.tsv
     * </pre>
     */
    private static void writeSummaryTsv(String path, String[] noiseLabels,
                                         double[][] means) throws IOException {
        try (PrintWriter pw = new PrintWriter(path, StandardCharsets.UTF_8)) {
            // Header
            pw.print("length");
            for (String label : noiseLabels) {
                pw.print('\t');
                pw.print(label);
            }
            pw.print("\tsep-sub10\tsep-rev\tsep-spc+\tsep-spc-");
            pw.println();

            // One row per length
            for (int i = 0; i < LENGTHS.length; i++) {
                pw.print(LENGTHS[i]);
                double[] m = means[i];
                for (double v : m) {
                    pw.print('\t');
                    if (Double.isNaN(v)) pw.print("NaN");
                    else pw.printf(Locale.ROOT, "%.4f", v);
                }
                // Separation columns
                pw.printf(Locale.ROOT, "\t%.4f", m[COL_CLEAN] - m[COL_SUB10]);
                pw.printf(Locale.ROOT, "\t%.4f", m[COL_CLEAN] - m[COL_REVERSED]);
                pw.printf(Locale.ROOT, "\t%.4f", m[COL_CLEAN] - m[COL_SPC_INS]);
                pw.printf(Locale.ROOT, "\t%.4f", m[COL_CLEAN] - m[COL_SPC_REM]);
                pw.println();
            }
        }
    }

    // ---- helpers ----

    private static void addScore(GenerativeLanguageModel glm, String lang,
                                  String text, double[] sums, int[] ns, int col) {
        float z = glm.zScoreLengthAdjusted(text, lang);
        if (!Float.isNaN(z)) {
            sums[col] += z;
            ns[col]++;
        }
    }

    /**
     * Simulate charset misidentification: encode {@code clean} as UTF-8, then
     * decode those bytes as {@code wrongCharset}.  Only scores the result if it
     * actually differs from the original ��� pure-ASCII windows are unaffected by
     * this type of error and should not count toward the mean.
     */
    private static void addMojibakeScore(GenerativeLanguageModel glm, String lang,
                                          String clean, Charset wrongCharset,
                                          double[] sums, int[] ns, int col) {
        String mojibake = new String(clean.getBytes(StandardCharsets.UTF_8), wrongCharset);
        if (mojibake.equals(clean)) {
            return; // no non-ASCII content affected ��� skip this sample
        }
        float z = glm.zScoreLengthAdjusted(mojibake, lang);
        if (!Float.isNaN(z)) {
            sums[col] += z;
            ns[col]++;
        }
    }

    private static double[] computeMeans(double[] sums, int[] ns) {
        double[] means = new double[sums.length];
        for (int i = 0; i < sums.length; i++) {
            means[i] = ns[i] > 0 ? sums[i] / ns[i] : Double.NaN;
        }
        return means;
    }

    /**
     * Concatenate all sentences for a language into a single codepoint array,
     * joining with a space so sentence boundaries are natural.
     */
    private static int[] buildPool(List<String> sentences) {
        StringBuilder sb = new StringBuilder();
        for (int i = 0; i < sentences.size(); i++) {
            if (i > 0) sb.append(' ');
            sb.append(sentences.get(i));
        }
        return sb.toString().codePoints().toArray();
    }

    private static int[] toCodepoints(String s) {
        return s.codePoints().toArray();
    }

    private static String fromCodepoints(int[] cps) {
        StringBuilder sb = new StringBuilder(cps.length);
        for (int cp : cps) sb.appendCodePoint(cp);
        return sb.toString();
    }

    private static Map<String, List<String>> loadFloresByLang(String path) throws IOException {
        Map<String, List<String>> map = new LinkedHashMap<>();
        try (BufferedReader br = new BufferedReader(new FileReader(path, StandardCharsets.UTF_8))) {
            String line;
            while ((line = br.readLine()) != null) {
                String[] parts = line.split("\t", 2);
                if (parts.length < 2) continue;
                String lang = FloresNorm.normalize(parts[0]);
                map.computeIfAbsent(lang, k -> new ArrayList<>()).add(parts[1]);
            }
        }
        return map;
    }
}