LinearModelTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tika.langdetect.charsoup;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import org.junit.jupiter.api.Test;
import org.apache.tika.ml.LinearModel;
public class LinearModelTest {
@Test
public void testRoundTrip() throws IOException {
// Build a tiny 3-class, 256-bucket model
int numBuckets = 256;
int numClasses = 3;
String[] labels = {"eng", "deu", "fra"};
float[] scales = {0.01f, 0.02f, 0.015f};
float[] biases = {0.1f, -0.05f, 0.0f};
byte[][] weights = new byte[numClasses][numBuckets];
// Set some non-zero weights
weights[0][0] = 127;
weights[0][1] = -127;
weights[1][10] = 50;
weights[2][100] = -100;
LinearModel original = new LinearModel(numBuckets, numClasses, labels, scales, biases,
weights);
// Save and reload
ByteArrayOutputStream baos = new ByteArrayOutputStream();
original.save(baos);
byte[] bytes = baos.toByteArray();
LinearModel loaded = LinearModel.load(new ByteArrayInputStream(bytes));
assertEquals(numBuckets, loaded.getNumBuckets());
assertEquals(numClasses, loaded.getNumClasses());
assertArrayEquals(labels, loaded.getLabels());
for (int c = 0; c < numClasses; c++) {
assertEquals(scales[c], loaded.getScales()[c], 1e-6);
assertEquals(biases[c], loaded.getBiases()[c], 1e-6);
assertArrayEquals(weights[c], loaded.getWeights()[c]);
}
}
@Test
public void testSoftmax() {
float[] logits = {1.0f, 2.0f, 3.0f};
float[] probs = LinearModel.softmax(logits);
// Should sum to ~1.0
float sum = 0f;
for (float p : probs) sum += p;
assertEquals(1.0f, sum, 1e-5);
// Highest logit should have highest probability
assertTrue(probs[2] > probs[1]);
assertTrue(probs[1] > probs[0]);
}
@Test
public void testSoftmaxNumericalStability() {
// Very large logits should not overflow
float[] logits = {1000.0f, 1001.0f, 999.0f};
float[] probs = LinearModel.softmax(logits);
float sum = 0f;
for (float p : probs) {
assertTrue(Float.isFinite(p));
sum += p;
}
assertEquals(1.0f, sum, 1e-5);
}
@Test
public void testPredict() {
// Simple model where class 0 has high weight on bucket 0
int numBuckets = 4;
int numClasses = 2;
String[] labels = {"eng", "deu"};
float[] scales = {1.0f, 1.0f};
float[] biases = {0.0f, 0.0f};
byte[][] weights = new byte[numClasses][numBuckets];
weights[0][0] = 127; // eng strongly triggered by bucket 0
weights[1][1] = 127; // deu strongly triggered by bucket 1
LinearModel model = new LinearModel(numBuckets, numClasses, labels, scales, biases,
weights);
// Feature vector activating bucket 0
int[] features0 = {10, 0, 0, 0};
float[] probs0 = model.predict(features0);
assertTrue(probs0[0] > probs0[1], "eng should win when bucket 0 is active");
// Feature vector activating bucket 1
int[] features1 = {0, 10, 0, 0};
float[] probs1 = model.predict(features1);
assertTrue(probs1[1] > probs1[0], "deu should win when bucket 1 is active");
}
@Test
public void testCorruptMagicThrows() {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
DataOutputStream dos = new DataOutputStream(baos);
try {
dos.writeInt(0xDEADBEEF); // wrong magic
dos.flush();
} catch (IOException e) {
throw new RuntimeException(e);
}
assertThrows(IOException.class,
() -> LinearModel.load(new ByteArrayInputStream(baos.toByteArray())));
}
@Test
public void testSaveHeaderFormat() throws IOException {
LinearModel model = new LinearModel(128, 2, new String[]{"a", "b"},
new float[]{1f, 2f}, new float[]{0f, 0f}, new byte[2][128]);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
model.save(baos);
byte[] data = baos.toByteArray();
// Check magic bytes (big-endian "LDM1" = 0x4C444D31)
assertEquals(0x4C, data[0] & 0xFF);
assertEquals(0x44, data[1] & 0xFF);
assertEquals(0x4D, data[2] & 0xFF);
assertEquals(0x31, data[3] & 0xFF);
}
}