TestModelSerialization.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.ml;

import io.airlift.slice.Slice;
import org.testng.annotations.Test;

import static com.facebook.presto.ml.TestUtils.getDataset;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;

public class TestModelSerialization
{
    @Test
    public void testSvmClassifier()
    {
        Model model = new SvmClassifier();
        model.train(getDataset());
        Slice serialized = ModelUtils.serialize(model);
        Model deserialized = ModelUtils.deserialize(serialized);
        assertNotNull(deserialized, "deserialization failed");
        assertTrue(deserialized instanceof SvmClassifier, "deserialized model is not a svm");
    }

    @Test
    public void testSvmRegressor()
    {
        Model model = new SvmRegressor();
        model.train(getDataset());
        Slice serialized = ModelUtils.serialize(model);
        Model deserialized = ModelUtils.deserialize(serialized);
        assertNotNull(deserialized, "deserialization failed");
        assertTrue(deserialized instanceof SvmRegressor, "deserialized model is not a svm");
    }

    @Test
    public void testRegressorFeatureTransformer()
    {
        Model model = new RegressorFeatureTransformer(new SvmRegressor(), new FeatureVectorUnitNormalizer());
        model.train(getDataset());
        Slice serialized = ModelUtils.serialize(model);
        Model deserialized = ModelUtils.deserialize(serialized);
        assertNotNull(deserialized, "deserialization failed");
        assertTrue(deserialized instanceof RegressorFeatureTransformer, "deserialized model is not a regressor feature transformer");
    }

    @Test
    public void testClassifierFeatureTransformer()
    {
        Model model = new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer());
        model.train(getDataset());
        Slice serialized = ModelUtils.serialize(model);
        Model deserialized = ModelUtils.deserialize(serialized);
        assertNotNull(deserialized, "deserialization failed");
        assertTrue(deserialized instanceof ClassifierFeatureTransformer, "deserialized model is not a classifier feature transformer");
    }

    @Test
    public void testVarcharClassifierAdapter()
    {
        Model model = new StringClassifierAdapter(new ClassifierFeatureTransformer(new SvmClassifier(), new FeatureVectorUnitNormalizer()));
        model.train(getDataset());
        Slice serialized = ModelUtils.serialize(model);
        Model deserialized = ModelUtils.deserialize(serialized);
        assertNotNull(deserialized, "deserialization failed");
        assertTrue(deserialized instanceof StringClassifierAdapter, "deserialized model is not a varchar classifier adapter");
    }

    @Test
    public void testSerializationIds()
    {
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmClassifier.class), 1);
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(SvmRegressor.class), 2);
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureVectorUnitNormalizer.class), 3);
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(ClassifierFeatureTransformer.class), 4);
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(RegressorFeatureTransformer.class), 5);
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(FeatureUnitNormalizer.class), 6);
        assertEquals((int) ModelUtils.MODEL_SERIALIZATION_IDS.get(StringClassifierAdapter.class), 7);
    }
}