FunctionAssertions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar;

import com.facebook.presto.Session;
import com.facebook.presto.common.InvalidTypeDefinitionException;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.Utils;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.FunctionListBuilder;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.operator.DriverContext;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.FilterAndProjectOperator.FilterAndProjectOperatorFactory;
import com.facebook.presto.operator.Operator;
import com.facebook.presto.operator.OperatorFactory;
import com.facebook.presto.operator.ScanFilterAndProjectOperator;
import com.facebook.presto.operator.SourceOperator;
import com.facebook.presto.operator.SourceOperatorFactory;
import com.facebook.presto.operator.project.CursorProcessor;
import com.facebook.presto.operator.project.PageProcessor;
import com.facebook.presto.operator.project.PageProjectionWithOutputs;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorPageSource;
import com.facebook.presto.spi.ConnectorSplit;
import com.facebook.presto.spi.ConnectorTableHandle;
import com.facebook.presto.spi.ErrorCodeSupplier;
import com.facebook.presto.spi.FixedPageSource;
import com.facebook.presto.spi.HostAddress;
import com.facebook.presto.spi.InMemoryRecordSet;
import com.facebook.presto.spi.NodeProvider;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.RecordPageSource;
import com.facebook.presto.spi.RecordSet;
import com.facebook.presto.spi.StandardErrorCode;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
import com.facebook.presto.spi.function.SqlFunction;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.spi.schedule.NodeSelectionStrategy;
import com.facebook.presto.split.PageSourceProvider;
import com.facebook.presto.sql.analyzer.ExpressionAnalysis;
import com.facebook.presto.sql.analyzer.FeaturesConfig;
import com.facebook.presto.sql.analyzer.FunctionsConfig;
import com.facebook.presto.sql.analyzer.SemanticErrorCode;
import com.facebook.presto.sql.analyzer.SemanticException;
import com.facebook.presto.sql.gen.ExpressionCompiler;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.DefaultTraversalVisitor;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.testing.LocalQueryRunner;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Iterators;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import org.intellij.lang.annotations.Language;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;
import org.openjdk.jol.info.ClassLayout;

import java.io.Closeable;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;

import static com.facebook.airlift.concurrent.Threads.daemonThreadsNamed;
import static com.facebook.airlift.testing.Assertions.assertInstanceOf;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.block.BlockAssertions.createBooleansBlock;
import static com.facebook.presto.block.BlockAssertions.createDoublesBlock;
import static com.facebook.presto.block.BlockAssertions.createIntsBlock;
import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
import static com.facebook.presto.block.BlockAssertions.createRowBlock;
import static com.facebook.presto.block.BlockAssertions.createSlicesBlock;
import static com.facebook.presto.block.BlockAssertions.createStringsBlock;
import static com.facebook.presto.block.BlockAssertions.createTimestampsWithTimezoneBlock;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateTimeEncoding.packDateTimeWithZone;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.geospatial.type.GeometryType.GEOMETRY;
import static com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_TYPE_DEFINITION;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static com.facebook.presto.spi.schedule.NodeSelectionStrategy.HARD_AFFINITY;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.analyzeExpressions;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.planner.iterative.rule.CanonicalizeExpressionRewriter.canonicalizeExpression;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.SqlToRowExpressionTranslator.translate;
import static com.facebook.presto.testing.TestingTaskContext.createTaskContext;
import static com.facebook.presto.util.AnalyzerUtil.createParsingOptions;
import static com.facebook.presto.util.Failures.toFailure;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.lang.String.format;
import static java.util.Collections.emptyMap;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.Executors.newCachedThreadPool;
import static java.util.concurrent.Executors.newScheduledThreadPool;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public final class FunctionAssertions
        implements Closeable
{
    private static final ExecutorService EXECUTOR = newCachedThreadPool(daemonThreadsNamed("test-%s"));
    private static final ScheduledExecutorService SCHEDULED_EXECUTOR = newScheduledThreadPool(2, daemonThreadsNamed("test-scheduledExecutor-%s"));

    private static final SqlParser SQL_PARSER = new SqlParser();

    // Increase the number of fields to generate a wide column
    private static final int TEST_ROW_NUMBER_OF_FIELDS = 2500;
    private static final RowType TEST_ROW_TYPE = createTestRowType(TEST_ROW_NUMBER_OF_FIELDS);
    private static final Block TEST_ROW_DATA = createTestRowData(TEST_ROW_TYPE);

    private static final Page SOURCE_PAGE = new Page(
            createLongsBlock(1234L),
            createStringsBlock("hello"),
            createDoublesBlock(12.34),
            createBooleansBlock(true),
            createLongsBlock(new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis()),
            createStringsBlock("%el%"),
            createStringsBlock((String) null),
            createTimestampsWithTimezoneBlock(packDateTimeWithZone(new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z"))),
            createSlicesBlock(Slices.wrappedBuffer((byte) 0xab)),
            createIntsBlock(1234),
            TEST_ROW_DATA);

    private static final Page ZERO_CHANNEL_PAGE = new Page(1);

    private static final Map<VariableReferenceExpression, Integer> INPUT_MAPPING = ImmutableMap.<VariableReferenceExpression, Integer>builder()
            .put(new VariableReferenceExpression(Optional.empty(), "bound_long", BIGINT), 0)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_string", VARCHAR), 1)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_double", DOUBLE), 2)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_boolean", BOOLEAN), 3)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_timestamp", BIGINT), 4)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_pattern", VARCHAR), 5)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_null_string", VARCHAR), 6)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_timestamp_with_timezone", TIMESTAMP_WITH_TIME_ZONE), 7)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_binary_literal", VARBINARY), 8)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_integer", INTEGER), 9)
            .put(new VariableReferenceExpression(Optional.empty(), "bound_row", TEST_ROW_TYPE), 10)
            .build();

    private static final TypeProvider SYMBOL_TYPES = TypeProvider.fromVariables(INPUT_MAPPING.keySet());

    private static final PageSourceProvider PAGE_SOURCE_PROVIDER = new TestPageSourceProvider();
    private static final PlanNodeId SOURCE_ID = new PlanNodeId("scan");

    private final Session session;
    private final LocalQueryRunner runner;
    private final Metadata metadata;
    private final ExpressionCompiler compiler;

    public FunctionAssertions()
    {
        this(TEST_SESSION);
    }

    public FunctionAssertions(Session session)
    {
        this(session, new FeaturesConfig(), new FunctionsConfig(), false);
    }

    public FunctionAssertions(Session session, FeaturesConfig featuresConfig)
    {
        this(session, featuresConfig, new FunctionsConfig(), false);
    }

    public FunctionAssertions(Session session, FunctionsConfig functionsConfig)
    {
        this(session, new FeaturesConfig(), functionsConfig, false);
    }

    public FunctionAssertions(Session session, FeaturesConfig featuresConfig, FunctionsConfig functionsConfig, boolean refreshSession)
    {
        requireNonNull(session, "session is null");
        runner = new LocalQueryRunner(session, featuresConfig, functionsConfig);
        if (refreshSession) {
            this.session = runner.getDefaultSession();
        }
        else {
            this.session = session;
        }
        metadata = runner.getMetadata();
        compiler = runner.getExpressionCompiler();
    }
    public FunctionAndTypeManager getFunctionAndTypeManager()
    {
        return runner.getFunctionAndTypeManager();
    }

    public Metadata getMetadata()
    {
        return metadata;
    }

    public FunctionAssertions addFunctions(List<? extends SqlFunction> functionInfos)
    {
        metadata.registerBuiltInFunctions(functionInfos);
        return this;
    }

    public FunctionAssertions addScalarFunctions(Class<?> clazz)
    {
        metadata.registerBuiltInFunctions(new FunctionListBuilder().scalars(clazz).getFunctions());
        return this;
    }

    public void assertFunction(String projection, Type expectedType, Object expected)
    {
        if (expected instanceof Slice) {
            expected = ((Slice) expected).toStringUtf8();
        }

        Object actual = selectSingleValue(projection, expectedType, compiler);
        assertEquals(actual, expected);
    }

    public void assertFunctionWithError(String projection, Type expectedType, double expected, double delta)
    {
        Number actual = (Number) selectSingleValue(projection, expectedType, compiler);
        assertEquals(actual.doubleValue(), expected, delta);
    }

    public void assertFunctionDoubleArrayWithError(String projection, Type expectedType, List<Double> expected, double delta)
    {
        Object actual = selectSingleValue(projection, expectedType, compiler);
        assertTrue(actual instanceof ArrayList);
        ArrayList<Object> arrayList = (ArrayList) actual;
        assertEquals(arrayList.size(), expected.size());
        for (int i = 0; i < arrayList.size(); ++i) {
            assertEquals((double) arrayList.get(i), expected.get(i), delta);
        }
    }

    public void assertFunctionFloatArrayWithError(String projection, Type expectedType, List<Float> expected, float delta)
    {
        Object actual = selectSingleValue(projection, expectedType, compiler);
        assertTrue(actual instanceof ArrayList);
        ArrayList<Object> arrayList = (ArrayList) actual;
        assertEquals(arrayList.size(), expected.size());
        for (int i = 0; i < arrayList.size(); ++i) {
            assertEquals((float) arrayList.get(i), expected.get(i), delta);
        }
    }

    public void assertFunctionString(String projection, Type expectedType, String expected)
    {
        Object actual = selectSingleValue(projection, expectedType, compiler);
        assertEquals(actual.toString(), expected);
    }

    public void tryEvaluate(String expression, Type expectedType)
    {
        tryEvaluate(expression, expectedType, session);
    }

    private void tryEvaluate(String expression, Type expectedType, Session session)
    {
        selectUniqueValue(expression, expectedType, session, compiler);
    }

    public void tryEvaluateWithAll(String expression, Type expectedType)
    {
        tryEvaluateWithAll(expression, expectedType, session);
    }

    public void tryEvaluateWithAll(String expression, Type expectedType, Session session)
    {
        executeProjectionWithAll(expression, expectedType, session, compiler);
    }

    public void executeProjectionWithFullEngine(String projection)
    {
        MaterializedResult result = runner.execute("SELECT " + projection);
    }

    public <T> T selectSingleValue(String projection, Type expectedType, Class<T> clazz)
    {
        Object object = selectSingleValue(projection, expectedType, compiler);
        assertEquals(object.getClass(), clazz);
        return (T) object;
    }

    private Object selectSingleValue(String projection, Type expectedType, ExpressionCompiler compiler)
    {
        return selectUniqueValue(projection, expectedType, session, compiler);
    }

    private Object selectUniqueValue(String projection, Type expectedType, Session session, ExpressionCompiler compiler)
    {
        List<Object> results = executeProjectionWithAll(projection, expectedType, session, compiler);
        HashSet<Object> resultSet = new HashSet<>(results);

        // we should only have a single result
        assertEquals(resultSet.size(), 1, "Expected only one result unique result, but got " + resultSet);

        return Iterables.getOnlyElement(resultSet);
    }

    public void assertInvalidFunction(String projection, StandardErrorCode errorCode, String messagePattern)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail("Expected to throw a PrestoException with message matching " + messagePattern + " but got " + value);
        }
        catch (PrestoException e) {
            try {
                assertEquals(e.getErrorCode(), errorCode.toErrorCode());
                assertTrue(e.getMessage().equals(messagePattern) || e.getMessage().matches(messagePattern), format("Error message [%s] doesn't match [%s]", e.getMessage(), messagePattern));
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertInvalidFunction(String projection, String messagePattern)
    {
        assertInvalidFunction(projection, INVALID_FUNCTION_ARGUMENT, messagePattern);
    }

    public void assertInvalidFunction(String projection, SemanticErrorCode expectedErrorCode)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail(format("Expected to throw %s exception but got %s", expectedErrorCode, value));
        }
        catch (SemanticException e) {
            try {
                assertEquals(e.getCode(), expectedErrorCode);
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertInvalidFunction(String projection, SemanticErrorCode expectedErrorCode, String message)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail(format("Expected to throw %s exception but got %s", expectedErrorCode, value));
        }
        catch (SemanticException e) {
            try {
                assertEquals(e.getCode(), expectedErrorCode);
                assertEquals(e.getMessage(), message);
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertInvalidTypeDefinition(String projection, String message)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail("Expected to throw an INVALID_CAST_ARGUMENT exception, but got " + value);
        }
        catch (InvalidTypeDefinitionException e) {
            try {
                assertEquals(toFailure(e).getErrorCode(), INVALID_TYPE_DEFINITION.toErrorCode());
                assertEquals(e.getMessage(), message);
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertInvalidFunction(String projection, ErrorCodeSupplier expectedErrorCode)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail(format("Expected to throw %s exception but got %s", expectedErrorCode, value));
        }
        catch (PrestoException e) {
            try {
                assertEquals(e.getErrorCode(), expectedErrorCode.toErrorCode());
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertFunctionThrowsIncorrectly(@Language("SQL") String projection, Class<? extends Throwable> throwableClass, @Language("RegExp") String message)
    {
        assertThatThrownBy(() -> evaluateInvalid(projection))
                .isInstanceOf(throwableClass)
                .hasMessageMatching(message);
    }

    public void assertNumericOverflow(String projection, String message)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail("Expected to throw an NUMERIC_VALUE_OUT_OF_RANGE exception with message " + message + " but got " + value);
        }
        catch (PrestoException e) {
            try {
                assertEquals(e.getErrorCode(), NUMERIC_VALUE_OUT_OF_RANGE.toErrorCode());
                assertEquals(e.getMessage(), message);
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertInvalidCast(String projection)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail("Expected to throw an INVALID_CAST_ARGUMENT exception but got " + value);
        }
        catch (PrestoException e) {
            try {
                assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    public void assertInvalidCast(String projection, String message)
    {
        try {
            Object value = evaluateInvalid(projection);
            fail("Expected to throw an INVALID_CAST_ARGUMENT exception, but got " + value);
        }
        catch (PrestoException e) {
            try {
                assertEquals(e.getErrorCode(), INVALID_CAST_ARGUMENT.toErrorCode());
                assertEquals(e.getMessage(), message);
            }
            catch (Throwable failure) {
                failure.addSuppressed(e);
                throw failure;
            }
        }
    }

    private Object evaluateInvalid(String projection)
    {
        return selectSingleValue(projection, GEOMETRY, compiler);
    }

    public void assertCachedInstanceHasBoundedRetainedSize(String projection)
    {
        requireNonNull(projection, "projection is null");

        Expression projectionExpression = createExpression(session, projection, metadata, SYMBOL_TYPES);
        RowExpression projectionRowExpression = toRowExpression(session, projectionExpression);
        PageProcessor processor = compiler.compilePageProcessor(session.getSqlFunctionProperties(), Optional.empty(), ImmutableList.of(projectionRowExpression)).get();

        // This is a heuristic to detect whether the retained size of cachedInstance is bounded.
        // * The test runs at least 1000 iterations.
        // * The test passes if max retained size doesn't refresh after
        //   4x the number of iterations when max was last updated.
        // * The test fails if retained size reaches 1MB.
        // Note that 1MB is arbitrarily chosen and may be increased if a function implementation
        // legitimately needs more.

        long maxRetainedSize = 0;
        int maxIterationCount = 0;
        for (int iterationCount = 0; iterationCount < Math.max(1000, maxIterationCount * 4); iterationCount++) {
            Iterator<Optional<Page>> output = processor.process(
                    session.getSqlFunctionProperties(),
                    new DriverYieldSignal(),
                    newSimpleAggregatedMemoryContext().newLocalMemoryContext(PageProcessor.class.getSimpleName()),
                    SOURCE_PAGE);
            // consume the iterator
            Iterators.getOnlyElement(output);

            long retainedSize = processor.getProjections().stream()
                    .mapToLong(this::getRetainedSizeOfCachedInstance)
                    .sum();
            if (retainedSize > maxRetainedSize) {
                maxRetainedSize = retainedSize;
                maxIterationCount = iterationCount;
            }

            if (maxRetainedSize >= 1048576) {
                fail(format("The retained size of cached instance of function invocation is likely unbounded: %s", projection));
            }
        }
    }

    private long getRetainedSizeOfCachedInstance(PageProjectionWithOutputs projection)
    {
        Field[] fields = projection.getPageProjection().getClass().getDeclaredFields();
        long retainedSize = 0;
        for (Field field : fields) {
            field.setAccessible(true);
            String fieldName = field.getName();
            if (!fieldName.startsWith("__cachedInstance")) {
                continue;
            }
            try {
                retainedSize += getRetainedSizeOf(field.get(projection));
            }
            catch (IllegalAccessException e) {
                throw new RuntimeException(e);
            }
        }
        return retainedSize;
    }

    private long getRetainedSizeOf(Object object)
    {
        if (object instanceof PageBuilder) {
            return ((PageBuilder) object).getRetainedSizeInBytes();
        }
        if (object instanceof Block) {
            return ((Block) object).getRetainedSizeInBytes();
        }

        Class type = object.getClass();
        if (type.isArray()) {
            if (type == int[].class) {
                return sizeOf((int[]) object);
            }
            else if (type == boolean[].class) {
                return sizeOf((boolean[]) object);
            }
            else if (type == byte[].class) {
                return sizeOf((byte[]) object);
            }
            else if (type == long[].class) {
                return sizeOf((long[]) object);
            }
            else if (type == short[].class) {
                return sizeOf((short[]) object);
            }
            else if (type == Block[].class) {
                Object[] objects = (Object[]) object;
                return Arrays.stream(objects)
                        .mapToLong(this::getRetainedSizeOf)
                        .sum();
            }
            else {
                throw new IllegalArgumentException(format("Unknown type encountered: %s", type));
            }
        }

        long retainedSize = ClassLayout.parseClass(type).instanceSize();
        Field[] fields = type.getDeclaredFields();
        for (Field field : fields) {
            try {
                if (field.getType().isPrimitive() || Modifier.isStatic(field.getModifiers())) {
                    continue;
                }
                field.setAccessible(true);
                retainedSize += getRetainedSizeOf(field.get(object));
            }
            catch (IllegalAccessException t) {
                throw new RuntimeException(t);
            }
        }
        return retainedSize;
    }

    private List<Object> executeProjectionWithAll(String projection, Type expectedType, Session session, ExpressionCompiler compiler)
    {
        requireNonNull(projection, "projection is null");

        Expression projectionExpression = createExpression(session, projection, metadata, SYMBOL_TYPES);
        RowExpression projectionRowExpression = toRowExpression(session, projectionExpression);

        List<Object> results = new ArrayList<>();

        // If the projection does not need bound values, execute query using full engine
        if (!needsBoundValue(projectionExpression)) {
            MaterializedResult result = runner.execute("SELECT " + projection);
            assertType(result.getTypes(), expectedType);
            assertEquals(result.getTypes().size(), 1);
            assertEquals(result.getMaterializedRows().size(), 1);
            Object queryResult = Iterables.getOnlyElement(result.getMaterializedRows()).getField(0);
            results.add(queryResult);
        }

        // execute as standalone operator
        OperatorFactory operatorFactory = compileFilterProject(session.getSqlFunctionProperties(), Optional.empty(), projectionRowExpression, compiler);
        Object directOperatorValue = selectSingleValue(operatorFactory, expectedType, session);
        results.add(directOperatorValue);

        // interpret
        Object interpretedValue = interpret(projectionExpression, expectedType, session);
        results.add(interpretedValue);

        // execute over normal operator
        SourceOperatorFactory scanProjectOperatorFactory = compileScanFilterProject(session.getSqlFunctionProperties(), Optional.empty(), projectionRowExpression, compiler);
        Object scanOperatorValue = selectSingleValue(scanProjectOperatorFactory, expectedType, createNormalSplit(), session);
        results.add(scanOperatorValue);

        // execute over record set
        Object recordValue = selectSingleValue(scanProjectOperatorFactory, expectedType, createRecordSetSplit(), session);
        results.add(recordValue);

        //
        // If the projection does not need bound values, execute query using full engine
        if (!needsBoundValue(projectionExpression)) {
            MaterializedResult result = runner.execute("SELECT " + projection);
            assertType(result.getTypes(), expectedType);
            assertEquals(result.getTypes().size(), 1);
            assertEquals(result.getMaterializedRows().size(), 1);
            Object queryResult = Iterables.getOnlyElement(result.getMaterializedRows()).getField(0);
            results.add(queryResult);
        }

        // validate type at end since some tests expect failure and for those UNKNOWN is used instead of actual type
        assertEquals(projectionRowExpression.getType(), expectedType);
        return results;
    }

    private RowExpression toRowExpression(Session session, Expression projectionExpression)
    {
        Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
                session,
                metadata,
                SQL_PARSER,
                SYMBOL_TYPES,
                projectionExpression,
                ImmutableMap.of(),
                WarningCollector.NOOP);
        return toRowExpression(projectionExpression, expressionTypes, INPUT_MAPPING);
    }

    private Object selectSingleValue(OperatorFactory operatorFactory, Type type, Session session)
    {
        Operator operator = operatorFactory.createOperator(createDriverContext(session));
        return selectSingleValue(operator, type);
    }

    private Object selectSingleValue(SourceOperatorFactory operatorFactory, Type type, Split split, Session session)
    {
        SourceOperator operator = operatorFactory.createOperator(createDriverContext(session));
        operator.addSplit(new ScheduledSplit(0, operator.getSourceId(), split));
        operator.noMoreSplits();
        return selectSingleValue(operator, type);
    }

    private Object selectSingleValue(Operator operator, Type type)
    {
        Page output = getAtMostOnePage(operator, SOURCE_PAGE);

        assertNotNull(output);
        assertEquals(output.getPositionCount(), 1);
        assertEquals(output.getChannelCount(), 1);

        Block block = output.getBlock(0);
        assertEquals(block.getPositionCount(), 1);

        return type.getObjectValue(session.getSqlFunctionProperties(), block, 0);
    }

    public void assertFilter(String filter, boolean expected, boolean withNoInputColumns)
    {
        assertFilter(filter, expected, withNoInputColumns, compiler);
    }

    private void assertFilter(String filter, boolean expected, boolean withNoInputColumns, ExpressionCompiler compiler)
    {
        List<Boolean> results = executeFilterWithAll(filter, TEST_SESSION, withNoInputColumns, compiler);
        HashSet<Boolean> resultSet = new HashSet<>(results);

        // we should only have a single result
        assertEquals(resultSet.size(), 1, "Expected only [" + expected + "] result unique result, but got " + resultSet);

        assertEquals((boolean) Iterables.getOnlyElement(resultSet), expected);
    }

    private List<Boolean> executeFilterWithAll(String filter, Session session, boolean executeWithNoInputColumns, ExpressionCompiler compiler)
    {
        requireNonNull(filter, "filter is null");

        Expression filterExpression = createExpression(session, filter, metadata, SYMBOL_TYPES);
        RowExpression filterRowExpression = toRowExpression(session, filterExpression);

        List<Boolean> results = new ArrayList<>();

        // execute as standalone operator
        OperatorFactory operatorFactory = compileFilterProject(session.getSqlFunctionProperties(), Optional.of(filterRowExpression), constant(true, BOOLEAN), compiler);
        results.add(executeFilter(operatorFactory, session));

        if (executeWithNoInputColumns) {
            // execute as standalone operator
            operatorFactory = compileFilterWithNoInputColumns(session.getSqlFunctionProperties(), filterRowExpression, compiler);
            results.add(executeFilterWithNoInputColumns(operatorFactory, session));
        }

        // interpret
        Boolean interpretedValue = (Boolean) interpret(filterExpression, BOOLEAN, session);
        if (interpretedValue == null) {
            interpretedValue = false;
        }
        results.add(interpretedValue);

        // execute over normal operator
        SourceOperatorFactory scanProjectOperatorFactory = compileScanFilterProject(session.getSqlFunctionProperties(), Optional.of(filterRowExpression), constant(true, BOOLEAN), compiler);
        boolean scanOperatorValue = executeFilter(scanProjectOperatorFactory, createNormalSplit(), session);
        results.add(scanOperatorValue);

        // execute over record set
        boolean recordValue = executeFilter(scanProjectOperatorFactory, createRecordSetSplit(), session);
        results.add(recordValue);

        //
        // If the filter does not need bound values, execute query using full engine
        if (!needsBoundValue(filterExpression)) {
            MaterializedResult result = runner.execute("SELECT TRUE WHERE " + filter);
            assertEquals(result.getTypes().size(), 1);

            Boolean queryResult;
            if (result.getMaterializedRows().isEmpty()) {
                queryResult = false;
            }
            else {
                assertEquals(result.getMaterializedRows().size(), 1);
                queryResult = (Boolean) Iterables.getOnlyElement(result.getMaterializedRows()).getField(0);
            }
            results.add(queryResult);
        }

        return results;
    }

    public static Expression createExpression(String expression, Metadata metadata, TypeProvider symbolTypes)
    {
        return createExpression(TEST_SESSION, expression, metadata, symbolTypes);
    }

    public static Expression createExpression(Session session, String expression, Metadata metadata, TypeProvider symbolTypes)
    {
        Expression parsedExpression = SQL_PARSER.createExpression(expression, createParsingOptions(session));

        parsedExpression = rewriteIdentifiersToSymbolReferences(parsedExpression);

        final ExpressionAnalysis analysis = analyzeExpressions(
                session,
                metadata,
                SQL_PARSER,
                symbolTypes,
                ImmutableList.of(parsedExpression),
                ImmutableMap.of(),
                WarningCollector.NOOP,
                false);

        Expression rewrittenExpression = ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
        {
            @Override
            public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
            {
                Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context);

                // cast expression if coercion is registered
                Type coercion = analysis.getCoercion(node);
                if (coercion != null) {
                    rewrittenExpression = new Cast(
                            rewrittenExpression,
                            coercion.getTypeSignature().toString(),
                            false,
                            analysis.isTypeOnlyCoercion(node));
                }

                return rewrittenExpression;
            }

            @Override
            public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
            {
                if (analysis.isColumnReference(node)) {
                    return rewriteExpression(node, context, treeRewriter);
                }

                Expression rewrittenExpression = treeRewriter.defaultRewrite(node, context);

                // cast expression if coercion is registered
                Type coercion = analysis.getCoercion(node);
                if (coercion != null) {
                    rewrittenExpression = new Cast(rewrittenExpression, coercion.getTypeSignature().toString());
                }

                return rewrittenExpression;
            }
        }, parsedExpression);

        return canonicalizeExpression(rewrittenExpression);
    }

    private static boolean executeFilterWithNoInputColumns(OperatorFactory operatorFactory, Session session)
    {
        return executeFilterWithNoInputColumns(operatorFactory.createOperator(createDriverContext(session)));
    }

    private static boolean executeFilter(OperatorFactory operatorFactory, Session session)
    {
        return executeFilter(operatorFactory.createOperator(createDriverContext(session)));
    }

    private static boolean executeFilter(SourceOperatorFactory operatorFactory, Split split, Session session)
    {
        SourceOperator operator = operatorFactory.createOperator(createDriverContext(session));
        operator.addSplit(new ScheduledSplit(0, operator.getSourceId(), split));
        operator.noMoreSplits();
        return executeFilter(operator);
    }

    private static boolean executeFilter(Operator operator)
    {
        Page page = getAtMostOnePage(operator, SOURCE_PAGE);

        boolean value;
        if (page != null) {
            assertEquals(page.getPositionCount(), 1);
            assertEquals(page.getChannelCount(), 1);

            assertTrue(BOOLEAN.getBoolean(page.getBlock(0), 0));
            value = true;
        }
        else {
            value = false;
        }
        return value;
    }

    private static boolean executeFilterWithNoInputColumns(Operator operator)
    {
        Page page = getAtMostOnePage(operator, ZERO_CHANNEL_PAGE);

        boolean value;
        if (page != null) {
            assertEquals(page.getPositionCount(), 1);
            assertEquals(page.getChannelCount(), 0);
            value = true;
        }
        else {
            value = false;
        }
        return value;
    }

    private static boolean needsBoundValue(Expression projectionExpression)
    {
        final AtomicBoolean hasSymbolReferences = new AtomicBoolean();
        new DefaultTraversalVisitor<Void, Void>()
        {
            @Override
            protected Void visitSymbolReference(SymbolReference node, Void context)
            {
                hasSymbolReferences.set(true);
                return null;
            }
        }.process(projectionExpression, null);

        return hasSymbolReferences.get();
    }

    private Object interpret(Expression expression, Type expectedType, Session session)
    {
        Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(session, metadata, SQL_PARSER, SYMBOL_TYPES, expression, emptyMap(), WarningCollector.NOOP);
        ExpressionInterpreter evaluator = ExpressionInterpreter.expressionInterpreter(expression, metadata, session, expressionTypes);

        Object result = evaluator.evaluate(variable -> {
            Symbol symbol = new Symbol(variable.getName());
            int position = 0;
            int channel = INPUT_MAPPING.get(new VariableReferenceExpression(Optional.empty(), symbol.getName(), SYMBOL_TYPES.get(symbol.toSymbolReference())));
            Type type = SYMBOL_TYPES.get(symbol.toSymbolReference());

            Block block = SOURCE_PAGE.getBlock(channel);

            if (block.isNull(position)) {
                return null;
            }

            Class<?> javaType = type.getJavaType();
            if (javaType == boolean.class) {
                return type.getBoolean(block, position);
            }
            else if (javaType == long.class) {
                return type.getLong(block, position);
            }
            else if (javaType == double.class) {
                return type.getDouble(block, position);
            }
            else if (javaType == Slice.class) {
                return type.getSlice(block, position);
            }
            else if (javaType == Block.class) {
                return type.getObject(block, position);
            }
            else {
                throw new UnsupportedOperationException("not yet implemented");
            }
        });

        // convert result from stack type to Type ObjectValue
        Block block = Utils.nativeValueToBlock(expectedType, result);

        return expectedType.getObjectValue(session.getSqlFunctionProperties(), block, 0);
    }

    private static OperatorFactory compileFilterWithNoInputColumns(SqlFunctionProperties sqlFunctionProperties, RowExpression filter, ExpressionCompiler compiler)
    {
        try {
            Supplier<PageProcessor> processor = compiler.compilePageProcessor(sqlFunctionProperties, Optional.of(filter), ImmutableList.of());

            return new FilterAndProjectOperatorFactory(0, new PlanNodeId("test"), processor, ImmutableList.of(), new DataSize(0, BYTE), 0);
        }
        catch (Throwable e) {
            if (e instanceof UncheckedExecutionException) {
                e = e.getCause();
            }
            throw new RuntimeException("Error compiling " + filter + ": " + e.getMessage(), e);
        }
    }

    private static OperatorFactory compileFilterProject(SqlFunctionProperties sqlFunctionProperties, Optional<RowExpression> filter, RowExpression projection, ExpressionCompiler compiler)
    {
        try {
            Supplier<PageProcessor> processor = compiler.compilePageProcessor(sqlFunctionProperties, filter, ImmutableList.of(projection));
            return new FilterAndProjectOperatorFactory(0, new PlanNodeId("test"), processor, ImmutableList.of(projection.getType()), new DataSize(0, BYTE), 0);
        }
        catch (Throwable e) {
            if (e instanceof UncheckedExecutionException) {
                e = e.getCause();
            }
            throw new RuntimeException("Error compiling " + projection + ": " + e.getMessage(), e);
        }
    }

    private static SourceOperatorFactory compileScanFilterProject(SqlFunctionProperties sqlFunctionProperties, Optional<RowExpression> filter, RowExpression projection, ExpressionCompiler compiler)
    {
        try {
            Supplier<CursorProcessor> cursorProcessor = compiler.compileCursorProcessor(
                    sqlFunctionProperties,
                    filter,
                    ImmutableList.of(projection),
                    SOURCE_ID);

            Supplier<PageProcessor> pageProcessor = compiler.compilePageProcessor(
                    sqlFunctionProperties,
                    filter,
                    ImmutableList.of(projection));

            return new ScanFilterAndProjectOperator.ScanFilterAndProjectOperatorFactory(
                    0,
                    new PlanNodeId("test"),
                    SOURCE_ID,
                    PAGE_SOURCE_PROVIDER,
                    cursorProcessor,
                    pageProcessor,
                    new TableHandle(
                            new ConnectorId("test"),
                            new ConnectorTableHandle() {},
                            new ConnectorTransactionHandle() {},
                            Optional.empty()),
                    ImmutableList.of(),
                    ImmutableList.of(projection.getType()),
                    Optional.empty(),
                    new DataSize(0, BYTE),
                    0);
        }
        catch (Throwable e) {
            if (e instanceof UncheckedExecutionException) {
                e = e.getCause();
            }
            throw new RuntimeException("Error compiling filter " + filter + ": " + e.getMessage(), e);
        }
    }

    private RowExpression toRowExpression(Expression projection, Map<NodeRef<Expression>, Type> expressionTypes, Map<VariableReferenceExpression, Integer> layout)
    {
        return translate(projection, expressionTypes, layout, metadata.getFunctionAndTypeManager(), session);
    }

    private static Page getAtMostOnePage(Operator operator, Page sourcePage)
    {
        // add our input page if needed
        if (operator.needsInput()) {
            operator.addInput(sourcePage);
        }

        // try to get the output page
        Page result = operator.getOutput();

        // tell operator to finish
        operator.finish();

        // try to get output until the operator is finished
        while (!operator.isFinished()) {
            // operator should never block
            assertTrue(operator.isBlocked().isDone());

            Page output = operator.getOutput();
            if (output != null) {
                assertNull(result);
                result = output;
            }
        }

        return result;
    }

    private static DriverContext createDriverContext(Session session)
    {
        return createTaskContext(EXECUTOR, SCHEDULED_EXECUTOR, session)
                .addPipelineContext(0, true, true, false)
                .addDriverContext();
    }

    private static void assertType(List<Type> types, Type expectedType)
    {
        assertTrue(types.size() == 1, "Expected one type, but got " + types);
        Type actualType = types.get(0);
        assertEquals(actualType, expectedType);
    }

    public Session getSession()
    {
        return session;
    }

    @Override
    public void close()
    {
        runner.close();
    }

    private static class TestPageSourceProvider
            implements PageSourceProvider
    {
        @Override
        public ConnectorPageSource createPageSource(Session session, Split split, TableHandle table, List<ColumnHandle> columns, RuntimeStats runtimeStats)
        {
            assertInstanceOf(split.getConnectorSplit(), FunctionAssertions.TestSplit.class);
            FunctionAssertions.TestSplit testSplit = (FunctionAssertions.TestSplit) split.getConnectorSplit();
            if (testSplit.isRecordSet()) {
                RecordSet records = InMemoryRecordSet.builder(ImmutableList.of(BIGINT, VARCHAR, DOUBLE, BOOLEAN, BIGINT, VARCHAR, VARCHAR, TIMESTAMP_WITH_TIME_ZONE, VARBINARY, INTEGER, TEST_ROW_TYPE))
                        .addRow(
                                1234L,
                                "hello",
                                12.34,
                                true,
                                new DateTime(2001, 8, 22, 3, 4, 5, 321, DateTimeZone.UTC).getMillis(),
                                "%el%",
                                null,
                                packDateTimeWithZone(new DateTime(1970, 1, 1, 0, 1, 0, 999, DateTimeZone.UTC).getMillis(), TimeZoneKey.getTimeZoneKey("Z")),
                                Slices.wrappedBuffer((byte) 0xab),
                                1234,
                                TEST_ROW_DATA.getBlock(0))
                        .build();
                return new RecordPageSource(records);
            }
            else {
                return new FixedPageSource(ImmutableList.of(SOURCE_PAGE));
            }
        }
    }

    private static Split createRecordSetSplit()
    {
        return new Split(new ConnectorId("test"), TestingTransactionHandle.create(), new TestSplit(true));
    }

    private static Split createNormalSplit()
    {
        return new Split(new ConnectorId("test"), TestingTransactionHandle.create(), new TestSplit(false));
    }

    private static RowType createTestRowType(int numberOfFields)
    {
        Iterator<Type> types = Iterables.<Type>cycle(
                BIGINT,
                INTEGER,
                VARCHAR,
                DOUBLE,
                BOOLEAN,
                VARBINARY,
                RowType.from(ImmutableList.of(RowType.field("nested_nested_column", VARCHAR)))).iterator();

        List<RowType.Field> fields = new ArrayList<>();
        for (int fieldIdx = 0; fieldIdx < numberOfFields; fieldIdx++) {
            fields.add(new RowType.Field(Optional.of("nested_column_" + fieldIdx), types.next()));
        }

        return RowType.from(fields);
    }

    private static Block createTestRowData(RowType rowType)
    {
        Iterator<Object> values = Iterables.cycle(
                1234L,
                34,
                "hello",
                12.34d,
                true,
                Slices.wrappedBuffer((byte) 0xab),
                createRowBlock(ImmutableList.of(VARCHAR), Collections.singleton("innerFieldValue").toArray()).getBlock(0)).iterator();

        final int numFields = rowType.getFields().size();
        Object[] rowValues = new Object[numFields];
        for (int fieldIdx = 0; fieldIdx < numFields; fieldIdx++) {
            rowValues[fieldIdx] = values.next();
        }

        return createRowBlock(rowType.getTypeParameters(), rowValues);
    }

    private static class TestSplit
            implements ConnectorSplit
    {
        private final boolean recordSet;

        private TestSplit(boolean recordSet)
        {
            this.recordSet = recordSet;
        }

        private boolean isRecordSet()
        {
            return recordSet;
        }

        @Override
        public NodeSelectionStrategy getNodeSelectionStrategy()
        {
            return HARD_AFFINITY;
        }

        @Override
        public List<HostAddress> getPreferredNodes(NodeProvider nodeProvider)
        {
            return ImmutableList.of();
        }

        @Override
        public Object getInfo()
        {
            return this;
        }
    }
}