CommonSubExpressionRewriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.bytecode.FieldDefinition;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import com.google.common.primitives.Primitives;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.bytecode.Access.PRIVATE;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantBoolean;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantNull;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.BIND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.relational.Expressions.subExpressions;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.function.Function.identity;

public class CommonSubExpressionRewriter
{
    private CommonSubExpressionRewriter() {}

    public static Map<Integer, Map<RowExpression, VariableReferenceExpression>> collectCSEByLevel(List<? extends RowExpression> expressions)
    {
        if (expressions.isEmpty()) {
            return ImmutableMap.of();
        }

        CommonSubExpressionCollector expressionCollector = new CommonSubExpressionCollector();
        expressions.forEach(expression -> expression.accept(expressionCollector, null));
        if (expressionCollector.cseByLevel.isEmpty()) {
            return ImmutableMap.of();
        }

        Map<Integer, Map<RowExpression, Integer>> cseByLevel = removeRedundantCSE(expressionCollector.cseByLevel, expressionCollector.expressionCount);

        VariableAllocator variableAllocator = new VariableAllocator();
        ImmutableMap.Builder<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressions = ImmutableMap.builder();
        Map<RowExpression, VariableReferenceExpression> rewriteWith = new HashMap<>();
        int startCSELevel = cseByLevel.keySet().stream().reduce(Math::min).get();
        int maxCSELevel = cseByLevel.keySet().stream().reduce(Math::max).get();
        for (int i = startCSELevel; i <= maxCSELevel; i++) {
            if (cseByLevel.containsKey(i)) {
                ExpressionRewriter rewriter = new ExpressionRewriter(rewriteWith);
                ImmutableMap.Builder<RowExpression, VariableReferenceExpression> expressionVariableMapBuilder = ImmutableMap.builder();
                for (Map.Entry<RowExpression, Integer> entry : cseByLevel.get(i).entrySet()) {
                    RowExpression rewrittenExpression = entry.getKey().accept(rewriter, null);
                    expressionVariableMapBuilder.put(rewrittenExpression, variableAllocator.newVariable(rewrittenExpression, "cse"));
                }
                Map<RowExpression, VariableReferenceExpression> expressionVariableMap = expressionVariableMapBuilder.build();
                commonSubExpressions.put(i, expressionVariableMap);
                rewriteWith.putAll(expressionVariableMap.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue())));
            }
        }
        return commonSubExpressions.build();
    }

    public static Map<Integer, Map<RowExpression, VariableReferenceExpression>> collectCSEByLevel(RowExpression expression)
    {
        return collectCSEByLevel(ImmutableList.of(expression));
    }

    public static Map<List<RowExpression>, Boolean> getExpressionsPartitionedByCSE(Collection<? extends RowExpression> expressions, int expressionGroupSize)
    {
        if (expressions.isEmpty()) {
            return ImmutableMap.of();
        }

        CommonSubExpressionCollector expressionCollector = new CommonSubExpressionCollector();
        expressions.forEach(expression -> expression.accept(expressionCollector, null));
        Set<RowExpression> cse = expressionCollector.cseByLevel.values().stream().flatMap(Set::stream).collect(toImmutableSet());

        if (cse.isEmpty()) {
            return expressions.stream().collect(toImmutableMap(ImmutableList::of, m -> false));
        }

        ImmutableMap.Builder<List<RowExpression>, Boolean> expressionsPartitionedByCse = ImmutableMap.builder();
        SubExpressionChecker subExpressionChecker = new SubExpressionChecker(cse);
        Map<Boolean, List<RowExpression>> expressionsWithCseFlag = expressions.stream().collect(Collectors.partitioningBy(expression -> expression.accept(subExpressionChecker, null)));
        expressionsWithCseFlag.get(false).forEach(expression -> expressionsPartitionedByCse.put(ImmutableList.of(expression), false));

        List<RowExpression> expressionsWithCse = expressionsWithCseFlag.get(true);
        if (expressionsWithCse.size() == 1) {
            RowExpression expression = expressionsWithCse.get(0);
            expressionsPartitionedByCse.put(ImmutableList.of(expression), true);
            return expressionsPartitionedByCse.build();
        }

        List<Set<RowExpression>> cseDependency = expressionsWithCse.stream()
                .map(expression -> subExpressions(expression).stream()
                        .filter(cse::contains)
                        .collect(toImmutableSet()))
                .collect(toImmutableList());

        boolean[] merged = new boolean[expressionsWithCse.size()];

        int i = 0;
        while (i < merged.length) {
            while (i < merged.length && merged[i]) {
                i++;
            }
            if (i >= merged.length) {
                break;
            }
            merged[i] = true;
            List<RowExpression> newList = new ArrayList<>();
            newList.add(expressionsWithCse.get(i));
            Set<RowExpression> dependencies = new HashSet<>();
            Set<RowExpression> first = cseDependency.get(i);
            dependencies.addAll(first);
            int j = i + 1;
            while (j < merged.length && newList.size() < expressionGroupSize) {
                while (j < merged.length && merged[j]) {
                    j++;
                }
                if (j >= merged.length) {
                    break;
                }
                Set<RowExpression> second = cseDependency.get(j);
                if (!Sets.intersection(dependencies, second).isEmpty()) {
                    RowExpression expression = expressionsWithCse.get(j);
                    newList.add(expression);
                    dependencies.addAll(second);
                    merged[j] = true;
                    j = i + 1;
                }
                else {
                    j++;
                }
            }
            expressionsPartitionedByCse.put(ImmutableList.copyOf(newList), true);
        }

        return expressionsPartitionedByCse.build();
    }

    public static RowExpression rewriteExpressionWithCSE(RowExpression expression, Map<RowExpression, VariableReferenceExpression> rewriteWith)
    {
        ExpressionRewriter rewriter = new ExpressionRewriter(rewriteWith);
        return expression.accept(rewriter, null);
    }

    private static Map<Integer, Map<RowExpression, Integer>> removeRedundantCSE(Map<Integer, Set<RowExpression>> cseByLevel, Map<RowExpression, Integer> expressionCount)
    {
        Map<Integer, Map<RowExpression, Integer>> results = new HashMap<>();
        int startCSELevel = cseByLevel.keySet().stream().reduce(Math::max).get();
        int stopCSELevel = cseByLevel.keySet().stream().reduce(Math::min).get();
        for (int i = startCSELevel; i > stopCSELevel; i--) {
            if (!cseByLevel.containsKey(i)) {
                continue;
            }
            Map<RowExpression, Integer> expressions = cseByLevel.get(i).stream().filter(expression -> expressionCount.get(expression) > 0).collect(toImmutableMap(identity(), expressionCount::get));
            if (!expressions.isEmpty()) {
                results.put(i, expressions);
            }
            for (RowExpression expression : expressions.keySet()) {
                int expressionOccurrence = expressionCount.get(expression);
                subExpressions(expression).stream()
                        .filter(subExpression -> !subExpression.equals(expression))
                        .forEach(subExpression -> {
                            if (expressionCount.containsKey(subExpression)) {
                                expressionCount.put(subExpression, expressionCount.get(subExpression) - expressionOccurrence);
                            }
                        });
            }
        }
        Map<RowExpression, Integer> expressions = cseByLevel.get(stopCSELevel).stream().filter(expression -> expressionCount.get(expression) > 0).collect(toImmutableMap(identity(), expression -> expressionCount.get(expression) + 1));
        if (!expressions.isEmpty()) {
            results.put(stopCSELevel, expressions);
        }
        return results;
    }

    static class SubExpressionChecker
            implements RowExpressionVisitor<Boolean, Void>
    {
        private final Set<RowExpression> subExpressions;

        SubExpressionChecker(Set<RowExpression> subExpressions)
        {
            this.subExpressions = subExpressions;
        }

        @Override
        public Boolean visitCall(CallExpression call, Void context)
        {
            if (subExpressions.contains(call)) {
                return true;
            }
            if (call.getArguments().isEmpty()) {
                return false;
            }
            return call.getArguments().stream().anyMatch(expression -> expression.accept(this, null));
        }

        @Override
        public Boolean visitInputReference(InputReferenceExpression reference, Void context)
        {
            return subExpressions.contains(reference);
        }

        @Override
        public Boolean visitConstant(ConstantExpression literal, Void context)
        {
            return subExpressions.contains(literal);
        }

        @Override
        public Boolean visitLambda(LambdaDefinitionExpression lambda, Void context)
        {
            return false;
        }

        @Override
        public Boolean visitVariableReference(VariableReferenceExpression reference, Void context)
        {
            return subExpressions.contains(reference);
        }

        @Override
        public Boolean visitSpecialForm(SpecialFormExpression specialForm, Void context)
        {
            if (subExpressions.contains(specialForm)) {
                return true;
            }
            if (specialForm.getArguments().isEmpty()) {
                return false;
            }
            return specialForm.getArguments().stream().anyMatch(expression -> expression.accept(this, null));
        }
    }

    static class ExpressionRewriter
            implements RowExpressionVisitor<RowExpression, Void>
    {
        private final Map<RowExpression, VariableReferenceExpression> expressionMap;

        public ExpressionRewriter(Map<RowExpression, VariableReferenceExpression> expressionMap)
        {
            this.expressionMap = ImmutableMap.copyOf(expressionMap);
        }

        @Override
        public RowExpression visitCall(CallExpression call, Void context)
        {
            RowExpression rewritten = new CallExpression(
                    call.getSourceLocation(),
                    call.getDisplayName(),
                    call.getFunctionHandle(),
                    call.getType(),
                    call.getArguments().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList()));
            if (expressionMap.containsKey(rewritten)) {
                return expressionMap.get(rewritten);
            }
            return rewritten;
        }

        @Override
        public RowExpression visitInputReference(InputReferenceExpression reference, Void context)
        {
            return reference;
        }

        @Override
        public RowExpression visitConstant(ConstantExpression literal, Void context)
        {
            return literal;
        }

        @Override
        public RowExpression visitLambda(LambdaDefinitionExpression lambda, Void context)
        {
            return lambda;
        }

        @Override
        public RowExpression visitVariableReference(VariableReferenceExpression reference, Void context)
        {
            return reference;
        }

        @Override
        public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Void context)
        {
            SpecialFormExpression rewritten = new SpecialFormExpression(
                    specialForm.getForm(),
                    specialForm.getType(),
                    specialForm.getArguments().stream().map(argument -> argument.accept(this, null)).collect(toImmutableList()));
            if (expressionMap.containsKey(rewritten)) {
                return expressionMap.get(rewritten);
            }
            return rewritten;
        }
    }

    static class CommonSubExpressionCollector
            implements RowExpressionVisitor<Integer, Void>
    {
        private final Map<Integer, Set<RowExpression>> expressionsByLevel = new HashMap<>();
        private final Map<Integer, Set<RowExpression>> cseByLevel = new HashMap<>();
        private final Map<RowExpression, Integer> expressionCount = new HashMap<>();

        private int addAtLevel(int level, RowExpression expression)
        {
            Set<RowExpression> rowExpressions = getExpressionsAtLevel(level, expressionsByLevel);
            expressionCount.putIfAbsent(expression, 1);
            if (rowExpressions.contains(expression)) {
                getExpressionsAtLevel(level, cseByLevel).add(expression);
                int count = expressionCount.get(expression) + 1;
                expressionCount.put(expression, count);
            }
            rowExpressions.add(expression);
            return level;
        }

        private static Set<RowExpression> getExpressionsAtLevel(int level, Map<Integer, Set<RowExpression>> expressionsByLevel)
        {
            expressionsByLevel.putIfAbsent(level, new HashSet<>());
            return expressionsByLevel.get(level);
        }

        @Override
        public Integer visitCall(CallExpression call, Void collect)
        {
            if (call.getArguments().isEmpty()) {
                // Do not track leaf expression
                return 0;
            }
            return addAtLevel(call.getArguments().stream().map(argument -> argument.accept(this, collect)).reduce(Math::max).get() + 1, call);
        }

        @Override
        public Integer visitInputReference(InputReferenceExpression reference, Void collect)
        {
            return 0;
        }

        @Override
        public Integer visitConstant(ConstantExpression literal, Void collect)
        {
            return 0;
        }

        @Override
        public Integer visitLambda(LambdaDefinitionExpression lambda, Void collect)
        {
            return 0;
        }

        @Override
        public Integer visitVariableReference(VariableReferenceExpression reference, Void collect)
        {
            return 0;
        }

        @Override
        public Integer visitSpecialForm(SpecialFormExpression specialForm, Void collect)
        {
            int level = specialForm.getArguments().stream().map(argument -> argument.accept(this, null)).reduce(Math::max).get() + 1;
            if (specialForm.getForm() != WHEN && specialForm.getForm() != BIND) {
                // BIND returns a function type rather than a value type
                // WHEN is part of CASE expression. We do not have a separate code generator to generate code for WHEN expression separately so do not consider them as CSE
                // TODO If we detect a whole WHEN statement as CSE we should probably only keep one
                addAtLevel(level, specialForm);
            }
            return level;
        }
    }

    static class CommonSubExpressionFields
    {
        private final FieldDefinition evaluatedField;
        private final FieldDefinition resultField;
        private final Class<?> resultType;
        private final String methodName;

        public CommonSubExpressionFields(FieldDefinition evaluatedField, FieldDefinition resultField, Class<?> resultType, String methodName)
        {
            this.evaluatedField = evaluatedField;
            this.resultField = resultField;
            this.resultType = resultType;
            this.methodName = methodName;
        }

        public FieldDefinition getEvaluatedField()
        {
            return evaluatedField;
        }

        public FieldDefinition getResultField()
        {
            return resultField;
        }

        public String getMethodName()
        {
            return methodName;
        }

        public Class<?> getResultType()
        {
            return resultType;
        }

        public static Map<VariableReferenceExpression, CommonSubExpressionFields> declareCommonSubExpressionFields(ClassDefinition classDefinition, Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel)
        {
            ImmutableMap.Builder<VariableReferenceExpression, CommonSubExpressionFields> fields = ImmutableMap.builder();
            commonSubExpressionsByLevel.values().stream().map(Map::values).flatMap(Collection::stream).forEach(variable -> {
                Class<?> type = Primitives.wrap(variable.getType().getJavaType());
                fields.put(variable, new CommonSubExpressionFields(
                        classDefinition.declareField(a(PRIVATE), variable.getName() + "Evaluated", boolean.class),
                        classDefinition.declareField(a(PRIVATE), variable.getName() + "Result", type),
                        type,
                        "get" + variable.getName()));
            });
            return fields.build();
        }

        public static void initializeCommonSubExpressionFields(Collection<CommonSubExpressionFields> cseFields, Variable thisVariable, BytecodeBlock body)
        {
            cseFields.forEach(fields -> {
                body.append(thisVariable.setField(fields.getEvaluatedField(), constantBoolean(false)));
                body.append(thisVariable.setField(fields.getResultField(), constantNull(fields.getResultType())));
            });
        }
    }
}