RowExpressionInterpreter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner;

import com.facebook.airlift.json.JsonCodec;
import com.facebook.presto.client.FailureInfo;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.RowBlockBuilder;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.FunctionType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionImplementationType;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.SqlInvokedScalarFunctionImplementation;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.IntermediateFormExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.InterpretedFunctionInvoker;
import com.facebook.presto.sql.planner.Interpreters.LambdaVariableResolver;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.util.Failures;
import com.google.common.collect.ImmutableList;
import com.google.common.primitives.Primitives;
import io.airlift.joni.Regex;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;

import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.JsonType.JSON;
import static com.facebook.presto.common.type.StandardTypes.ARRAY;
import static com.facebook.presto.common.type.StandardTypes.MAP;
import static com.facebook.presto.common.type.StandardTypes.ROW;
import static com.facebook.presto.common.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.expressions.DynamicFilters.isDynamicFilter;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.metadata.CastType.CAST;
import static com.facebook.presto.metadata.CastType.JSON_TO_ARRAY_CAST;
import static com.facebook.presto.metadata.CastType.JSON_TO_MAP_CAST;
import static com.facebook.presto.metadata.CastType.JSON_TO_ROW_CAST;
import static com.facebook.presto.spi.function.FunctionImplementationType.JAVA;
import static com.facebook.presto.spi.function.FunctionImplementationType.SQL;
import static com.facebook.presto.spi.function.FunctionKind.SCALAR;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.EVALUATED;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.SERIALIZABLE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.BIND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.NULL_IF;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.ROW_CONSTRUCTOR;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.gen.VarArgsToMapAdapterGenerator.generateVarArgsToMapAdapter;
import static com.facebook.presto.sql.planner.Interpreters.interpretDereference;
import static com.facebook.presto.sql.planner.Interpreters.interpretLikePredicate;
import static com.facebook.presto.sql.planner.LiteralEncoder.estimatedSizeInBytes;
import static com.facebook.presto.sql.planner.LiteralEncoder.isSupportedLiteralType;
import static com.facebook.presto.sql.planner.RowExpressionInterpreter.SpecialCallResult.changed;
import static com.facebook.presto.sql.planner.RowExpressionInterpreter.SpecialCallResult.notChanged;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.SqlFunctionUtils.getSqlFunctionRowExpression;
import static com.facebook.presto.type.LikeFunctions.isLikePattern;
import static com.facebook.presto.type.LikeFunctions.unescapeLiteralLikePattern;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Predicates.instanceOf;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.invoke.MethodHandles.insertArguments;
import static java.util.Arrays.asList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public class RowExpressionInterpreter
{
    private static final long MAX_SERIALIZABLE_OBJECT_SIZE = 1000;
    private final RowExpression expression;
    private final ConnectorSession session;
    private final Level optimizationLevel;
    private final InterpretedFunctionInvoker functionInvoker;
    private final RowExpressionDeterminismEvaluator determinismEvaluator;
    private final FunctionAndTypeManager functionAndTypeManager;
    private final FunctionResolution resolution;

    private final Visitor visitor;

    public static Object evaluateConstantRowExpression(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session)
    {
        // evaluate the expression
        Object result = new RowExpressionInterpreter(expression, functionAndTypeManager, session, EVALUATED).evaluate();
        verify(!(result instanceof RowExpression), "RowExpression interpreter returned an unresolved expression");
        return result;
    }

    public static RowExpressionInterpreter rowExpressionInterpreter(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session)
    {
        return new RowExpressionInterpreter(expression, functionAndTypeManager, session, EVALUATED);
    }

    public RowExpressionInterpreter(RowExpression expression, Metadata metadata, ConnectorSession session, Level optimizationLevel)
    {
        this(expression, metadata.getFunctionAndTypeManager(), session, optimizationLevel);
    }

    public RowExpressionInterpreter(RowExpression expression, FunctionAndTypeManager functionAndTypeManager, ConnectorSession session, Level optimizationLevel)
    {
        this.expression = requireNonNull(expression, "expression is null");
        this.session = requireNonNull(session, "session is null");
        this.optimizationLevel = optimizationLevel;
        requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.functionInvoker = new InterpretedFunctionInvoker(functionAndTypeManager);
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
        this.resolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
        this.functionAndTypeManager = functionAndTypeManager;

        this.visitor = new Visitor();
    }

    public Type getType()
    {
        return expression.getType();
    }

    public Object evaluate()
    {
        checkState(optimizationLevel.ordinal() >= EVALUATED.ordinal(), "evaluate() not allowed for optimizer");
        return expression.accept(visitor, null);
    }

    public Object optimize()
    {
        checkState(optimizationLevel.ordinal() < EVALUATED.ordinal(), "optimize() not allowed for interpreter");
        return optimize(null);
    }

    /**
     * Replace symbol with constants
     */
    public Object optimize(VariableResolver inputs)
    {
        checkState(optimizationLevel.ordinal() <= EVALUATED.ordinal(), "optimize(SymbolResolver) not allowed for interpreter");
        return expression.accept(visitor, inputs);
    }

    private class Visitor
            implements RowExpressionVisitor<Object, Object>
    {
        @Override
        public Object visitInputReference(InputReferenceExpression node, Object context)
        {
            return node;
        }

        @Override
        public Object visitConstant(ConstantExpression node, Object context)
        {
            return node.getValue();
        }

        @Override
        public Object visitVariableReference(VariableReferenceExpression node, Object context)
        {
            if (context instanceof VariableResolver) {
                return ((VariableResolver) context).getValue(node);
            }
            return node;
        }

        @Override
        public Object visitCall(CallExpression node, Object context)
        {
            List<Type> argumentTypes = new ArrayList<>();
            List<Object> argumentValues = new ArrayList<>();
            for (RowExpression expression : node.getArguments()) {
                Object value = expression.accept(this, context);
                argumentValues.add(value);
                argumentTypes.add(expression.getType());
            }

            FunctionHandle functionHandle = node.getFunctionHandle();
            FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(node.getFunctionHandle());
            if (!functionMetadata.isCalledOnNullInput()) {
                for (Object value : argumentValues) {
                    if (value == null) {
                        return null;
                    }
                }
            }

            // Special casing for large constant array construction
            if (resolution.isArrayConstructor(functionHandle)) {
                SpecialCallResult result = tryHandleArrayConstructor(node, argumentValues);
                if (result.isChanged()) {
                    return result.getValue();
                }
            }

            // Special casing for cast
            if (resolution.isCastFunction(functionHandle)) {
                SpecialCallResult result = tryHandleCast(node, argumentValues);
                if (result.isChanged()) {
                    return result.getValue();
                }
            }

            // Special casing for like
            if (resolution.isLikeFunction(functionHandle)) {
                SpecialCallResult result = tryHandleLike(node, argumentValues, argumentTypes, context);
                if (result.isChanged()) {
                    return result.getValue();
                }
            }

            if (functionMetadata.getFunctionKind() != SCALAR) {
                if (optimizationLevel.ordinal() < EVALUATED.ordinal()) {
                    return call(node.getDisplayName(), functionHandle, node.getType(), toRowExpressions(argumentValues, node.getArguments()));
                }

                throw new RuntimeException("Cannot evaluate non-scalar function: " + node.getDisplayName());
            }

            // do not optimize non-deterministic functions
            if (optimizationLevel.ordinal() < EVALUATED.ordinal() &&
                    (!functionMetadata.isDeterministic() ||
                            hasUnresolvedValue(argumentValues) ||
                            isDynamicFilter(node) ||
                            resolution.isFailFunction(functionHandle))) {
                return call(node.getDisplayName(), functionHandle, node.getType(), toRowExpressions(argumentValues, node.getArguments()));
            }

            Object value;
            FunctionImplementationType implementationType = functionMetadata.getImplementationType();
            if (!implementationType.canBeEvaluatedInCoordinator()) {
                // do not interpret remote functions or cpp UDF on coordinator
                return call(node.getDisplayName(), functionHandle, node.getType(), toRowExpressions(argumentValues, node.getArguments()));
            }
            else if (implementationType.equals(JAVA)) {
                value = functionInvoker.invoke(functionHandle, session.getSqlFunctionProperties(), argumentValues);
            }
            else {
                checkState(implementationType.equals(SQL));
                SqlInvokedScalarFunctionImplementation functionImplementation = (SqlInvokedScalarFunctionImplementation) functionAndTypeManager.getScalarFunctionImplementation(functionHandle);
                RowExpression function = getSqlFunctionRowExpression(
                        functionMetadata,
                        functionImplementation,
                        functionAndTypeManager,
                        session.getSqlFunctionProperties(),
                        session.getSessionFunctions(),
                        node.getArguments());
                RowExpressionInterpreter rowExpressionInterpreter = new RowExpressionInterpreter(function, functionAndTypeManager, session, optimizationLevel);
                if (optimizationLevel.ordinal() >= EVALUATED.ordinal()) {
                    value = rowExpressionInterpreter.evaluate();
                }
                else {
                    value = rowExpressionInterpreter.optimize();
                }
                if (value instanceof CallExpression) {
                    return call(node.getDisplayName(), functionHandle, node.getType(), toRowExpressions(argumentValues, node.getArguments()));
                }
            }
            if (optimizationLevel.ordinal() <= SERIALIZABLE.ordinal() && !isSerializable(value, node.getType())) {
                return call(node.getDisplayName(), functionHandle, node.getType(), toRowExpressions(argumentValues, node.getArguments()));
            }
            return value;
        }

        @Override
        public Object visitLambda(LambdaDefinitionExpression node, Object context)
        {
            if (optimizationLevel.ordinal() < EVALUATED.ordinal()) {
                // TODO: enable optimization related to lambda expression
                // Currently, we are not able to determine if lambda is deterministic.
                // context is passed down as null here since lambda argument can only be resolved under the evaluation context.
                RowExpression rewrittenBody = toRowExpression(processWithExceptionHandling(node.getBody(), null), node.getBody());
                if (!rewrittenBody.equals(node.getBody())) {
                    return new LambdaDefinitionExpression(node.getSourceLocation(), node.getArgumentTypes(), node.getArguments(), rewrittenBody);
                }
                return node;
            }
            RowExpression body = node.getBody();
            FunctionType functionType = (FunctionType) node.getType();
            checkArgument(node.getArguments().size() == functionType.getArgumentTypes().size());

            return generateVarArgsToMapAdapter(
                    Primitives.wrap(functionType.getReturnType().getJavaType()),
                    functionType.getArgumentTypes().stream()
                            .map(Type::getJavaType)
                            .map(Primitives::wrap)
                            .collect(toImmutableList()),
                    node.getArguments(),
                    map -> body.accept(this, new LambdaVariableResolver(map)));
        }

        @Override
        public Object visitIntermediateFormExpression(IntermediateFormExpression intermediateFormExpression, Object context)
        {
            return intermediateFormExpression;
        }

        @Override
        public Object visitSpecialForm(SpecialFormExpression node, Object context)
        {
            switch (node.getForm()) {
                case IF: {
                    checkArgument(node.getArguments().size() == 3);
                    Object condition = processWithExceptionHandling(node.getArguments().get(0), context);

                    if (condition instanceof RowExpression) {
                        return new SpecialFormExpression(
                                ((RowExpression) condition).getSourceLocation(),
                                IF,
                                node.getType(),
                                toRowExpression(condition, node.getArguments().get(0)),
                                toRowExpression(processWithExceptionHandling(node.getArguments().get(1), context), node.getArguments().get(1)),
                                toRowExpression(processWithExceptionHandling(node.getArguments().get(2), context), node.getArguments().get(2)));
                    }
                    else if (Boolean.TRUE.equals(condition)) {
                        return processWithExceptionHandling(node.getArguments().get(1), context);
                    }

                    return processWithExceptionHandling(node.getArguments().get(2), context);
                }
                case NULL_IF: {
                    checkArgument(node.getArguments().size() == 2);
                    Object left = processWithExceptionHandling(node.getArguments().get(0), context);
                    if (left == null) {
                        return null;
                    }

                    Object right = processWithExceptionHandling(node.getArguments().get(1), context);
                    if (right == null) {
                        return left;
                    }

                    if (hasUnresolvedValue(left, right)) {
                        return new SpecialFormExpression(
                                node.getSourceLocation(),
                                NULL_IF,
                                node.getType(),
                                toRowExpression(left, node.getArguments().get(0)),
                                toRowExpression(right, node.getArguments().get(1)));
                    }

                    Type leftType = node.getArguments().get(0).getType();
                    Type rightType = node.getArguments().get(1).getType();
                    Type commonType = functionAndTypeManager.getCommonSuperType(leftType, rightType).get();
                    FunctionHandle firstCast = functionAndTypeManager.lookupCast(CAST, leftType, commonType);
                    FunctionHandle secondCast = functionAndTypeManager.lookupCast(CAST, rightType, commonType);

                    // cast(first as <common type>) == cast(second as <common type>)
                    boolean equal = Boolean.TRUE.equals(invokeOperator(
                            EQUAL,
                            ImmutableList.of(commonType, commonType),
                            ImmutableList.of(
                                    functionInvoker.invoke(firstCast, session.getSqlFunctionProperties(), left),
                                    functionInvoker.invoke(secondCast, session.getSqlFunctionProperties(), right))));

                    if (equal) {
                        return null;
                    }
                    return left;
                }
                case IS_NULL: {
                    checkArgument(node.getArguments().size() == 1);
                    Object value = processWithExceptionHandling(node.getArguments().get(0), context);
                    if (value instanceof RowExpression) {
                        return new SpecialFormExpression(
                                node.getSourceLocation(),
                                IS_NULL,
                                node.getType(),
                                toRowExpression(value, node.getArguments().get(0)));
                    }
                    return value == null;
                }
                case AND: {
                    Object left = node.getArguments().get(0).accept(this, context);
                    Object right;

                    if (Boolean.FALSE.equals(left)) {
                        return false;
                    }

                    right = node.getArguments().get(1).accept(this, context);

                    if (Boolean.TRUE.equals(right)) {
                        return left;
                    }

                    if (Boolean.FALSE.equals(right) || Boolean.TRUE.equals(left)) {
                        return right;
                    }

                    if (left == null && right == null) {
                        return null;
                    }
                    return new SpecialFormExpression(
                            AND,
                            node.getType(),
                            toRowExpressions(
                                    asList(left, right),
                                    node.getArguments().subList(0, 2)));
                }
                case OR: {
                    Object left = node.getArguments().get(0).accept(this, context);
                    Object right;

                    if (Boolean.TRUE.equals(left)) {
                        return true;
                    }

                    right = node.getArguments().get(1).accept(this, context);

                    if (Boolean.FALSE.equals(right)) {
                        return left;
                    }

                    if (Boolean.TRUE.equals(right) || Boolean.FALSE.equals(left)) {
                        return right;
                    }

                    if (left == null && right == null) {
                        return null;
                    }
                    return new SpecialFormExpression(
                            OR,
                            node.getType(),
                            toRowExpressions(
                                    asList(left, right),
                                    node.getArguments().subList(0, 2)));
                }
                case ROW_CONSTRUCTOR: {
                    RowType rowType = (RowType) node.getType();
                    List<Type> parameterTypes = rowType.getTypeParameters();
                    List<RowExpression> arguments = node.getArguments();
                    checkArgument(parameterTypes.size() == arguments.size(), "RowConstructor does not contain all fields");
                    for (int i = 0; i < parameterTypes.size(); i++) {
                        checkArgument(parameterTypes.get(i).equals(arguments.get(i).getType()), "RowConstructor has field with incorrect type");
                    }

                    int cardinality = arguments.size();
                    List<Object> values = new ArrayList<>(cardinality);
                    arguments.forEach(argument -> values.add(argument.accept(this, context)));
                    if (hasUnresolvedValue(values)) {
                        return new SpecialFormExpression(ROW_CONSTRUCTOR, node.getType(), toRowExpressions(values, node.getArguments()));
                    }
                    else {
                        BlockBuilder blockBuilder = new RowBlockBuilder(parameterTypes, null, 1);
                        BlockBuilder singleRowBlockWriter = blockBuilder.beginBlockEntry();
                        for (int i = 0; i < cardinality; ++i) {
                            writeNativeValue(parameterTypes.get(i), singleRowBlockWriter, values.get(i));
                        }
                        blockBuilder.closeEntry();
                        return rowType.getObject(blockBuilder, 0);
                    }
                }
                case COALESCE: {
                    Type type = node.getType();
                    List<Object> values = node.getArguments().stream()
                            .map(value -> processWithExceptionHandling(value, context))
                            .filter(Objects::nonNull)
                            .flatMap(expression -> {
                                if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == COALESCE) {
                                    return ((SpecialFormExpression) expression).getArguments().stream();
                                }
                                return Stream.of(expression);
                            })
                            .collect(toList());

                    if ((!values.isEmpty() && !(values.get(0) instanceof RowExpression)) || values.size() == 1) {
                        return values.get(0);
                    }
                    ImmutableList.Builder<RowExpression> operandsBuilder = ImmutableList.builder();
                    Set<RowExpression> visitedExpression = new HashSet<>();
                    int i = 0;
                    for (Object value : values) {
                        RowExpression expression = LiteralEncoder.toRowExpression(node.getArguments().get(i).getSourceLocation(), value, type);
                        if (!determinismEvaluator.isDeterministic(expression) || visitedExpression.add(expression)) {
                            operandsBuilder.add(expression);
                        }
                        if (expression instanceof ConstantExpression && !(((ConstantExpression) expression).getValue() == null)) {
                            break;
                        }
                    }
                    List<RowExpression> expressions = operandsBuilder.build();

                    if (expressions.isEmpty()) {
                        return null;
                    }

                    if (expressions.size() == 1) {
                        return getOnlyElement(expressions);
                    }
                    return new SpecialFormExpression(node.getSourceLocation(), COALESCE, node.getType(), expressions);
                }
                case IN: {
                    checkArgument(node.getArguments().size() >= 2, "values must not be empty");

                    // use toList to handle null values
                    List<RowExpression> valueExpressions = node.getArguments().subList(1, node.getArguments().size());
                    List<Object> values = valueExpressions.stream().map(value -> value.accept(this, context)).collect(toList());
                    List<Type> valuesTypes = valueExpressions.stream().map(RowExpression::getType).collect(toImmutableList());
                    Object target = node.getArguments().get(0).accept(this, context);
                    Type targetType = node.getArguments().get(0).getType();

                    if (target == null) {
                        return null;
                    }

                    boolean hasUnresolvedValue = false;
                    if (target instanceof RowExpression) {
                        hasUnresolvedValue = true;
                    }

                    boolean hasNullValue = false;
                    boolean found = false;
                    List<RowExpression> unresolvedValues = new ArrayList<>(values.size());
                    for (int i = 0; i < values.size(); i++) {
                        Object value = values.get(i);
                        Type valueType = valuesTypes.get(i);
                        if (value instanceof RowExpression || target instanceof RowExpression) {
                            hasUnresolvedValue = true;
                            unresolvedValues.add(toRowExpression(value, valueExpressions.get(i)));
                            continue;
                        }

                        if (value == null) {
                            hasNullValue = true;
                        }
                        else {
                            Boolean result = (Boolean) invokeOperator(EQUAL, ImmutableList.of(targetType, valueType), ImmutableList.of(target, value));
                            if (result == null) {
                                hasNullValue = true;
                            }
                            else if (!found && result) {
                                // in does not short-circuit so we must evaluate all value in the list
                                found = true;
                            }
                        }
                    }
                    if (found) {
                        return true;
                    }

                    if (hasUnresolvedValue) {
                        List<RowExpression> simplifiedExpressionValues = Stream.concat(
                                        Stream.concat(
                                                Stream.of(toRowExpression(target, node.getArguments().get(0))),
                                                unresolvedValues.stream().filter(determinismEvaluator::isDeterministic).distinct()),
                                        unresolvedValues.stream().filter((expression -> !determinismEvaluator.isDeterministic(expression))))
                                .collect(toImmutableList());
                        return new SpecialFormExpression(IN, node.getType(), simplifiedExpressionValues);
                    }
                    if (hasNullValue) {
                        return null;
                    }
                    return false;
                }
                case DEREFERENCE: {
                    checkArgument(node.getArguments().size() == 2);

                    Object base = node.getArguments().get(0).accept(this, context);
                    int index = ((Number) node.getArguments().get(1).accept(this, context)).intValue();

                    // if the base part is evaluated to be null, the dereference expression should also be null
                    if (base == null) {
                        return null;
                    }

                    if (hasUnresolvedValue(base)) {
                        return new SpecialFormExpression(
                                node.getSourceLocation(),
                                DEREFERENCE,
                                node.getType(),
                                toRowExpression(base, node.getArguments().get(0)),
                                toRowExpression((long) index, node.getArguments().get(1)));
                    }
                    return interpretDereference(base, node.getType(), index);
                }
                case BIND: {
                    List<Object> values = node.getArguments()
                            .stream()
                            .map(value -> value.accept(this, context))
                            .collect(toList());
                    if (hasUnresolvedValue(values)) {
                        return new SpecialFormExpression(
                                BIND,
                                node.getType(),
                                toRowExpressions(values, node.getArguments()));
                    }
                    return insertArguments((MethodHandle) values.get(values.size() - 1), 0, values.subList(0, values.size() - 1).toArray());
                }
                case SWITCH: {
                    List<RowExpression> whenClauses;
                    Object elseValue = null;
                    RowExpression last = node.getArguments().get(node.getArguments().size() - 1);
                    if (last instanceof SpecialFormExpression && ((SpecialFormExpression) last).getForm().equals(WHEN)) {
                        whenClauses = node.getArguments().subList(1, node.getArguments().size());
                    }
                    else {
                        whenClauses = node.getArguments().subList(1, node.getArguments().size() - 1);
                    }

                    List<RowExpression> simplifiedWhenClauses = new ArrayList<>();
                    Object value = processWithExceptionHandling(node.getArguments().get(0), context);
                    if (value != null) {
                        for (RowExpression whenClause : whenClauses) {
                            checkArgument(whenClause instanceof SpecialFormExpression && ((SpecialFormExpression) whenClause).getForm().equals(WHEN));

                            RowExpression operand = ((SpecialFormExpression) whenClause).getArguments().get(0);
                            RowExpression result = ((SpecialFormExpression) whenClause).getArguments().get(1);

                            Object operandValue = processWithExceptionHandling(operand, context);

                            // call equals(value, operand)
                            if (operandValue instanceof RowExpression || value instanceof RowExpression) {
                                // cannot fully evaluate, add updated whenClause
                                simplifiedWhenClauses.add(new SpecialFormExpression(operand.getSourceLocation(), WHEN, whenClause.getType(), toRowExpression(operandValue, operand), toRowExpression(processWithExceptionHandling(result, context), result)));
                            }
                            else if (operandValue != null) {
                                Boolean isEqual = (Boolean) invokeOperator(
                                        EQUAL,
                                        ImmutableList.of(node.getArguments().get(0).getType(), operand.getType()),
                                        ImmutableList.of(value, operandValue));
                                if (isEqual != null && isEqual) {
                                    if (simplifiedWhenClauses.isEmpty()) {
                                        // this is the left-most true predicate. So return it.
                                        return processWithExceptionHandling(result, context);
                                    }

                                    elseValue = processWithExceptionHandling(result, context);
                                    break; // Done we found the last match. Don't need to go any further.
                                }
                            }
                        }
                    }

                    if (elseValue == null) {
                        elseValue = processWithExceptionHandling(last, context);
                    }

                    if (simplifiedWhenClauses.isEmpty()) {
                        return elseValue;
                    }

                    ImmutableList.Builder<RowExpression> argumentsBuilder = ImmutableList.builder();
                    argumentsBuilder.add(toRowExpression(value, node.getArguments().get(0)))
                            .addAll(simplifiedWhenClauses)
                            .add(toRowExpression(elseValue, last));
                    return new SpecialFormExpression(SWITCH, node.getType(), argumentsBuilder.build());
                }
                default:
                    throw new IllegalStateException("Can not compile special form: " + node.getForm());
            }
        }

        private Object processWithExceptionHandling(RowExpression expression, Object context)
        {
            if (expression == null) {
                return null;
            }
            try {
                return expression.accept(this, context);
            }
            catch (RuntimeException e) {
                // HACK
                // Certain operations like 0 / 0 or likeExpression may throw exceptions.
                // Wrap them in a call that will throw the exception if the expression is actually executed
                return createFailureFunction(e, expression.getType());
            }
        }

        private RowExpression createFailureFunction(RuntimeException exception, Type type)
        {
            requireNonNull(exception, "Exception is null");

            String failureInfo = JsonCodec.jsonCodec(FailureInfo.class).toJson(Failures.toFailure(exception).toFailureInfo());
            FunctionHandle jsonParse = functionAndTypeManager.lookupFunction(QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, "json_parse"), fromTypes(VARCHAR));
            Object json = functionInvoker.invoke(jsonParse, session.getSqlFunctionProperties(), utf8Slice(failureInfo));
            FunctionHandle cast = functionAndTypeManager.lookupCast(CAST, UNKNOWN, type);
            if (exception instanceof PrestoException) {
                long errorCode = ((PrestoException) exception).getErrorCode().getCode();
                FunctionHandle failureFunction = functionAndTypeManager.lookupFunction("fail", fromTypes(INTEGER, JSON));
                return call(CAST.name(), cast, type, call("fail", failureFunction, UNKNOWN, constant(errorCode, INTEGER), LiteralEncoder.toRowExpression(json, JSON)));
            }

            FunctionHandle failureFunction = functionAndTypeManager.lookupFunction("fail", fromTypes(JSON));
            return call(CAST.name(), cast, type, call("fail", failureFunction, UNKNOWN, LiteralEncoder.toRowExpression(json, JSON)));
        }

        private boolean hasUnresolvedValue(Object... values)
        {
            return hasUnresolvedValue(ImmutableList.copyOf(values));
        }

        private boolean hasUnresolvedValue(List<Object> values)
        {
            return values.stream().anyMatch(instanceOf(RowExpression.class)::apply);
        }

        private Object invokeOperator(OperatorType operatorType, List<? extends Type> argumentTypes, List<Object> argumentValues)
        {
            FunctionHandle operatorHandle = functionAndTypeManager.resolveOperator(operatorType, fromTypes(argumentTypes));
            return functionInvoker.invoke(operatorHandle, session.getSqlFunctionProperties(), argumentValues);
        }

        private List<RowExpression> toRowExpressions(List<Object> values, List<RowExpression> unchangedValues)
        {
            checkArgument(values != null, "value is null");
            checkArgument(unchangedValues != null, "value is null");
            checkArgument(values.size() == unchangedValues.size());
            ImmutableList.Builder<RowExpression> rowExpressions = ImmutableList.builder();
            for (int i = 0; i < values.size(); i++) {
                rowExpressions.add(toRowExpression(values.get(i), unchangedValues.get(i)));
            }
            return rowExpressions.build();
        }

        private RowExpression toRowExpression(Object value, RowExpression originalRowExpression)
        {
            if (optimizationLevel.ordinal() <= SERIALIZABLE.ordinal() && !isSerializable(value, originalRowExpression.getType())) {
                return originalRowExpression;
            }
            // handle lambda
            if (optimizationLevel.ordinal() < EVALUATED.ordinal() && value instanceof MethodHandle) {
                return originalRowExpression;
            }
            return LiteralEncoder.toRowExpression(originalRowExpression.getSourceLocation(), value, originalRowExpression.getType());
        }

        private boolean isSerializable(Object value, Type type)
        {
            // If value is already RowExpression, constant values contained inside should already have been made serializable. Otherwise, we make sure the object is small and serializable.
            return value instanceof RowExpression || (isSupportedLiteralType(type) && estimatedSizeInBytes(value) <= MAX_SERIALIZABLE_OBJECT_SIZE);
        }

        private SpecialCallResult tryHandleArrayConstructor(CallExpression callExpression, List<Object> argumentValues)
        {
            checkArgument(resolution.isArrayConstructor(callExpression.getFunctionHandle()));
            boolean allConstants = true;
            for (Object values : argumentValues) {
                if (values instanceof RowExpression) {
                    allConstants = false;
                    break;
                }
            }
            if (allConstants) {
                Type elementType = ((ArrayType) callExpression.getType()).getElementType();
                BlockBuilder arrayBlockBuilder = elementType.createBlockBuilder(null, argumentValues.size());
                for (Object value : argumentValues) {
                    writeNativeValue(elementType, arrayBlockBuilder, value);
                }
                return changed(arrayBlockBuilder.build());
            }
            return notChanged();
        }

        private SpecialCallResult tryHandleCast(CallExpression callExpression, List<Object> argumentValues)
        {
            checkArgument(resolution.isCastFunction(callExpression.getFunctionHandle()));
            checkArgument(callExpression.getArguments().size() == 1);
            RowExpression source = callExpression.getArguments().get(0);
            Type sourceType = source.getType();
            Type targetType = callExpression.getType();

            Object value = argumentValues.get(0);

            if (value == null) {
                return changed(null);
            }

            if (value instanceof RowExpression) {
                if (sourceType.equals(targetType)) {
                    return changed(value);
                }
                if (callExpression.getArguments().get(0) instanceof CallExpression) {
                    // Optimization for CAST(JSON_PARSE(...) AS ARRAY/MAP/ROW), solves https://github.com/prestodb/presto/issues/12829
                    CallExpression innerCall = (CallExpression) callExpression.getArguments().get(0);
                    if (functionAndTypeManager.getFunctionMetadata(innerCall.getFunctionHandle()).getName().getObjectName().equals("json_parse")) {
                        checkArgument(innerCall.getType().equals(JSON));
                        checkArgument(innerCall.getArguments().size() == 1);
                        TypeSignature returnType = functionAndTypeManager.getFunctionMetadata(callExpression.getFunctionHandle()).getReturnType();
                        if (returnType.getBase().equals(ARRAY)) {
                            return changed(call(
                                    JSON_TO_ARRAY_CAST.name(),
                                    functionAndTypeManager.lookupCast(
                                            JSON_TO_ARRAY_CAST,
                                            VARCHAR,
                                            functionAndTypeManager.getType(returnType)),
                                    callExpression.getType(),
                                    innerCall.getArguments()));
                        }
                        if (returnType.getBase().equals(MAP)) {
                            return changed(call(
                                    JSON_TO_MAP_CAST.name(),
                                    functionAndTypeManager.lookupCast(
                                            JSON_TO_MAP_CAST,
                                            VARCHAR,
                                            functionAndTypeManager.getType(returnType)),
                                    callExpression.getType(),
                                    innerCall.getArguments()));
                        }
                        if (returnType.getBase().equals(ROW)) {
                            return changed(call(
                                    JSON_TO_ROW_CAST.name(),
                                    functionAndTypeManager.lookupCast(
                                            JSON_TO_ROW_CAST,
                                            VARCHAR,
                                            functionAndTypeManager.getType(returnType)),
                                    callExpression.getType(),
                                    innerCall.getArguments()));
                        }
                    }
                }
                return changed(call(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), toRowExpression(value, source)));
            }

            // TODO: still there is limitation for RowExpression. Example types could be Regex
            if (optimizationLevel.ordinal() <= SERIALIZABLE.ordinal() && !isSupportedLiteralType(targetType)) {
                // Otherwise, cast will be evaluated through invoke later and generates unserializable constant expression.
                return changed(call(callExpression.getSourceLocation(), callExpression.getDisplayName(), callExpression.getFunctionHandle(), callExpression.getType(), toRowExpression(value, source)));
            }

            if (functionAndTypeManager.isTypeOnlyCoercion(sourceType, targetType)) {
                return changed(value);
            }
            return notChanged();
        }

        private SpecialCallResult tryHandleLike(CallExpression callExpression, List<Object> argumentValues, List<Type> argumentTypes, Object context)
        {
            FunctionResolution resolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
            checkArgument(resolution.isLikeFunction(callExpression.getFunctionHandle()));
            checkArgument(callExpression.getArguments().size() == 2);
            RowExpression likePatternExpression = callExpression.getArguments().get(1);
            if (!(likePatternExpression instanceof CallExpression &&
                    (((CallExpression) likePatternExpression).getFunctionHandle().equals(resolution.likePatternFunction()) ||
                            (resolution.isCastFunction(((CallExpression) likePatternExpression).getFunctionHandle()))))) {
                // expression was already optimized
                return notChanged();
            }
            Object value = argumentValues.get(0);
            Object possibleCompiledPattern = argumentValues.get(1);

            if (value == null) {
                return changed(null);
            }

            CallExpression likePatternCall = (CallExpression) likePatternExpression;

            Object nonCompiledPattern = likePatternCall.getArguments().get(0).accept(this, context);
            if (nonCompiledPattern == null) {
                return changed(null);
            }

            boolean hasEscape = false;  // We cannot use Optional given escape could exist and its value is null
            Object escape = null;
            if (likePatternCall.getArguments().size() == 2) {
                hasEscape = true;
                escape = likePatternCall.getArguments().get(1).accept(this, context);
            }

            if (hasEscape && escape == null) {
                return changed(null);
            }

            if (!hasUnresolvedValue(value) && !hasUnresolvedValue(nonCompiledPattern) && (!hasEscape || !hasUnresolvedValue(escape))) {
                // fast path when we know the pattern and escape are constants
                if (possibleCompiledPattern instanceof Regex) {
                    return changed(interpretLikePredicate(argumentTypes.get(0), (Slice) value, (Regex) possibleCompiledPattern));
                }
                if (possibleCompiledPattern == null) {
                    return changed(null);
                }

                checkState(possibleCompiledPattern instanceof CallExpression);
                // this corresponds to ExpressionInterpreter::getConstantPattern
                if (hasEscape) {
                    // like_pattern(pattern, escape)
                    possibleCompiledPattern = functionInvoker.invoke(((CallExpression) possibleCompiledPattern).getFunctionHandle(), session.getSqlFunctionProperties(), nonCompiledPattern, escape);
                }
                else {
                    // like_pattern(pattern)
                    possibleCompiledPattern = functionInvoker.invoke(((CallExpression) possibleCompiledPattern).getFunctionHandle(), session.getSqlFunctionProperties(), nonCompiledPattern);
                }

                checkState(possibleCompiledPattern instanceof Regex, "unexpected like pattern type " + possibleCompiledPattern.getClass());
                return changed(interpretLikePredicate(argumentTypes.get(0), (Slice) value, (Regex) possibleCompiledPattern));
            }

            // if pattern is a constant without % or _ replace with a comparison
            if (nonCompiledPattern instanceof Slice && (escape == null || escape instanceof Slice) && !isLikePattern((Slice) nonCompiledPattern, (Slice) escape)) {
                Slice unescapedPattern = unescapeLiteralLikePattern((Slice) nonCompiledPattern, (Slice) escape);
                Type valueType = argumentTypes.get(0);
                Type patternType = createVarcharType(unescapedPattern.length());
                Optional<Type> commonSuperType = functionAndTypeManager.getCommonSuperType(valueType, patternType);
                checkArgument(commonSuperType.isPresent(), "Missing super type when optimizing %s", callExpression);
                RowExpression valueExpression = LiteralEncoder.toRowExpression(callExpression.getSourceLocation(), value, valueType);
                RowExpression patternExpression = LiteralEncoder.toRowExpression(callExpression.getSourceLocation(), unescapedPattern, patternType);
                Type superType = commonSuperType.get();
                if (!valueType.equals(superType)) {
                    FunctionHandle cast = functionAndTypeManager.lookupCast(CAST, valueType, superType);
                    valueExpression = call(CAST.name(), cast, superType, valueExpression);
                }
                if (!patternType.equals(superType)) {
                    FunctionHandle cast = functionAndTypeManager.lookupCast(CAST, patternType, superType);
                    patternExpression = call(CAST.name(), cast, superType, patternExpression);
                }
                FunctionHandle equal = functionAndTypeManager.resolveOperator(EQUAL, fromTypes(superType, superType));
                return changed(call(EQUAL.name(), equal, BOOLEAN, valueExpression, patternExpression).accept(this, context));
            }
            return notChanged();
        }
    }

    static final class SpecialCallResult
    {
        private final Object value;
        private final boolean changed;

        private SpecialCallResult(Object value, boolean changed)
        {
            this.value = value;
            this.changed = changed;
        }

        public static SpecialCallResult notChanged()
        {
            return new SpecialCallResult(null, false);
        }

        public static SpecialCallResult changed(Object value)
        {
            return new SpecialCallResult(value, true);
        }

        public Object getValue()
        {
            return value;
        }

        public boolean isChanged()
        {
            return changed;
        }
    }
}