TrainShortModel.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup.tools;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.tika.langdetect.charsoup.CharSoupModel;
import org.apache.tika.langdetect.charsoup.ShortTextFeatureExtractor;
/**
* Trains and saves the production short-text language model.
* Always uses the ShortTextFeatureExtractor feature set
* (flags 0x0a1: trigrams + word unigrams + 4-grams) at 32 768 buckets.
*
* Usage:
* TrainShortModel prepDir trainFile outputModel
* --allowed-langs file one lang code per line (# = comment)
* --flores file FLORES-200 dev TSV for post-train eval
*/
public class TrainShortModel {
private static final int NUM_BUCKETS = 32_768;
public static void main(String[] args) throws Exception {
if (args.length < 3) {
System.err.println("Usage: TrainShortModel <prepDir> <trainFile> <outputModel>");
System.err.println(" --allowed-langs <file>");
System.err.println(" --flores <file> FLORES-200 dev TSV");
System.exit(1);
}
Path prepDir = Paths.get(args[0]);
Path trainFile = Paths.get(args[1]);
Path outputModel = Paths.get(args[2]);
Path allowedFile = null;
Path floresFile = null;
for (int i = 3; i < args.length; i++) {
switch (args[i]) {
case "--allowed-langs":
allowedFile = Paths.get(args[++i]);
break;
case "--flores":
floresFile = Paths.get(args[++i]);
break;
default:
System.err.println("Unknown arg: " + args[i]);
System.exit(1);
}
}
Set<String> allowedLangs = loadAllowedLangs(allowedFile);
List<LabeledSentence> floresData =
AblationRunner.loadFlores(floresFile);
int threads = Runtime.getRuntime().availableProcessors();
System.out.printf(Locale.ROOT,
"Training short-text model: buckets=%d flags=0x%03x threads=%d%n",
NUM_BUCKETS, ShortTextFeatureExtractor.FEATURE_FLAGS, threads);
System.out.println("Train file : " + trainFile);
System.out.println("Allowed langs : "
+ (allowedLangs != null ? allowedLangs.size() + " langs" : "all"));
System.out.println("Output : " + outputModel);
System.out.println();
long t0 = System.nanoTime();
List<LabeledSentence> dev =
AblationRunner.readReservoir(prepDir.resolve("dev.txt"), 100_000, allowedLangs);
System.out.printf(Locale.ROOT, "Dev loaded: %,d sentences (%d langs)%n%n",
dev.size(), countLangs(dev));
Phase2Trainer trainer = new Phase2Trainer(NUM_BUCKETS)
.setAdamLr(0.001f)
.setSgdLr(0.01f, 0.001f)
.setAdamEpochs(2)
.setMaxEpochs(6)
.setCheckpointInterval(500_000)
.setPatience(2)
.setDevSubsampleSize(10_000)
.setNumThreads(threads)
.setVerbose(true)
.setPreprocessed(true)
.setUseWordUnigrams(true)
.setUseTrigrams(true)
.setUseWordSuffixes(false)
.setUseWordPrefix(false)
.setUse4grams(true)
.setUse5grams(false)
.setAllowedLangs(allowedLangs);
trainer.train(trainFile, dev);
double trainSecs = (System.nanoTime() - t0) / 1e9;
System.out.printf(Locale.ROOT, "%nTraining complete in %.1f s%n", trainSecs);
int flags = ShortTextFeatureExtractor.FEATURE_FLAGS;
CharSoupModel model = ModelQuantizer.quantize(
trainer.getLabels(),
trainer.getWeightsClassMajor(),
trainer.getBiases(),
trainer.getNumBuckets(),
flags);
if (outputModel.getParent() != null) {
Files.createDirectories(outputModel.getParent());
}
try (OutputStream os = new BufferedOutputStream(Files.newOutputStream(outputModel))) {
model.save(os);
}
System.out.println("Model saved: " + outputModel);
System.out.printf(Locale.ROOT, " classes=%d buckets=%d flags=0x%03x%n",
model.getNumClasses(), model.getNumBuckets(), model.getFeatureFlags());
if (floresData != null) {
Set<String> known = new HashSet<>(trainer.getLabelIndex().keySet());
List<LabeledSentence> ff = floresData.stream()
.filter(s -> known.contains(s.getLanguage()))
.collect(Collectors.toList());
System.out.println();
for (int len : new int[]{20, 50, 100}) {
List<LabeledSentence> trunc =
CompareDetectors.truncate(ff, len);
double f1 = trainer.evaluateMacroF1(trunc).f1;
System.out.printf(Locale.ROOT, " FLORES @%-4d macro-F1: %.2f%%%n",
len, 100 * f1);
}
}
}
private static Set<String> loadAllowedLangs(Path file) throws Exception {
if (file == null) {
return null;
}
Set<String> set = new HashSet<>();
try (BufferedReader r = Files.newBufferedReader(file, StandardCharsets.UTF_8)) {
String line;
while ((line = r.readLine()) != null) {
line = line.trim();
if (!line.isEmpty() && !line.startsWith("#")) {
set.add(line);
}
}
}
return set;
}
private static int countLangs(List<LabeledSentence> data) {
Set<String> langs = new HashSet<>();
for (LabeledSentence s : data) {
langs.add(s.getLanguage());
}
return langs.size();
}
}