AbstractTestFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.DecimalParseResult;
import com.facebook.presto.common.type.Decimals;
import com.facebook.presto.common.type.SqlDecimal;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.FunctionListBuilder;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.SqlScalarFunction;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.analyzer.FunctionsConfig;
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;

import java.math.BigInteger;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.fail;

public abstract class AbstractTestFunctions
{
    private static final double DELTA = 1e-5;

    protected final Session session;
    private final FeaturesConfig featuresConfig;
    private final FunctionsConfig functionsConfig;
    protected FunctionAssertions functionAssertions;

    protected AbstractTestFunctions()
    {
        this(TEST_SESSION);
    }

    protected AbstractTestFunctions(Session session)
    {
        this(session, new FeaturesConfig(), new FunctionsConfig());
    }

    protected AbstractTestFunctions(FeaturesConfig featuresConfig)
    {
        this(TEST_SESSION, featuresConfig, new FunctionsConfig());
    }

    protected AbstractTestFunctions(FunctionsConfig functionsConfig)
    {
        this(TEST_SESSION, new FeaturesConfig(), functionsConfig);
    }

    protected AbstractTestFunctions(Session session, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig)
    {
        this.session = requireNonNull(session, "session is null");
        this.featuresConfig = requireNonNull(featuresConfig, "featuresConfig is null");
        this.functionsConfig = requireNonNull(functionsConfig, "config is null")
                .setLegacyLogFunction(true)
                .setUseNewNanDefinition(true);
    }

    @BeforeClass
    public final void initTestFunctions()
    {
        functionAssertions = new FunctionAssertions(session, featuresConfig, functionsConfig, false);
    }

    @AfterClass(alwaysRun = true)
    public final void destroyTestFunctions()
    {
        closeAllRuntimeException(functionAssertions);
        functionAssertions = null;
    }

    public FunctionAndTypeManager getFunctionAndTypeManager()
    {
        return functionAssertions.getFunctionAndTypeManager();
    }

    protected void assertFunction(String projection, Type expectedType, Object expected)
    {
        functionAssertions.assertFunction(projection, expectedType, expected);
    }

    protected void assertFunctionString(String projection, Type expectedType, String expected)
    {
        functionAssertions.assertFunctionString(projection, expectedType, expected);
    }

    protected void assertFunctionWithError(String projection, Type expectedType, Double expected)
    {
        if (expected == null) {
            assertFunction(projection, expectedType, null);
            return;
        }
        assertFunctionWithError(projection, expectedType, expected, DELTA);
    }

    protected void assertFunctionWithError(String projection, Type expectedType, double expected, double delta)
    {
        functionAssertions.assertFunctionWithError(projection, expectedType, expected, delta);
    }

    protected void assertOperator(OperatorType operator, String value, Type expectedType, Object expected)
    {
        functionAssertions.assertFunction(format("\"%s\"(%s)", operator.getFunctionName().getObjectName(), value), expectedType, expected);
    }

    protected void assertFunctionDoubleArrayWithError(String projection, Type expectedType, List<Double> expected, double delta)
    {
        functionAssertions.assertFunctionDoubleArrayWithError(projection, expectedType, expected, delta);
    }

    protected void assertFunctionFloatArrayWithError(String projection, Type expectedType, List<Float> expected, float delta)
    {
        functionAssertions.assertFunctionFloatArrayWithError(projection, expectedType, expected, delta);
    }

    protected void assertDecimalFunction(String statement, SqlDecimal expectedResult)
    {
        assertFunction(
                statement,
                createDecimalType(expectedResult.getPrecision(), expectedResult.getScale()),
                expectedResult);
    }

    protected void assertInvalidFunction(String projection, StandardErrorCode errorCode, String messagePattern)
    {
        functionAssertions.assertInvalidFunction(projection, errorCode, messagePattern);
    }

    protected void assertInvalidFunction(String projection, String messagePattern)
    {
        functionAssertions.assertInvalidFunction(projection, INVALID_FUNCTION_ARGUMENT, messagePattern);
    }

    protected void assertInvalidFunction(String projection, SemanticErrorCode expectedErrorCode)
    {
        functionAssertions.assertInvalidFunction(projection, expectedErrorCode);
    }

    protected void assertInvalidFunction(String projection, SemanticErrorCode expectedErrorCode, String message)
    {
        functionAssertions.assertInvalidFunction(projection, expectedErrorCode, message);
    }

    protected void assertInvalidFunction(String projection, ErrorCodeSupplier expectedErrorCode)
    {
        functionAssertions.assertInvalidFunction(projection, expectedErrorCode);
    }

    protected void assertFunctionThrowsIncorrectly(@Language("SQL") String projection, Class<? extends Throwable> throwableClass, @Language("RegExp") String message)
    {
        functionAssertions.assertFunctionThrowsIncorrectly(projection, throwableClass, message);
    }

    protected void assertNumericOverflow(String projection, String message)
    {
        functionAssertions.assertNumericOverflow(projection, message);
    }

    protected void assertInvalidTypeDefinition(String projection, String message)
    {
        functionAssertions.assertInvalidTypeDefinition(projection, message);
    }

    protected void assertInvalidCast(String projection)
    {
        functionAssertions.assertInvalidCast(projection);
    }

    protected void assertInvalidCast(@Language("SQL") String projection, String message)
    {
        functionAssertions.assertInvalidCast(projection, message);
    }

    public void assertCachedInstanceHasBoundedRetainedSize(String projection)
    {
        functionAssertions.assertCachedInstanceHasBoundedRetainedSize(projection);
    }

    protected void assertNotSupported(String projection, String message)
    {
        try {
            functionAssertions.executeProjectionWithFullEngine(projection);
            fail("expected exception");
        }
        catch (PrestoException e) {
            try {
                assertEquals(e.getErrorCode(), NOT_SUPPORTED.toErrorCode());
                assertEquals(e.getMessage(), message);
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    protected void tryEvaluateWithAll(String projection, Type expectedType)
    {
        functionAssertions.tryEvaluateWithAll(projection, expectedType);
    }

    protected void registerScalarFunction(SqlScalarFunction sqlScalarFunction)
    {
        Metadata metadata = functionAssertions.getMetadata();
        metadata.getFunctionAndTypeManager().registerBuiltInFunctions(ImmutableList.of(sqlScalarFunction));
    }

    protected void registerScalar(Class<?> clazz)
    {
        Metadata metadata = functionAssertions.getMetadata();
        List<SqlFunction> functions = new FunctionListBuilder()
                .scalars(clazz)
                .getFunctions();
        metadata.getFunctionAndTypeManager().registerBuiltInFunctions(functions);
    }

    protected void registerParametricScalar(Class<?> clazz)
    {
        Metadata metadata = functionAssertions.getMetadata();
        List<SqlFunction> functions = new FunctionListBuilder()
                .scalar(clazz)
                .getFunctions();
        metadata.getFunctionAndTypeManager().registerBuiltInFunctions(functions);
    }

    protected static SqlDecimal decimal(String decimalString)
    {
        DecimalParseResult parseResult = Decimals.parseIncludeLeadingZerosInPrecision(decimalString);
        BigInteger unscaledValue;
        if (parseResult.getType().isShort()) {
            unscaledValue = BigInteger.valueOf((Long) parseResult.getObject());
        }
        else {
            unscaledValue = Decimals.decodeUnscaledValue((Slice) parseResult.getObject());
        }
        return new SqlDecimal(unscaledValue, parseResult.getType().getPrecision(), parseResult.getType().getScale());
    }

    protected static SqlDecimal maxPrecisionDecimal(long value)
    {
        final String maxPrecisionFormat = "%0" + (Decimals.MAX_PRECISION + (value < 0 ? 1 : 0)) + "d";
        return decimal(format(maxPrecisionFormat, value));
    }

    // this help function should only be used when the map contains null value
    // otherwise, use ImmutableMap.of()
    protected static Map asMap(List keyList, List valueList)
    {
        if (keyList.size() != valueList.size()) {
            fail("keyList should have same size with valueList");
        }
        Map map = new HashMap<>();
        for (int i = 0; i < keyList.size(); i++) {
            if (map.put(keyList.get(i), valueList.get(i)) != null) {
                fail("keyList should have same size with valueList");
            }
        }
        return map;
    }
}