RowExpressionPredicateCompiler.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.CallSiteBinder;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.MethodDefinition;
import com.facebook.presto.bytecode.Parameter;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.relation.Predicate;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.PredicateCompiler;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.sql.relational.Expressions;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;

import javax.inject.Inject;

import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.TreeSet;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.Parameter.arg;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.and;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantInt;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.not;
import static com.facebook.presto.operator.project.PageFieldsToInputParametersRewriter.rewritePageFieldsToInputParameters;
import static com.facebook.presto.spi.StandardErrorCode.COMPILER_ERROR;
import static com.facebook.presto.sql.gen.BytecodeUtils.invoke;
import static com.facebook.presto.sql.gen.LambdaBytecodeGenerator.generateMethodsForLambda;
import static com.facebook.presto.util.CompilerUtils.defineClass;
import static com.facebook.presto.util.CompilerUtils.makeClassName;
import static java.util.Objects.requireNonNull;

public class RowExpressionPredicateCompiler
        implements PredicateCompiler
{
    private final Metadata metadata;
    private final LoadingCache<CacheKey, Supplier<Predicate>> predicateCache;

    @Inject
    public RowExpressionPredicateCompiler(Metadata metadata)
    {
        this(requireNonNull(metadata, "metadata is null"), 10_000);
    }

    public RowExpressionPredicateCompiler(Metadata metadata, int predicateCacheSize)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");

        if (predicateCacheSize > 0) {
            predicateCache = CacheBuilder.newBuilder()
                    .recordStats()
                    .maximumSize(predicateCacheSize)
                    .build(CacheLoader.from(cacheKey -> compilePredicateInternal(cacheKey.sqlFunctionProperties, cacheKey.sessionFunctions, cacheKey.rowExpression)));
        }
        else {
            predicateCache = null;
        }
    }

    @Override
    public Supplier<Predicate> compilePredicate(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, RowExpression predicate)
    {
        if (predicateCache == null) {
            return compilePredicateInternal(sqlFunctionProperties, sessionFunctions, predicate);
        }
        return predicateCache.getUnchecked(new CacheKey(sqlFunctionProperties, sessionFunctions, predicate));
    }

    private Supplier<Predicate> compilePredicateInternal(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, RowExpression predicate)
    {
        requireNonNull(predicate, "predicate is null");

        PageFieldsToInputParametersRewriter.Result result = rewritePageFieldsToInputParameters(predicate);
        int[] inputChannels = result.getInputChannels().getInputChannels().stream()
                .mapToInt(Integer::intValue)
                .toArray();

        CallSiteBinder callSiteBinder = new CallSiteBinder();
        ClassDefinition classDefinition = definePredicateClass(sqlFunctionProperties, sessionFunctions, result.getRewrittenExpression(), inputChannels, callSiteBinder);

        Class<? extends Predicate> predicateClass;
        try {
            predicateClass = defineClass(classDefinition, Predicate.class, callSiteBinder.getBindings(), getClass().getClassLoader());
        }
        catch (Exception e) {
            throw new PrestoException(COMPILER_ERROR, predicate.toString(), e.getCause());
        }

        return () -> {
            try {
                return predicateClass.getConstructor().newInstance();
            }
            catch (ReflectiveOperationException e) {
                throw new PrestoException(COMPILER_ERROR, e);
            }
        };
    }

    private ClassDefinition definePredicateClass(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            RowExpression predicate,
            int[] inputChannels,
            CallSiteBinder callSiteBinder)
    {
        ClassDefinition classDefinition = new ClassDefinition(
                a(PUBLIC, FINAL),
                makeClassName(Predicate.class.getSimpleName(), Optional.empty()),
                type(Object.class),
                type(Predicate.class));

        CachedInstanceBinder cachedInstanceBinder = new CachedInstanceBinder(classDefinition, callSiteBinder);

        generatePredicateMethod(sqlFunctionProperties, sessionFunctions, classDefinition, callSiteBinder, cachedInstanceBinder, predicate);

        // getInputChannels
        classDefinition.declareMethod(a(PUBLIC), "getInputChannels", type(int[].class))
                .getBody()
                .append(invoke(callSiteBinder.bind(inputChannels, int[].class), "getInputChannels"))
                .retObject();

        // constructor
        generateConstructor(classDefinition, cachedInstanceBinder);

        return classDefinition;
    }

    private MethodDefinition generatePredicateMethod(
            SqlFunctionProperties sqlFunctionProperties,
            Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions,
            ClassDefinition classDefinition,
            CallSiteBinder callSiteBinder,
            CachedInstanceBinder cachedInstanceBinder,
            RowExpression predicate)
    {
        AtomicInteger lambdaCounter = new AtomicInteger(0);
        Map<LambdaDefinitionExpression, LambdaBytecodeGenerator.CompiledLambda> compiledLambdaMap = generateMethodsForLambda(
                classDefinition,
                callSiteBinder,
                cachedInstanceBinder,
                predicate,
                metadata,
                sqlFunctionProperties,
                sessionFunctions,
                lambdaCounter);

        Parameter properties = arg("properties", SqlFunctionProperties.class);
        Parameter page = arg("page", Page.class);
        Parameter position = arg("position", int.class);

        MethodDefinition method = classDefinition.declareMethod(
                a(PUBLIC),
                "evaluate",
                type(boolean.class),
                ImmutableList.<Parameter>builder()
                        .add(properties)
                        .add(page)
                        .add(position)
                        .build());

        method.comment("Predicate: %s", predicate.toString());

        Scope scope = method.getScope();
        BytecodeBlock body = method.getBody();

        declareBlockVariables(predicate, page, scope, body);

        Variable wasNullVariable = scope.declareVariable("wasNull", body, constantFalse());
        RowExpressionCompiler compiler = new RowExpressionCompiler(
                classDefinition,
                callSiteBinder,
                cachedInstanceBinder,
                fieldReferenceCompiler(callSiteBinder),
                metadata,
                sqlFunctionProperties,
                sessionFunctions,
                compiledLambdaMap,
                lambdaCounter);

        Variable result = scope.declareVariable(boolean.class, "result");
        body.append(compiler.compile(predicate, scope, Optional.empty()))
                // store result so we can check for null
                .putVariable(result)
                .append(and(not(wasNullVariable), result).ret());
        return method;
    }

    private static void declareBlockVariables(RowExpression expression, Parameter page, Scope scope, BytecodeBlock body)
    {
        for (int channel : getInputChannels(expression)) {
            scope.declareVariable("block_" + channel, body, page.invoke("getBlock", Block.class, constantInt(channel)));
        }
    }

    private static List<Integer> getInputChannels(Iterable<RowExpression> expressions)
    {
        TreeSet<Integer> channels = new TreeSet<>();
        for (RowExpression expression : Expressions.subExpressions(expressions)) {
            if (expression instanceof InputReferenceExpression) {
                channels.add(((InputReferenceExpression) expression).getField());
            }
        }
        return ImmutableList.copyOf(channels);
    }

    private static List<Integer> getInputChannels(RowExpression expression)
    {
        return getInputChannels(ImmutableList.of(expression));
    }

    private static RowExpressionVisitor<BytecodeNode, Scope> fieldReferenceCompiler(CallSiteBinder callSiteBinder)
    {
        return new InputReferenceCompiler(
                (scope, field) -> scope.getVariable("block_" + field),
                (scope, field) -> scope.getVariable("position"),
                callSiteBinder);
    }

    private static void generateConstructor(
            ClassDefinition classDefinition,
            CachedInstanceBinder cachedInstanceBinder)
    {
        MethodDefinition constructorDefinition = classDefinition.declareConstructor(a(PUBLIC));

        BytecodeBlock body = constructorDefinition.getBody();
        Variable thisVariable = constructorDefinition.getThis();

        body.comment("super();")
                .append(thisVariable)
                .invokeConstructor(Object.class);

        cachedInstanceBinder.generateInitializations(thisVariable, body);
        body.ret();
    }

    private static final class CacheKey
    {
        private final SqlFunctionProperties sqlFunctionProperties;
        private final Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions;
        private final RowExpression rowExpression;

        private CacheKey(SqlFunctionProperties sqlFunctionProperties, Map<SqlFunctionId, SqlInvokedFunction> sessionFunctions, RowExpression rowExpression)
        {
            this.sqlFunctionProperties = requireNonNull(sqlFunctionProperties, "sqlFunctionProperties is null");
            this.sessionFunctions = requireNonNull(sessionFunctions, "sessionFunctions is null");
            this.rowExpression = requireNonNull(rowExpression, "rowExpression is null");
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (!(o instanceof CacheKey)) {
                return false;
            }
            CacheKey that = (CacheKey) o;
            return Objects.equals(sqlFunctionProperties, that.sqlFunctionProperties) &&
                    Objects.equals(rowExpression, that.rowExpression);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(sqlFunctionProperties, rowExpression);
        }
    }
}