Phase2SmokeTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Random;
import java.util.Set;
import org.apache.tika.langdetect.charsoup.CharSoupModel;
/**
* Smoke test for Phase2Trainer. Streams training data from
* disk, trains at 8k buckets, quantizes, evaluates.
* <p>
* Usage: Phase2SmokeTest <prepDir> [outputModel]
* <p>
* Expects prepDir to contain train.txt (or train_2m.txt),
* dev.txt, and test.txt in tab-separated format.
*/
public class Phase2SmokeTest {
public static void main(String[] args) throws Exception {
if (args.length < 1) {
System.err.println(
"Usage: Phase2SmokeTest <prepDir> "
+ "[outputModel]");
System.exit(1);
}
Path prepDir = Paths.get(args[0]);
Path outputModel = args.length > 1
? Paths.get(args[1]) : null;
int numBuckets = 8192;
int threads = Runtime.getRuntime()
.availableProcessors();
// Determine train file
Path trainFile = prepDir.resolve("train_2m.txt");
if (!Files.exists(trainFile)) {
trainFile = prepDir.resolve("train.txt");
}
System.out.println("Train file: " + trainFile);
// Load dev and test (these are small enough for RAM)
System.out.println("Loading dev + test...");
long t0 = System.nanoTime();
List<LabeledSentence> dev = readCapped(
prepDir.resolve("dev.txt"), 100_000);
List<LabeledSentence> test = readCapped(
prepDir.resolve("test.txt"), 200_000);
System.out.printf(Locale.US,
"Loaded: dev=%,d (%d langs) "
+ "test=%,d (%d langs) [%.1f s]%n",
dev.size(), countLangs(dev),
test.size(), countLangs(test),
elapsed(t0));
// Train ��� streams from disk
System.out.println(
"\n=== Phase 2 Training (8k, streaming) ===");
float adamLr = Float.parseFloat(
System.getProperty("adamLr", "0.001"));
int adamT = Integer.parseInt(
System.getProperty("adamThreads", "1"));
int sgdT = Integer.parseInt(
System.getProperty("sgdThreads",
String.valueOf(threads)));
Phase2Trainer trainer =
new Phase2Trainer(numBuckets)
.setAdamLr(adamLr)
.setSgdLr(0.01f, 0.001f)
.setAdamEpochs(2)
.setMaxEpochs(4)
.setCheckpointInterval(200_000)
.setPatience(2)
.setDevSubsampleSize(10_000)
.setAdamThreads(adamT)
.setSgdThreads(sgdT)
.setPreprocessed(true);
t0 = System.nanoTime();
trainer.train(trainFile, dev);
System.out.printf(Locale.US,
"Training time: %.1f s%n", elapsed(t0));
// Quantize
System.out.println("\nQuantizing...");
t0 = System.nanoTime();
CharSoupModel model =
ModelQuantizer.quantize(trainer);
System.out.printf(Locale.US,
"Quantized [%.1f s]%n", elapsed(t0));
// Evaluate
System.out.println(
"Evaluating on test set...");
t0 = System.nanoTime();
double testAcc =
TrainLanguageModel.evaluateQuantized(
model, test).accuracy;
System.out.printf(Locale.US,
"Test accuracy (quantized): %.4f "
+ "[%.1f s]%n",
testAcc, elapsed(t0));
// Float32 dev macro F1
Phase2Trainer.F1Result devF1 =
trainer.evaluateMacroF1(dev);
System.out.printf(Locale.US,
"Dev macro F1 (float32): %.4f (%d langs)%n",
devF1.f1, devF1.numLangs);
// Save
if (outputModel != null) {
Files.createDirectories(
outputModel.getParent() != null
? outputModel.getParent()
: Paths.get("."));
try (OutputStream os = new BufferedOutputStream(
Files.newOutputStream(outputModel))) {
model.save(os);
}
System.out.printf(Locale.US,
"Model saved: %s (%.1f KB)%n",
outputModel,
Files.size(outputModel) / 1024.0);
}
System.out.println("\nDone.");
}
/**
* Reservoir-sample up to {@code maxLines} from a
* preprocessed file. Guarantees a uniform random
* sample across the entire file, critical when the
* file is sorted by language.
*/
private static List<LabeledSentence> readCapped(
Path file, int maxLines) throws Exception {
LabeledSentence[] reservoir =
new LabeledSentence[maxLines];
Random rng = new Random(42);
int seen = 0;
try (BufferedReader br = Files.newBufferedReader(
file, StandardCharsets.UTF_8)) {
String line;
while ((line = br.readLine()) != null) {
int tab = line.indexOf('\t');
if (tab < 0) {
continue;
}
LabeledSentence s = new LabeledSentence(
line.substring(0, tab),
line.substring(tab + 1));
if (seen < maxLines) {
reservoir[seen] = s;
} else {
int j = rng.nextInt(seen + 1);
if (j < maxLines) {
reservoir[j] = s;
}
}
seen++;
}
}
int fill = Math.min(seen, maxLines);
List<LabeledSentence> result =
new ArrayList<>(fill);
for (int i = 0; i < fill; i++) {
result.add(reservoir[i]);
}
return result;
}
private static int countLangs(
List<LabeledSentence> data) {
Set<String> langs = new HashSet<>();
for (LabeledSentence s : data) {
langs.add(s.getLanguage());
}
return langs.size();
}
private static double elapsed(long startNanos) {
return (System.nanoTime() - startNanos)
/ 1_000_000_000.0;
}
}