TransformCorrelatedInPredicateToJoin.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InSubqueryExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.VariablesExtractor;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.AssignmentUtils;
import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.graph.SuccessorsFunction;
import com.google.common.graph.Traverser;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.FALSE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.or;
import static com.facebook.presto.matching.Pattern.nonEmpty;
import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.Patterns.Apply.correlation;
import static com.facebook.presto.sql.planner.plan.Patterns.applyNode;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.searchedCaseExpression;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.Streams.stream;
import static java.util.Objects.requireNonNull;

/**
 * Replaces correlated ApplyNode with InPredicate expression with SemiJoin
 * <p>
 * Transforms:
 * <pre>
 * - Apply (output: a in B.b)
 *    - input: some plan A producing symbol a
 *    - subquery: some plan B producing symbol b, using symbols from A
 * </pre>
 * Into:
 * <pre>
 * - Project (output: CASE WHEN (countmatches > 0) THEN true WHEN (countnullmatches > 0) THEN null ELSE false END)
 *   - Aggregate (countmatches=count(*) where a, b not null; countnullmatches where a,b null but buildSideKnownNonNull is not null)
 *     grouping by (A'.*)
 *     - LeftJoin on (A and B correlation condition)
 *       - AssignUniqueId (A')
 *         - A
 * </pre>
 * <p>
 *
 * @see TransformCorrelatedScalarAggregationToJoin
 */
public class TransformCorrelatedInPredicateToJoin
        implements Rule<ApplyNode>
{
    private static final Pattern<ApplyNode> PATTERN = applyNode()
            .with(nonEmpty(correlation()));

    private final FunctionResolution functionResolution;

    public TransformCorrelatedInPredicateToJoin(FunctionAndTypeManager functionAndTypeManager)
    {
        requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public Pattern<ApplyNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(ApplyNode apply, Captures captures, Context context)
    {
        Assignments subqueryAssignments = apply.getSubqueryAssignments();
        if (subqueryAssignments.size() != 1) {
            return Result.empty();
        }
        RowExpression assignmentExpression = getOnlyElement(subqueryAssignments.getExpressions());
        if (!(assignmentExpression instanceof InSubqueryExpression)) {
            return Result.empty();
        }
        InSubqueryExpression inPredicate = (InSubqueryExpression) assignmentExpression;
        VariableReferenceExpression inPredicateOutputVariable = getOnlyElement(subqueryAssignments.getVariables());

        return apply(apply, inPredicate, inPredicateOutputVariable, context.getLookup(), context.getIdAllocator(), context.getVariableAllocator());
    }

    private Result apply(
            ApplyNode apply,
            InSubqueryExpression inPredicate,
            VariableReferenceExpression inPredicateOutputVariable,
            Lookup lookup,
            PlanNodeIdAllocator idAllocator,
            VariableAllocator variableAllocator)
    {
        Optional<Decorrelated> decorrelated = new DecorrelatingVisitor(lookup, apply.getCorrelation(), TypeProvider.viewOf(variableAllocator.getVariables()))
                .decorrelate(apply.getSubquery());

        if (!decorrelated.isPresent()) {
            return Result.empty();
        }

        PlanNode projection = buildInPredicateEquivalent(
                apply,
                inPredicate,
                inPredicateOutputVariable,
                decorrelated.get(),
                idAllocator,
                variableAllocator);

        return Result.ofPlanNode(projection);
    }

    private PlanNode buildInPredicateEquivalent(
            ApplyNode apply,
            InSubqueryExpression inPredicate,
            VariableReferenceExpression inPredicateOutputVariable,
            Decorrelated decorrelated,
            PlanNodeIdAllocator idAllocator,
            VariableAllocator variableAllocator)
    {
        RowExpression correlationCondition = and(decorrelated.getCorrelatedPredicates());
        PlanNode decorrelatedBuildSource = decorrelated.getDecorrelatedNode();

        AssignUniqueId probeSide = new AssignUniqueId(
                apply.getSourceLocation(),
                idAllocator.getNextId(),
                apply.getInput(),
                variableAllocator.newVariable("unique", BIGINT));

        VariableReferenceExpression buildSideKnownNonNull = variableAllocator.newVariable(inPredicateOutputVariable.getSourceLocation(), "buildSideKnownNonNull", BIGINT);
        ProjectNode buildSide = new ProjectNode(
                idAllocator.getNextId(),
                decorrelatedBuildSource,
                Assignments.builder()
                        .putAll(identityAssignments(decorrelatedBuildSource.getOutputVariables()))
                        .put(buildSideKnownNonNull, constant(0L, BIGINT))
                        .build());

        VariableReferenceExpression probeSideSymbolReference = inPredicate.getValue();
        VariableReferenceExpression buildSideSymbolReference = inPredicate.getSubquery();

        RowExpression isProbeSideNull = specialForm(probeSideSymbolReference.getSourceLocation(), IS_NULL, BOOLEAN, probeSideSymbolReference);
        RowExpression isBuildSideNull = specialForm(buildSideSymbolReference.getSourceLocation(), IS_NULL, BOOLEAN, buildSideSymbolReference);
        RowExpression comparison = call(
                ComparisonExpression.Operator.EQUAL.name(),
                functionResolution.comparisonFunction(ComparisonExpression.Operator.EQUAL, probeSideSymbolReference.getType(), buildSideSymbolReference.getType()),
                BOOLEAN,
                probeSideSymbolReference,
                buildSideSymbolReference);

        RowExpression joinExpression = and(
                or(isProbeSideNull, comparison, isBuildSideNull),
                correlationCondition);

        JoinNode leftOuterJoin = leftOuterJoin(idAllocator, probeSide, buildSide, joinExpression);

        VariableReferenceExpression countMatchesVariable = variableAllocator.newVariable(buildSideSymbolReference.getSourceLocation(), "countMatches", BIGINT);
        VariableReferenceExpression countNullMatchesVariable = variableAllocator.newVariable(buildSideSymbolReference.getSourceLocation(), "countNullMatches", BIGINT);

        RowExpression matchCondition = and(
                isNotNull(probeSideSymbolReference),
                isNotNull(buildSideSymbolReference));

        RowExpression nullMatchCondition = and(
                isNotNull(buildSideKnownNonNull),
                not(matchCondition));

        AggregationNode aggregation = new AggregationNode(
                apply.getSourceLocation(),
                idAllocator.getNextId(),
                leftOuterJoin,
                ImmutableMap.<VariableReferenceExpression, AggregationNode.Aggregation>builder()
                        .put(countMatchesVariable, countWithFilter(matchCondition))
                        .put(countNullMatchesVariable, countWithFilter(nullMatchCondition))
                        .build(),
                singleGroupingSet(probeSide.getOutputVariables()),
                ImmutableList.of(),
                AggregationNode.Step.SINGLE,
                Optional.empty(),
                Optional.empty(),
                Optional.empty());

        // TODO since we care only about "some count > 0", we could have specialized node instead of leftOuterJoin that does the job without materializing join results
        RowExpression inPredicateEquivalent = searchedCaseExpression(
                ImmutableList.of(
                        specialForm(WHEN, BOOLEAN, isGreaterThan(countMatchesVariable, 0), TRUE_CONSTANT),
                        specialForm(WHEN, BOOLEAN, isGreaterThan(countNullMatchesVariable, 0), new ConstantExpression(null, BOOLEAN))),
                Optional.of(FALSE_CONSTANT));
        return new ProjectNode(
                idAllocator.getNextId(),
                aggregation,
                Assignments.builder()
                        .putAll(identityAssignments(apply.getInput().getOutputVariables()))
                        .put(inPredicateOutputVariable, inPredicateEquivalent)
                        .build());
    }

    private RowExpression isNotNull(RowExpression expression)
    {
        return not(specialForm(IS_NULL, BOOLEAN, ImmutableList.of(expression)));
    }

    private RowExpression not(RowExpression expression)
    {
        return call(
                expression.getSourceLocation(),
                "not",
                functionResolution.notFunction(),
                BOOLEAN,
                expression);
    }

    private static JoinNode leftOuterJoin(PlanNodeIdAllocator idAllocator, AssignUniqueId probeSide, ProjectNode buildSide, RowExpression joinExpression)
    {
        return new JoinNode(
                probeSide.getSourceLocation(),
                idAllocator.getNextId(),
                JoinType.LEFT,
                probeSide,
                buildSide,
                ImmutableList.of(),
                ImmutableList.<VariableReferenceExpression>builder()
                        .addAll(probeSide.getOutputVariables())
                        .addAll(buildSide.getOutputVariables())
                        .build(),
                Optional.of(joinExpression),
                Optional.empty(),
                Optional.empty(),
                Optional.empty(),
                ImmutableMap.of());
    }

    private AggregationNode.Aggregation countWithFilter(RowExpression condition)
    {
        return new AggregationNode.Aggregation(
                new CallExpression(
                        condition.getSourceLocation(),
                        "count",
                        functionResolution.countFunction(),
                        BIGINT,
                        ImmutableList.of()),
                Optional.of(condition),
                Optional.empty(),
                false,
                Optional.empty()); /* mask */
    }

    private RowExpression isGreaterThan(VariableReferenceExpression variable, long value)
    {
        return call(
                ComparisonExpression.Operator.GREATER_THAN.name(),
                functionResolution.comparisonFunction(ComparisonExpression.Operator.GREATER_THAN, BIGINT, BIGINT),
                BOOLEAN,
                variable,
                constant(value, BIGINT));
    }

    private static class DecorrelatingVisitor
            extends InternalPlanVisitor<Optional<Decorrelated>, PlanNode>
    {
        private final Lookup lookup;
        private final Set<VariableReferenceExpression> correlation;
        private final TypeProvider types;

        public DecorrelatingVisitor(Lookup lookup, Iterable<VariableReferenceExpression> correlation, TypeProvider types)
        {
            this.lookup = requireNonNull(lookup, "lookup is null");
            this.correlation = ImmutableSet.copyOf(requireNonNull(correlation, "correlation is null"));
            this.types = requireNonNull(types, "types is null");
        }

        public Optional<Decorrelated> decorrelate(PlanNode reference)
        {
            return lookup.resolve(reference).accept(this, reference);
        }

        @Override
        public Optional<Decorrelated> visitProject(ProjectNode node, PlanNode reference)
        {
            if (isCorrelatedShallowly(node)) {
                // TODO: handle correlated projection
                return Optional.empty();
            }

            Optional<Decorrelated> result = decorrelate(node.getSource());
            return result.map(decorrelated -> {
                Assignments.Builder assignments = Assignments.builder()
                        .putAll(node.getAssignments());

                // Pull up all symbols used by a filter (except correlation)
                decorrelated.getCorrelatedPredicates().stream()
                        .flatMap(expression -> stream(
                                Traverser.forTree((SuccessorsFunction<RowExpression>) RowExpression::getChildren)
                                        .depthFirstPreOrder(expression)))
                        .filter(VariableReferenceExpression.class::isInstance)
                        .map(VariableReferenceExpression.class::cast)
                        .filter(variable -> !correlation.contains(variable))
                        .map(AssignmentUtils::identityAssignments)
                        .forEach(assignments::putAll);

                return new Decorrelated(
                        decorrelated.getCorrelatedPredicates(),
                        new ProjectNode(
                                node.getId(), // FIXME should I reuse or not?
                                decorrelated.getDecorrelatedNode(),
                                assignments.build()));
            });
        }

        @Override
        public Optional<Decorrelated> visitFilter(FilterNode node, PlanNode reference)
        {
            Optional<Decorrelated> result = decorrelate(node.getSource());
            return result.map(decorrelated ->
                    new Decorrelated(
                            ImmutableList.<RowExpression>builder()
                                    .addAll(decorrelated.getCorrelatedPredicates())
                                    // No need to retain uncorrelated conditions, predicate push down will push them back
                                    .add(node.getPredicate())
                                    .build(),
                            decorrelated.getDecorrelatedNode()));
        }

        @Override
        public Optional<Decorrelated> visitPlan(PlanNode node, PlanNode reference)
        {
            if (isCorrelatedRecursively(node)) {
                return Optional.empty();
            }
            else {
                return Optional.of(new Decorrelated(ImmutableList.of(), reference));
            }
        }

        private boolean isCorrelatedRecursively(PlanNode node)
        {
            if (isCorrelatedShallowly(node)) {
                return true;
            }
            return node.getSources().stream()
                    .map(lookup::resolve)
                    .anyMatch(this::isCorrelatedRecursively);
        }

        private boolean isCorrelatedShallowly(PlanNode node)
        {
            return VariablesExtractor.extractUniqueNonRecursive(node).stream().anyMatch(correlation::contains);
        }
    }

    private static class Decorrelated
    {
        private final List<RowExpression> correlatedPredicates;
        private final PlanNode decorrelatedNode;

        public Decorrelated(List<RowExpression> correlatedPredicates, PlanNode decorrelatedNode)
        {
            this.correlatedPredicates = ImmutableList.copyOf(requireNonNull(correlatedPredicates, "correlatedPredicates is null"));
            this.decorrelatedNode = requireNonNull(decorrelatedNode, "decorrelatedNode is null");
        }

        public List<RowExpression> getCorrelatedPredicates()
        {
            return correlatedPredicates;
        }

        public PlanNode getDecorrelatedNode()
        {
            return decorrelatedNode;
        }
    }
}