TrigramAblation.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedReader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.Set;
/**
* Quick ablation: bigrams+trigrams at 8k/16k/32k and
* bigrams+trigrams+wordUnigrams for comparison.
*
* Usage: TrigramAblation <prepDir>
*/
public class TrigramAblation {
static final int[] BUCKET_SIZES = {8192, 16384, 32768};
public static void main(String[] args) throws Exception {
if (args.length < 1) {
System.err.println("Usage: TrigramAblation <prepDir>");
System.exit(1);
}
Path prepDir = Paths.get(args[0]);
Path trainFile = prepDir.resolve("train_5m.txt");
if (!Files.exists(trainFile)) {
trainFile = prepDir.resolve("train.txt");
}
int threads = Runtime.getRuntime().availableProcessors();
System.out.println("Train file: " + trainFile);
System.out.printf(Locale.US, "Threads: %d%n", threads);
System.out.println("Loading dev + test...");
long t0 = System.nanoTime();
List<LabeledSentence> dev = readReservoir(
prepDir.resolve("dev.txt"), 100_000);
List<LabeledSentence> test = readReservoir(
prepDir.resolve("test.txt"), 200_000);
System.out.printf(Locale.US,
"Loaded: dev=%,d (%d langs) test=%,d (%d langs) [%.1f s]%n%n",
dev.size(), countLangs(dev),
test.size(), countLangs(test),
elapsed(t0));
// Header (matches AblationRunner format)
System.out.printf(Locale.US,
"%-28s %7s %4s %4s %4s %4s "
+ "%8s %8s %8s %7s%n",
"config", "buckets",
"tri", "skip", "word", "cjk",
"devF1", "langs", "testAcc", "time_s");
System.out.println("-".repeat(105));
// Configs: trigrams, skipgrams, wordUnigrams, cjkUnigrams
String[][] configs = {
{"+trigrams", "true", "false", "false", "false"},
{"+trigrams+wordUnigrams", "true", "false", "true", "false"},
};
for (String[] config : configs) {
String name = config[0];
boolean tri = Boolean.parseBoolean(config[1]);
boolean skip = Boolean.parseBoolean(config[2]);
boolean word = Boolean.parseBoolean(config[3]);
boolean cjk = Boolean.parseBoolean(config[4]);
for (int buckets : BUCKET_SIZES) {
t0 = System.nanoTime();
Phase2Trainer trainer = new Phase2Trainer(
buckets)
.setAdamLr(0.001f)
.setSgdLr(0.01f, 0.001f)
.setAdamEpochs(2)
.setMaxEpochs(6)
.setCheckpointInterval(500_000)
.setPatience(2)
.setDevSubsampleSize(10_000)
.setNumThreads(threads)
.setVerbose(false)
.setPreprocessed(true);
trainer.train(trainFile, dev);
Phase2Trainer.F1Result devF1 =
trainer.evaluateMacroF1(dev);
Phase2Trainer.F1Result testF1 =
trainer.evaluateMacroF1(test);
double secs = elapsed(t0);
System.out.printf(Locale.US,
"%-28s %7d %4s %4s %4s %4s "
+ "%8.4f %8d %8.4f %7.1f%n",
name, buckets,
tri ? "Y" : "-",
skip ? "Y" : "-",
word ? "Y" : "-",
cjk ? "Y" : "-",
devF1.f1, devF1.numLangs,
testF1.f1, secs);
}
}
System.out.println("\nDone.");
}
private static List<LabeledSentence> readReservoir(
Path file, int maxLines) throws Exception {
LabeledSentence[] reservoir =
new LabeledSentence[maxLines];
Random rng = new Random(42);
int seen = 0;
try (BufferedReader br = Files.newBufferedReader(
file, StandardCharsets.UTF_8)) {
String line;
while ((line = br.readLine()) != null) {
int tab = line.indexOf('\t');
if (tab < 0) {
continue;
}
LabeledSentence s = new LabeledSentence(
line.substring(0, tab),
line.substring(tab + 1));
if (seen < maxLines) {
reservoir[seen] = s;
} else {
int j = rng.nextInt(seen + 1);
if (j < maxLines) {
reservoir[j] = s;
}
}
seen++;
}
}
int fill = Math.min(seen, maxLines);
List<LabeledSentence> result = new ArrayList<>(fill);
for (int i = 0; i < fill; i++) {
result.add(reservoir[i]);
}
return result;
}
private static int countLangs(List<LabeledSentence> data) {
Set<String> langs = new HashSet<>();
for (LabeledSentence s : data) {
langs.add(s.getLanguage());
}
return langs.size();
}
private static double elapsed(long startNanos) {
return (System.nanoTime() - startNanos) / 1_000_000_000.0;
}
}