PageFunctionCompiler.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.CallSiteBinder;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.ParameterizedType;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.ForLoop;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.Work;
import com.facebook.presto.operator.project.ConstantPageProjection;
import com.facebook.presto.operator.project.GeneratedPageProjection;
import com.facebook.presto.operator.project.InputChannels;
import com.facebook.presto.operator.project.InputPageProjection;
import com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter;
import com.facebook.presto.operator.project.PageFilter;
import com.facebook.presto.operator.project.PageProjection;
import com.facebook.presto.operator.project.PageProjectionWithOutputs;
import com.facebook.presto.operator.project.SelectedPositions;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.gen.LambdaBytecodeGenerator.CompiledLambda;
import com.facebook.presto.sql.planner.CompilerConfig;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Joiner;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Primitives;
import com.google.common.util.concurrent.UncheckedExecutionException;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

import javax.annotation.Nullable;
import javax.inject.Inject;

import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.add;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.and;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeStatic;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.lessThan;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.newArray;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.not;
import static com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters;
import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR;
import static com.facebook.presto.sql.gen.BytecodeUtils.boxPrimitiveIfNecessary;
import static com.facebook.presto.sql.gen.BytecodeUtils.invoke;
import static com.facebook.presto.sql.gen.BytecodeUtils.unboxPrimitiveIfNecessary;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.initializeCommonSubExpressionFields;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.collectCSEByLevel;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.getExpressionsPartitionedByCSE;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.rewriteExpressionWithCSE;
import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateMethodsForLambda;
import static com.facebook.presto.sql.relational.Expressions.subExpressions;
import static com.facebook.presto.util.CompilerUtils.defineClass;
import static com.facebook.presto.util.CompilerUtils.makeClassName;
import static com.facebook.presto.util.Reflection.constructorMethodHandle;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Collections.emptyMap;
import static java.util.Objects.requireNonNull;

public class PageFunctionCompiler
{
    private static Logger log = Logger.get(PageFunctionCompiler.class);
    // Benchmark and experiments showed that when we put too many projections into a single PageProjection, performance degrades. Flamechart shows that a lot of cpus are
    // spent in evaluate. The root cause is not well understood. Maybe when the function is too large JIT has problem optimizing it. Empirical evidence shows that when there
    // are less than 10 projections performance is generally better with common sub-expressions. So we set an upper limit on how many projections we would compile together here.
    private static final int MAX_PROJECTION_GROUP_SIZE = 10;

    private final Metadata metadata;
    private final DeterminismEvaluator determinismEvaluator;

    private final LoadingCache<CacheKey, Supplier<PageProjection>> projectionCache;
    private final LoadingCache<CacheKey, Supplier<PageFilter>> filterCache;

    private final CacheStatsMBean projectionCacheStats;
    private final CacheStatsMBean filterCacheStats;

    @Inject
    public PageFunctionCompiler(Metadata metadata, CompilerConfig config)
    {
        this(metadata, requireNonNull(config, "config is null").getExpressionCacheSize());
    }

    public PageFunctionCompiler(Metadata metadata, int expressionCacheSize)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");
        this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager());

        if (expressionCacheSize > 0) {
            projectionCache = CacheBuilder.newBuilder()
                    .recordStats()
                    .maximumSize(expressionCacheSize)
                    .build(CacheLoader.from(cacheKey -> compileProjectionInternal(cacheKey.sqlFunctionProperties, cacheKey.sessionFunctions, cacheKey.rowExpressions, cacheKey.isOptimizeCommonSubExpression, Optional.empty())));
            projectionCacheStats = new CacheStatsMBean(projectionCache);
        }
        else {
            projectionCache = null;
            projectionCacheStats = null;
        }

        if (expressionCacheSize > 0) {
            filterCache = CacheBuilder.newBuilder()
                    .recordStats()
                    .maximumSize(expressionCacheSize)
                    .build(CacheLoader.from(cacheKey -> compileFilterInternal(cacheKey.sqlFunctionProperties, cacheKey.sessionFunctions, cacheKey.rowExpressions.get(0), cacheKey.isOptimizeCommonSubExpression, Optional.empty())));
            filterCacheStats = new CacheStatsMBean(filterCache);
        }
        else {
            filterCache = null;
            filterCacheStats = null;
        }
    }

    @Nullable
    @Managed
    @Nested
    public CacheStatsMBean getProjectionCache()
    {
        return projectionCacheStats;
    }

    @Nullable
    @Managed
    @Nested
    public CacheStatsMBean getFilterCache()
    {
        return filterCacheStats;
    }

    public List<Supplier<PageProjectionWithOutputs>> compileProjections(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            List<? extends RowExpression> projections,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        if (isOptimizeCommonSubExpression) {
            ImmutableList.Builder<Supplier<PageProjectionWithOutputs>> pageProjections = ImmutableList.builder();
            ImmutableMap.Builder<RowExpression, Integer> expressionsWithPositionBuilder = ImmutableMap.builder();
            Set<RowExpression> expressionCandidates = new HashSet<>();
            for (int i = 0; i < projections.size(); i++) {
                RowExpression projection = projections.get(i);
                // Duplicate expressions are not expected here in general due to duplicate assignments pruning in query optimization, hence we skip CSE for them to allow for a
                // simpler implementation (and duplicate projections in expressionsWithPositionBuilder will throw exception when calling expressionsWithPositionBuilder.build())
                if (projection instanceof ConstantExpression || projection instanceof InputReferenceExpression || expressionCandidates.contains(projection)) {
                    pageProjections.add(toPageProjectionWithOutputs(compileProjection(sqlFunctionProperties, sessionFunctions, projection, classNameSuffix), new int[] {i}));
                }
                else {
                    expressionsWithPositionBuilder.put(projection, i);
                    expressionCandidates.add(projection);
                }
            }
            Map<RowExpression, Integer> expressionsWithPosition = expressionsWithPositionBuilder.build();

            Map<List<RowExpression>, Boolean> projectionsPartitionedByCSE = getExpressionsPartitionedByCSE(expressionsWithPosition.keySet(), MAX_PROJECTION_GROUP_SIZE);

            for (Map.Entry<List<RowExpression>, Boolean> entry : projectionsPartitionedByCSE.entrySet()) {
                if (entry.getValue()) {
                    pageProjections.add(toPageProjectionWithOutputs(
                            compileProjectionCached(sqlFunctionProperties, sessionFunctions, entry.getKey(), true, classNameSuffix),
                            toIntArray(entry.getKey().stream().map(expressionsWithPosition::get).collect(toImmutableList()))));
                }
                else {
                    verify(entry.getKey().size() == 1, "Expect non-cse expression list to only have one element");
                    RowExpression projection = entry.getKey().get(0);
                    pageProjections.add(toPageProjectionWithOutputs(
                            compileProjection(sqlFunctionProperties, sessionFunctions, projection, classNameSuffix),
                            new int[] {expressionsWithPosition.get(projection)}));
                }
            }
            return pageProjections.build();
        }
        return IntStream.range(0, projections.size())
                .mapToObj(outputChannel -> toPageProjectionWithOutputs(
                        compileProjection(sqlFunctionProperties, sessionFunctions, projections.get(outputChannel), classNameSuffix),
                        new int[] {outputChannel}))
                .collect(toImmutableList());
    }

    @VisibleForTesting
    public List<Supplier<PageProjectionWithOutputs>> compileProjections(
            SqlFunctionProperties sqlFunctionProperties,
            List<? extends RowExpression> projections,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        return compileProjections(sqlFunctionProperties, emptyMap(), projections, isOptimizeCommonSubExpression, classNameSuffix);
    }

    @VisibleForTesting
    public Supplier<PageProjection> compileProjection(
            SqlFunctionProperties sqlFunctionProperties,
            RowExpression projection,
            Optional<String> classNameSuffix)
    {
        return compileProjection(sqlFunctionProperties, emptyMap(), projection, classNameSuffix);
    }

    private Supplier<PageProjection> compileProjection(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            RowExpression projection,
            Optional<String> classNameSuffix)
    {
        if (projection instanceof InputReferenceExpression) {
            InputReferenceExpression input = (InputReferenceExpression) projection;
            InputPageProjection projectionFunction = new InputPageProjection(input.getField());
            return () -> projectionFunction;
        }

        if (projection instanceof ConstantExpression) {
            ConstantExpression constant = (ConstantExpression) projection;
            ConstantPageProjection projectionFunction = new ConstantPageProjection(constant.getValue(), constant.getType());
            return () -> projectionFunction;
        }

        return compileProjectionCached(sqlFunctionProperties, sessionFunctions, ImmutableList.of(projection), false, classNameSuffix);
    }

    private Supplier<PageProjection> compileProjectionCached(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            List<RowExpression> projections,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        if (projectionCache == null) {
            return compileProjectionInternal(sqlFunctionProperties, sessionFunctions, projections, isOptimizeCommonSubExpression, classNameSuffix);
        }
        try {
            return projectionCache.getUnchecked(new CacheKey(sqlFunctionProperties, sessionFunctions, projections, isOptimizeCommonSubExpression));
        }
        catch (UncheckedExecutionException e) {
            throwIfInstanceOf(e.getCause(), PrestoException.class);
            throw e;
        }
    }

    private Supplier<PageProjectionWithOutputs> toPageProjectionWithOutputs(Supplier<PageProjection> pageProjection, int[] outputChannels)
    {
        return () -> new PageProjectionWithOutputs(pageProjection.get(), outputChannels);
    }

    private Supplier<PageProjection> compileProjectionInternal(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            List<RowExpression> projections,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        requireNonNull(projections, "projections is null");
        checkArgument(!projections.isEmpty() && projections.stream().allMatch(projection -> projection instanceof CallExpression || projection instanceof SpecialFormExpression));

        PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(projections);
        List<RowExpression> rewrittenExpression = result.getRewrittenExpressions();

        CallSiteBinder callSiteBinder = new CallSiteBinder();

        // generate Work
        ClassDefinition pageProjectionWorkDefinition = definePageProjectWorkClass(
                sqlFunctionProperties,
                sessionFunctions,
                rewrittenExpression,
                callSiteBinder,
                isOptimizeCommonSubExpression,
                classNameSuffix);

        Class<? extends Work> pageProjectionWorkClass;
        try {
            pageProjectionWorkClass = defineClass(pageProjectionWorkDefinition, Work.class, callSiteBinder.getBindings(), getClass().getClassLoader());
        }
        catch (PrestoException prestoException) {
            throw prestoException;
        }
        catch (Exception e) {
            throw new PrestoException(COMPILER_ERROR, e);
        }

        return () -> new GeneratedPageProjection(
                rewrittenExpression,
                rewrittenExpression.stream().allMatch(determinismEvaluator::isDeterministic),
                result.getInputChannels(),
                constructorMethodHandle(pageProjectionWorkClass, List.class, SqlFunctionProperties.class, Page.class, SelectedPositions.class));
    }

    private static ParameterizedType generateProjectionWorkClassName(Optional<String> classNameSuffix)
    {
        return makeClassName("PageProjectionWork", classNameSuffix);
    }

    private ClassDefinition definePageProjectWorkClass(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            List<RowExpression> projections,
            CallSiteBinder callSiteBinder,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        ClassDefinition classDefinition = new ClassDefinition(
                a(PUBLIC, FINAL),
                generateProjectionWorkClassName(classNameSuffix),
                type(Object.class),
                type(Work.class));

        FieldDefinition blockBuilderFields = classDefinition.declareField(a(PRIVATE), "blockBuilders", type(List.class, BlockBuilder.class));
        FieldDefinition propertiesField = classDefinition.declareField(a(PRIVATE), "properties", SqlFunctionProperties.class);
        FieldDefinition pageField = classDefinition.declareField(a(PRIVATE), "page", Page.class);
        FieldDefinition selectedPositionsField = classDefinition.declareField(a(PRIVATE), "selectedPositions", SelectedPositions.class);
        FieldDefinition nextIndexOrPositionField = classDefinition.declareField(a(PRIVATE), "nextIndexOrPosition", int.class);
        FieldDefinition resultField = classDefinition.declareField(a(PRIVATE), "result", List.class);

        CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);

        // process
        generateProcessMethod(classDefinition, blockBuilderFields, projections.size(), propertiesField, pageField, selectedPositionsField, nextIndexOrPositionField, resultField);

        // getResult
        MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "getResult", type(Object.class), ImmutableList.of());
        method.getBody().append(method.getThis().getField(resultField)).ret(Object.class);

        AtomicInteger lambdaCounter = new AtomicInteger(0);

        Map<LambdaDefinitionExpression, CompiledLambda> compiledLambdaMap = generateMethodsForLambda(classDefinition,
                callSiteBinder,
                cachedInstanceBinder,
                projections,
                metadata,
                sqlFunctionProperties,
                sessionFunctions,
                "",
                lambdaCounter);

        // cse
        Map<VariableReferenceExpression, CommonSubExpressionFields> cseFields = ImmutableMap.of();

        RowExpressionCompiler compiler = new RowExpressionCompiler(
                classDefinition,
                callSiteBinder,
                cachedInstanceBinder,
                new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields),
                metadata,
                sqlFunctionProperties,
                sessionFunctions,
                compiledLambdaMap,
                lambdaCounter);

        if (isOptimizeCommonSubExpression) {
            Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel = collectCSEByLevel(projections);
            if (!commonSubExpressionsByLevel.isEmpty()) {
                cseFields = declareCommonSubExpressionFields(classDefinition, commonSubExpressionsByLevel);
                compiler = new RowExpressionCompiler(
                        classDefinition,
                        callSiteBinder,
                        cachedInstanceBinder,
                        new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields),
                        metadata,
                        sqlFunctionProperties,
                        sessionFunctions,
                        compiledLambdaMap,
                        lambdaCounter);

                generateCommonSubExpressionMethods(classDefinition, compiler, commonSubExpressionsByLevel, cseFields);
                Map<RowExpression, VariableReferenceExpression> commonSubExpressions = commonSubExpressionsByLevel.values().stream()
                        .flatMap(m -> m.entrySet().stream())
                        .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
                projections = projections.stream().map(projection -> rewriteExpressionWithCSE(projection, commonSubExpressions)).collect(toImmutableList());
                if (log.isDebugEnabled()) {
                    log.debug("Extracted %d common sub-expressions", commonSubExpressions.size());
                    commonSubExpressions.entrySet().forEach(entry -> log.debug("\t%s = %s", entry.getValue(), entry.getKey()));
                    log.debug("Rewrote %d projections: %s", projections.size(), Joiner.on(", ").join(projections));
                }
            }
        }

        // evaluate
        generateEvaluateMethod(classDefinition, compiler, projections, blockBuilderFields, cseFields);

        // constructor
        Parameter blockBuilders = arg("blockBuilders", type(List.class, BlockBuilder.class));
        Parameter properties = arg("properties", SqlFunctionProperties.class);
        Parameter page = arg("page", Page.class);
        Parameter selectedPositions = arg("selectedPositions", SelectedPositions.class);

        MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC), blockBuilders, properties, page, selectedPositions);

        BytecodeBlock body = constructorDefinition.getBody();
        Variable thisVariable = constructorDefinition.getThis();

        body.comment("super();")
                .append(thisVariable)
                .invokeConstructor(Object.class)
                .append(thisVariable.setField(blockBuilderFields, invokeStatic(ImmutableList.class, "copyOf", ImmutableList.class, blockBuilders.cast(Collection.class))))
                .append(thisVariable.setField(propertiesField, properties))
                .append(thisVariable.setField(pageField, page))
                .append(thisVariable.setField(selectedPositionsField, selectedPositions))
                .append(thisVariable.setField(nextIndexOrPositionField, selectedPositions.invoke("getOffset", int.class)))
                .append(thisVariable.setField(resultField, constantNull(Block.class)));

        initializeCommonSubExpressionFields(cseFields.values(), thisVariable, body);

        cachedInstanceBinder.generateInitializations(thisVariable, body);
        body.ret();

        return classDefinition;
    }

    private static MethodDefinition generateProcessMethod(
            ClassDefinition classDefinition,
            FieldDefinition blockBuilders,
            int blockBuilderSize,
            FieldDefinition properties,
            FieldDefinition page,
            FieldDefinition selectedPositions,
            FieldDefinition nextIndexOrPosition,
            FieldDefinition result)
    {
        MethodDefinition method = classDefinition.declareMethod(a(PUBLIC), "process", type(boolean.class), ImmutableList.of());

        Scope scope = method.getScope();
        Variable thisVariable = method.getThis();
        BytecodeBlock body = method.getBody();

        Variable from = scope.declareVariable("from", body, thisVariable.getField(nextIndexOrPosition));
        Variable to = scope.declareVariable("to", body, add(thisVariable.getField(selectedPositions).invoke("getOffset", int.class), thisVariable.getField(selectedPositions).invoke("size", int.class)));
        Variable positions = scope.declareVariable(int[].class, "positions");
        Variable index = scope.declareVariable(int.class, "index");

        IfStatement ifStatement = new IfStatement()
                .condition(thisVariable.getField(selectedPositions).invoke("isList", boolean.class));
        body.append(ifStatement);

        ifStatement.ifTrue(new BytecodeBlock()
                .append(positions.set(thisVariable.getField(selectedPositions).invoke("getPositions", int[].class)))
                .append(new ForLoop("positions loop")
                        .initialize(index.set(from))
                        .condition(lessThan(index, to))
                        .update(index.increment())
                        .body(new BytecodeBlock()
                                .append(thisVariable.invoke("evaluate", void.class, thisVariable.getField(properties), thisVariable.getField(page), positions.getElement(index))))));

        ifStatement.ifFalse(new ForLoop("range based loop")
                .initialize(index.set(from))
                .condition(lessThan(index, to))
                .update(index.increment())
                .body(new BytecodeBlock()
                        .append(thisVariable.invoke("evaluate", void.class, thisVariable.getField(properties), thisVariable.getField(page), index))));

        Variable blocksBuilder = scope.declareVariable("blocksBuilder", body, invokeStatic(ImmutableList.class, "builder", ImmutableList.Builder.class));
        Variable iterator = scope.createTempVariable(int.class);
        ForLoop forLoop = new ForLoop("for (iterator = 0; iterator < this.blockBuilders.size(); iterator ++) blockBuildersBuilder.add(this.blockBuilders.get(iterator).builder();")
                .initialize(iterator.set(constantInt(0)))
                .condition(lessThan(iterator, constantInt(blockBuilderSize)))
                .update(iterator.increment())
                .body(new BytecodeBlock()
                        .append(blocksBuilder.invoke(
                                "add",
                                ImmutableList.Builder.class,
                                thisVariable.getField(blockBuilders).invoke("get", Object.class, iterator).cast(BlockBuilder.class).invoke("build", Block.class).cast(Object.class)).pop()));

        body.append(forLoop)
                .comment("result = blockBuildersBuilder.build(); return true")
                .append(thisVariable.setField(result, blocksBuilder.invoke("build", ImmutableList.class)))
                .push(true)
                .retBoolean();

        return method;
    }

    private List<MethodDefinition> generateCommonSubExpressionMethods(
            ClassDefinition classDefinition,
            RowExpressionCompiler compiler,
            Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel,
            Map<VariableReferenceExpression, CommonSubExpressionFields> commonSubExpressionFieldsMap)
    {
        ImmutableList.Builder<MethodDefinition> methods = ImmutableList.builder();

        Parameter properties = arg("properties", SqlFunctionProperties.class);
        Parameter page = arg("page", Page.class);
        Parameter position = arg("position", int.class);

        int startLevel = commonSubExpressionsByLevel.keySet().stream().reduce(Math::min).get();
        int maxLevel = commonSubExpressionsByLevel.keySet().stream().reduce(Math::max).get();
        for (int i = startLevel; i <= maxLevel; i++) {
            if (commonSubExpressionsByLevel.containsKey(i)) {
                for (Map.Entry<RowExpression, VariableReferenceExpression> entry : commonSubExpressionsByLevel.get(i).entrySet()) {
                    RowExpression cse = entry.getKey();
                    Class<?> type = Primitives.wrap(cse.getType().getJavaType());
                    VariableReferenceExpression cseVariable = entry.getValue();
                    CommonSubExpressionFields cseFields = commonSubExpressionFieldsMap.get(cseVariable);
                    MethodDefinition method = classDefinition.declareMethod(
                            a(PRIVATE),
                            "get" + cseVariable.getName(),
                            type(cseFields.getResultType()),
                            ImmutableList.<Parameter>builder()
                                    .add(properties)
                                    .add(page)
                                    .add(position)
                                    .build());

                    method.comment("cse: %s", cse);

                    Scope scope = method.getScope();
                    BytecodeBlock body = method.getBody();
                    Variable thisVariable = method.getThis();

                    declareBlockVariables(ImmutableList.of(cse), page, scope, body);
                    scope.declareVariable("wasNull", body, constantFalse());

                    IfStatement ifStatement = new IfStatement()
                            .condition(thisVariable.getField(cseFields.getEvaluatedField()))
                            .ifFalse(new BytecodeBlock()
                                    .append(thisVariable)
                                    .append(compiler.compile(cse, scope, Optional.empty()))
                                    .append(boxPrimitiveIfNecessary(scope, type))
                                    .putField(cseFields.getResultField())
                                    .append(thisVariable.setField(cseFields.getEvaluatedField(), constantBoolean(true))));

                    body.append(ifStatement)
                            .append(thisVariable)
                            .getField(cseFields.getResultField())
                            .retObject();

                    methods.add(method);
                }
            }
        }
        return methods.build();
    }

    private MethodDefinition generateEvaluateMethod(
            ClassDefinition classDefinition,
            RowExpressionCompiler compiler,
            List<RowExpression> projections,
            FieldDefinition blockBuilders,
            Map<VariableReferenceExpression, CommonSubExpressionFields> cseFields)
    {
        Parameter properties = arg("properties", SqlFunctionProperties.class);
        Parameter page = arg("page", Page.class);
        Parameter position = arg("position", int.class);

        MethodDefinition method = classDefinition.declareMethod(
                a(PUBLIC),
                "evaluate",
                type(void.class),
                ImmutableList.<Parameter>builder()
                        .add(properties)
                        .add(page)
                        .add(position)
                        .build());

        method.comment("Projections: %s", Joiner.on(", ").join(projections));

        Scope scope = method.getScope();
        BytecodeBlock body = method.getBody();
        Variable thisVariable = method.getThis();

        declareBlockVariables(projections, page, scope, body);
        Variable wasNull = scope.declareVariable("wasNull", body, constantFalse());
        cseFields.values().forEach(fields -> body.append(thisVariable.setField(fields.getEvaluatedField(), constantBoolean(false))));

        Variable outputBlockVariable = scope.createTempVariable(BlockBuilder.class);
        for (int i = 0; i < projections.size(); i++) {
            body.append(outputBlockVariable.set(thisVariable.getField(blockBuilders).invoke("get", Object.class, constantInt(i))))
                    .append(compiler.compile(projections.get(i), scope, Optional.of(outputBlockVariable)))
                    .append(constantBoolean(false))
                    .putVariable(wasNull);
        }
        body.ret();
        return method;
    }

    @VisibleForTesting
    public Supplier<PageFilter> compileFilter(
            SqlFunctionProperties sqlFunctionProperties,
            RowExpression filter,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        return compileFilter(sqlFunctionProperties, emptyMap(), filter, isOptimizeCommonSubExpression, classNameSuffix);
    }

    public Supplier<PageFilter> compileFilter(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            RowExpression filter,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        if (filterCache == null) {
            return compileFilterInternal(sqlFunctionProperties, sessionFunctions, filter, isOptimizeCommonSubExpression, classNameSuffix);
        }
        try {
            return filterCache.getUnchecked(new CacheKey(sqlFunctionProperties, sessionFunctions, ImmutableList.of(filter), isOptimizeCommonSubExpression));
        }
        catch (UncheckedExecutionException e) {
            throwIfInstanceOf(e.getCause(), PrestoException.class);
            throw e;
        }
    }

    private Supplier<PageFilter> compileFilterInternal(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            RowExpression filter,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        requireNonNull(filter, "filter is null");

        PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(filter);

        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = defineFilterClass(
                sqlFunctionProperties,
                sessionFunctions,
                result.getRewrittenExpression(),
                result.getInputChannels(),
                callSiteBinder,
                isOptimizeCommonSubExpression,
                classNameSuffix);

        Class<? extends PageFilter> functionClass;
        try {
            functionClass = defineClass(classDefinition, PageFilter.class, callSiteBinder.getBindings(), getClass().getClassLoader());
        }
        catch (PrestoException prestoException) {
            throw prestoException;
        }
        catch (Exception e) {
            throw new PrestoException(COMPILER_ERROR, filter.toString(), e.getCause());
        }

        return () -> {
            try {
                return functionClass.getConstructor().newInstance();
            }
            catch (ReflectiveOperationException e) {
                throw new PrestoException(COMPILER_ERROR, e);
            }
        };
    }

    private static ParameterizedType generateFilterClassName(Optional<String> classNameSuffix)
    {
        return makeClassName(PageFilter.class.getSimpleName(), classNameSuffix);
    }

    private ClassDefinition defineFilterClass(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            RowExpression filter,
            InputChannels inputChannels,
            CallSiteBinder callSiteBinder,
            boolean isOptimizeCommonSubExpression,
            Optional<String> classNameSuffix)
    {
        ClassDefinition classDefinition = new ClassDefinition(
                a(PUBLIC, FINAL),
                generateFilterClassName(classNameSuffix),
                type(Object.class),
                type(PageFilter.class));

        CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);

        AtomicInteger lambdaCounter = new AtomicInteger(0);
        // cse
        Map<VariableReferenceExpression, CommonSubExpressionFields> cseFields = ImmutableMap.of();
        RowExpressionCompiler compiler = new RowExpressionCompiler(
                classDefinition,
                callSiteBinder,
                cachedInstanceBinder,
                new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields),
                metadata,
                sqlFunctionProperties,
                sessionFunctions,
                ImmutableMap.of(),
                lambdaCounter);

        if (isOptimizeCommonSubExpression) {
            Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel = collectCSEByLevel(filter);
            if (!commonSubExpressionsByLevel.isEmpty()) {
                cseFields = declareCommonSubExpressionFields(classDefinition, commonSubExpressionsByLevel);
                compiler = new RowExpressionCompiler(
                        classDefinition,
                        callSiteBinder,
                        cachedInstanceBinder,
                        new FieldAndVariableReferenceCompiler(callSiteBinder, cseFields),
                        metadata,
                        sqlFunctionProperties,
                        sessionFunctions,
                        ImmutableMap.of(),
                        lambdaCounter);

                generateCommonSubExpressionMethods(classDefinition, compiler, commonSubExpressionsByLevel, cseFields);
                Map<RowExpression, VariableReferenceExpression> commonSubExpressions = commonSubExpressionsByLevel.values().stream()
                        .flatMap(m -> m.entrySet().stream())
                        .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
                filter = rewriteExpressionWithCSE(filter, commonSubExpressions);
                if (log.isDebugEnabled()) {
                    log.debug("Extracted %d common sub-expressions", commonSubExpressions.size());
                    commonSubExpressions.entrySet().forEach(entry -> log.debug("\t%s = %s", entry.getValue(), entry.getKey()));
                    log.debug("Rewrote filter: %s", filter);
                }
            }
        }

        generateFilterMethod(classDefinition, compiler, filter, cseFields);

        FieldDefinition selectedPositions = classDefinition.declareField(a(PRIVATE), "selectedPositions", boolean[].class);
        generatePageFilterMethod(classDefinition, selectedPositions);

        // isDeterministic
        classDefinition.declareMethod(a(PUBLIC), "isDeterministic", type(boolean.class))
                .getBody()
                .append(constantBoolean(determinismEvaluator.isDeterministic(filter)))
                .retBoolean();

        // getInputChannels
        classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(InputChannels.class))
                .getBody()
                .append(invoke(callSiteBinder.bind(inputChannels, InputChannels.class), "getInputChannels"))
                .retObject();

        // toString
        String toStringResult = toStringHelper(classDefinition.getType()
                .getJavaClassName())
                .add("filter", filter)
                .toString();
        classDefinition.declareMethod(a(PUBLIC), "toString", type(String.class))
                .getBody()
                // bind constant via invokedynamic to avoid constant pool issues due to large strings
                .append(invoke(callSiteBinder.bind(toStringResult, String.class), "toString"))
                .retObject();

        // constructor
        MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC));

        BytecodeBlock body = constructorDefinition.getBody();
        Variable thisVariable = constructorDefinition.getThis();

        body.comment("super();")
                .append(thisVariable)
                .invokeConstructor(Object.class)
                .append(thisVariable.setField(selectedPositions, newArray(type(boolean[].class), 0)));

        initializeCommonSubExpressionFields(cseFields.values(), thisVariable, body);

        cachedInstanceBinder.generateInitializations(thisVariable, body);
        body.ret();
        return classDefinition;
    }

    private static MethodDefinition generatePageFilterMethod(ClassDefinition classDefinition, FieldDefinition selectedPositionsField)
    {
        Parameter properties = arg("properties", SqlFunctionProperties.class);
        Parameter page = arg("page", Page.class);

        MethodDefinition method = classDefinition.declareMethod(
                a(PUBLIC),
                "filter",
                type(SelectedPositions.class),
                ImmutableList.<Parameter>builder()
                        .add(properties)
                        .add(page)
                        .build());

        Scope scope = method.getScope();
        Variable thisVariable = method.getThis();
        BytecodeBlock body = method.getBody();

        Variable positionCount = scope.declareVariable("positionCount", body, page.invoke("getPositionCount", int.class));

        body.append(new IfStatement("grow selectedPositions if necessary")
                .condition(lessThan(thisVariable.getField(selectedPositionsField).length(), positionCount))
                .ifTrue(thisVariable.setField(selectedPositionsField, newArray(type(boolean[].class), positionCount))));

        Variable selectedPositions = scope.declareVariable("selectedPositions", body, thisVariable.getField(selectedPositionsField));
        Variable position = scope.declareVariable(int.class, "position");
        body.append(new ForLoop()
                .initialize(position.set(constantInt(0)))
                .condition(lessThan(position, positionCount))
                .update(position.increment())
                .body(selectedPositions.setElement(position, thisVariable.invoke("filter", boolean.class, properties, page, position))));

        body.append(invokeStatic(
                PageFilter.class,
                "positionsArrayToSelectedPositions",
                SelectedPositions.class,
                selectedPositions,
                positionCount)
                .ret());

        return method;
    }

    private MethodDefinition generateFilterMethod(
            ClassDefinition classDefinition,
            RowExpressionCompiler compiler,
            RowExpression filter,
            Map<VariableReferenceExpression, CommonSubExpressionFields> cseFields)
    {
        Parameter properties = arg("properties", SqlFunctionProperties.class);
        Parameter page = arg("page", Page.class);
        Parameter position = arg("position", int.class);

        MethodDefinition method = classDefinition.declareMethod(
                a(PUBLIC),
                "filter",
                type(boolean.class),
                ImmutableList.<Parameter>builder()
                        .add(properties)
                        .add(page)
                        .add(position)
                        .build());

        method.comment("Filter: %s", filter.toString());

        Scope scope = method.getScope();
        BytecodeBlock body = method.getBody();
        Variable thisVariable = scope.getThis();

        declareBlockVariables(ImmutableList.of(filter), page, scope, body);
        cseFields.values().forEach(fields -> body.append(thisVariable.setField(fields.getEvaluatedField(), constantBoolean(false))));

        Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
        Variable result = scope.declareVariable(boolean.class, "result");
        body.append(compiler.compile(filter, scope, Optional.empty()))
                // store result so we can check for null
                .putVariable(result)
                .append(and(not(wasNullVariable), result).ret());
        return method;
    }

    private static void declareBlockVariables(List<RowExpression> expressions, Parameter page, Scope scope, BytecodeBlock body)
    {
        for (int channel : getInputChannels(expressions)) {
            scope.declareVariable("block_" + channel, body, page.invoke("getBlock", Block.class, constantInt(channel)));
        }
    }

    private static List<Integer> getInputChannels(Iterable<RowExpression> expressions)
    {
        TreeSet<Integer> channels = new TreeSet<>();
        for (RowExpression expression : subExpressions(expressions)) {
            if (expression instanceof InputReferenceExpression) {
                channels.add(((InputReferenceExpression) expression).getField());
            }
        }
        return ImmutableList.copyOf(channels);
    }

    private static int[] toIntArray(List<Integer> list)
    {
        int[] array = new int[list.size()];
        for (int i = 0; i < list.size(); i++) {
            array[i] = list.get(i);
        }
        return array;
    }

    private static class FieldAndVariableReferenceCompiler
            implements RowExpressionVisitor<BytecodeNode, Scope>
    {
        private final InputReferenceCompiler inputReferenceCompiler;
        private final Map<VariableReferenceExpression, CommonSubExpressionFields> variableMap;

        public FieldAndVariableReferenceCompiler(CallSiteBinder callSiteBinder, Map<VariableReferenceExpression, CommonSubExpressionFields> variableMap)
        {
            this.inputReferenceCompiler = new InputReferenceCompiler(
                    (scope, field) -> scope.getVariable("block_" + field),
                    (scope, field) -> scope.getVariable("position"),
                    callSiteBinder);
            this.variableMap = ImmutableMap.copyOf(variableMap);
        }

        @Override
        public BytecodeNode visitCall(CallExpression call, Scope context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitInputReference(InputReferenceExpression reference, Scope context)
        {
            return inputReferenceCompiler.visitInputReference(reference, context);
        }

        @Override
        public BytecodeNode visitConstant(ConstantExpression literal, Scope context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitLambda(LambdaDefinitionExpression lambda, Scope context)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public BytecodeNode visitVariableReference(VariableReferenceExpression reference, Scope context)
        {
            CommonSubExpressionFields fields = variableMap.get(reference);
            return new BytecodeBlock()
                    .append(context.getThis().invoke(fields.getMethodName(), fields.getResultType(), context.getVariable("properties"), context.getVariable("page"), context.getVariable("position")))
                    .append(unboxPrimitiveIfNecessary(context, Primitives.wrap(reference.getType().getJavaType())));
        }

        @Override
        public BytecodeNode visitSpecialForm(SpecialFormExpression specialForm, Scope context)
        {
            throw new UnsupportedOperationException();
        }
    }

    private static final class CacheKey
    {
        private final SqlFunctionProperties sqlFunctionProperties;
        private final Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions;
        private final List<RowExpression> rowExpressions;
        private final boolean isOptimizeCommonSubExpression;

        private CacheKey(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, List<RowExpression> rowExpressions, boolean isOptimizeCommonSubExpression)
        {
            requireNonNull(rowExpressions, "rowExpressions is null");
            checkArgument(rowExpressions.size() >= 1, "Expect at least one RowExpression");
            this.sqlFunctionProperties = requireNonNull(sqlFunctionProperties, "sqlFunctionProperties is null");
            this.sessionFunctions = requireNonNull(sessionFunctions, "sessionFunctions is null");
            this.rowExpressions = ImmutableList.copyOf(rowExpressions);
            this.isOptimizeCommonSubExpression = isOptimizeCommonSubExpression;
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (!(o instanceof CacheKey)) {
                return false;
            }
            CacheKey that = (CacheKey) o;
            return Objects.equals(sqlFunctionProperties, that.sqlFunctionProperties) &&
                    Objects.equals(rowExpressions, that.rowExpressions) &&
                    isOptimizeCommonSubExpression == that.isOptimizeCommonSubExpression;
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(sqlFunctionProperties, rowExpressions, isOptimizeCommonSubExpression);
        }
    }
}