TestRowExpressionSerde.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql;

import com.facebook.airlift.bootstrap.Bootstrap;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.json.JsonModule;
import com.facebook.airlift.stats.cardinality.HyperLogLog;
import com.facebook.presto.block.BlockJsonSerde;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockEncoding;
import com.facebook.presto.common.block.BlockEncodingManager;
import com.facebook.presto.common.block.BlockEncodingSerde;
import com.facebook.presto.common.block.IntArrayBlock;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.HandleJsonModule;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.sql.analyzer.ExpressionAnalyzer;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.analyzer.Scope;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.relational.RowExpressionOptimizer;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.type.TypeDeserializer;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.Scopes;
import io.airlift.slice.Slice;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Map;

import static com.facebook.airlift.configuration.ConfigBinder.configBinder;
import static com.facebook.airlift.json.JsonBinder.jsonBinder;
import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.function.OperatorType.SUBSCRIPT;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.ROW_CONSTRUCTOR;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment.AS_DOUBLE;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.inject.multibindings.Multibinder.newSetBinder;
import static java.lang.Float.floatToIntBits;
import static java.util.Collections.emptyMap;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;

public class TestRowExpressionSerde
{
    private final Metadata metadata = MetadataManager.createTestMetadataManager();
    private JsonCodec<RowExpression> codec;

    @BeforeClass
    public void setUp()
            throws Exception
    {
        codec = getJsonCodec();
    }

    @Test
    public void testSimpleLiteral()
    {
        assertLiteral("TRUE", constant(true, BOOLEAN));
        assertLiteral("FALSE", constant(false, BOOLEAN));
        assertLiteral("CAST(NULL AS BOOLEAN)", constant(null, BOOLEAN));

        assertLiteral("TINYINT '1'", constant(1L, TINYINT));
        assertLiteral("SMALLINT '1'", constant(1L, SMALLINT));
        assertLiteral("1", constant(1L, INTEGER));
        assertLiteral("BIGINT '1'", constant(1L, BIGINT));

        assertLiteral("1.1", constant(1.1, DOUBLE));
        assertLiteral("nan()", constant(Double.NaN, DOUBLE));
        assertLiteral("infinity()", constant(Double.POSITIVE_INFINITY, DOUBLE));
        assertLiteral("-infinity()", constant(Double.NEGATIVE_INFINITY, DOUBLE));

        assertLiteral("CAST(1.1 AS REAL)", constant((long) floatToIntBits(1.1f), REAL));
        assertLiteral("CAST(nan() AS REAL)", constant((long) floatToIntBits(Float.NaN), REAL));
        assertLiteral("CAST(infinity() AS REAL)", constant((long) floatToIntBits(Float.POSITIVE_INFINITY), REAL));
        assertLiteral("CAST(-infinity() AS REAL)", constant((long) floatToIntBits(Float.NEGATIVE_INFINITY), REAL));

        assertStringLiteral("'String Literal'", "String Literal", VarcharType.createVarcharType(14));
        assertLiteral("CAST(NULL AS VARCHAR)", constant(null, VARCHAR));

        assertLiteral("DATE '1991-01-01'", constant(7670L, DATE));
        assertLiteral("TIMESTAMP '1991-01-01 00:00:00.000'", constant(662727600000L, TIMESTAMP));
    }

    @Test
    public void testArrayLiteral()
    {
        RowExpression rowExpression = getRoundTrip("ARRAY [1, 2, 3]", true);
        assertTrue(rowExpression instanceof ConstantExpression);
        Object value = ((ConstantExpression) rowExpression).getValue();
        assertTrue(value instanceof IntArrayBlock);
        IntArrayBlock block = (IntArrayBlock) value;
        assertEquals(block.getPositionCount(), 3);
        assertEquals(block.getInt(0), 1);
        assertEquals(block.getInt(1), 2);
        assertEquals(block.getInt(2), 3);
    }

    @Test
    public void testArrayGet()
    {
        assertEquals(getRoundTrip("(ARRAY [1, 2, 3])[1]", false),
                call(SUBSCRIPT.name(),
                        operator(SUBSCRIPT, new ArrayType(INTEGER), BIGINT),
                        INTEGER,
                        call("array_constructor",
                                function("array_constructor", INTEGER, INTEGER, INTEGER),
                                new ArrayType(INTEGER),
                                constant(1L, INTEGER),
                                constant(2L, INTEGER),
                                constant(3L, INTEGER)),
                        constant(1L, INTEGER)));
        assertEquals(getRoundTrip("(ARRAY [1, 2, 3])[1]", true), constant(1L, INTEGER));
    }

    @Test
    public void testRowLiteral()
    {
        assertEquals(getRoundTrip("ROW(1, 1.1)", false),
                specialForm(
                        ROW_CONSTRUCTOR,
                        RowType.anonymous(
                                ImmutableList.of(
                                        INTEGER,
                                        DOUBLE)),
                        constant(1L, INTEGER),
                        constant(1.1, DOUBLE)));
    }

    @Test
    public void testDereference()
    {
        String sql = "CAST(ROW(1) AS ROW(col1 integer)).col1";
        RowExpression before = translate(expression(sql, new ParsingOptions(AS_DOUBLE)), false);
        RowExpression after = getRoundTrip(sql, false);
        assertEquals(before, after);
    }

    @Test
    public void testHllLiteral()
    {
        RowExpression rowExpression = getRoundTrip("empty_approx_set()", true);
        assertTrue(rowExpression instanceof ConstantExpression);
        Object value = ((ConstantExpression) rowExpression).getValue();
        assertEquals(HyperLogLog.newInstance((Slice) value).cardinality(), 0);
    }

    @Test
    public void testUnserializableType()
    {
        assertThrowsWhenSerialize("CAST('$.a' AS JsonPath)", true);
    }

    private void assertThrowsWhenSerialize(@Language("SQL") String sql, boolean optimize)
    {
        RowExpression rowExpression = translate(expression(sql, new ParsingOptions(AS_DOUBLE)), optimize);
        assertThrows(IllegalArgumentException.class, () -> codec.toJson(rowExpression));
    }

    private void assertLiteral(@Language("SQL") String sql, ConstantExpression expected)
    {
        assertEquals(getRoundTrip(sql, true), expected);
    }

    private void assertStringLiteral(@Language("SQL") String sql, String expectedString, Type expectedType)
    {
        RowExpression roundTrip = getRoundTrip(sql, true);
        assertTrue(roundTrip instanceof ConstantExpression);
        String roundTripValue = ((Slice) ((ConstantExpression) roundTrip).getValue()).toStringUtf8();
        Type roundTripType = roundTrip.getType();
        assertEquals(roundTripValue, expectedString);
        assertEquals(roundTripType, expectedType);
    }

    private RowExpression getRoundTrip(String sql, boolean optimize)
    {
        RowExpression rowExpression = translate(expression(sql, new ParsingOptions(AS_DOUBLE)), optimize);
        String json = codec.toJson(rowExpression);
        return codec.fromJson(json);
    }

    private FunctionHandle operator(OperatorType operatorType, Type... types)
    {
        return metadata.getFunctionAndTypeManager().resolveOperator(operatorType, fromTypes(types));
    }

    private FunctionHandle function(String name, Type... types)
    {
        return metadata.getFunctionAndTypeManager().lookupFunction(name, fromTypes(types));
    }

    private JsonCodec<RowExpression> getJsonCodec()
            throws Exception
    {
        Module module = binder -> {
            binder.install(new JsonModule());
            binder.install(new HandleJsonModule());
            configBinder(binder).bindConfig(FeaturesConfig.class);

            FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
            binder.bind(TypeManager.class).toInstance(functionAndTypeManager);
            jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class);
            newSetBinder(binder, Type.class);

            binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON);
            newSetBinder(binder, BlockEncoding.class);
            jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class);
            jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class);
            jsonCodecBinder(binder).bindJsonCodec(RowExpression.class);
        };
        Bootstrap app = new Bootstrap(ImmutableList.of(module));
        Injector injector = app
                .doNotInitializeLogging()
                .quiet()
                .initialize();
        return injector.getInstance(new Key<JsonCodec<RowExpression>>() {});
    }

    private RowExpression translate(Expression expression, boolean optimize)
    {
        RowExpression rowExpression = SqlToRowExpressionTranslator.translate(expression, getExpressionTypes(expression), ImmutableMap.of(), metadata.getFunctionAndTypeManager(), TEST_SESSION);
        if (optimize) {
            RowExpressionOptimizer optimizer = new RowExpressionOptimizer(metadata);
            return optimizer.optimize(rowExpression, OPTIMIZED, TEST_SESSION.toConnectorSession());
        }
        return rowExpression;
    }

    private Map<NodeRef<Expression>, Type> getExpressionTypes(Expression expression)
    {
        ExpressionAnalyzer expressionAnalyzer = ExpressionAnalyzer.createWithoutSubqueries(
                metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver(),
                TEST_SESSION,
                TypeProvider.empty(),
                emptyMap(),
                node -> new IllegalStateException("Unexpected node: %s" + node),
                WarningCollector.NOOP,
                false);
        expressionAnalyzer.analyze(expression, Scope.create());
        return expressionAnalyzer.getExpressionTypes();
    }
}