CrossJoinWithArrayNotContainsToAntiJoin.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.PlannerUtils;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.facebook.presto.SystemSessionProperties.isRewriteCrossJoinArrayNotContainsToAntiJoinEnabled;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.matching.Capture.newCapture;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.sql.planner.PlannerUtils.isSupportedArrayContainsFilter;
import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput;
import static com.facebook.presto.sql.planner.plan.Patterns.filter;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.planner.plan.Patterns.sources;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

/**
 * Inner join with contains function inside join clause will be run as cross join with filter.
 * When the join condition has pattern of contains(array, element), we can rewrite it to a inner join. For example:
 * <pre>
 * - Filter not contains(r_array, l_key)
 *      - Cross join
 *          - scan l
 *          - single_row_table_scan r
 * </pre>
 * into:
 * <pre>
 *     - Filter (l_array_elem IS NULL)
 *       - LOJ (l_key = l_array_elem)
 *          - scan l
 *          - Unnest
 *              l_array_elem <- unnest distinct_array
 *              - project
 *                  distinct_array := array_distinct(remove_nulls(r_array))
 *                  - single_row_table_scan r
 * </pre>
 */
public class CrossJoinWithArrayNotContainsToAntiJoin
        implements Rule<FilterNode>
{
    private static final Capture<JoinNode> JOIN = newCapture();

    private static final Capture<List<PlanNode>> JOIN_CHILDREN = Capture.newCapture();

    private static final Pattern<FilterNode> PATTERN = filter()
            .with(source().matching(join().matching(x -> x.isCrossJoin()).capturedAs(JOIN).with(sources().capturedAs(JOIN_CHILDREN))));

    Metadata metadata;
    private final FunctionAndTypeManager functionAndTypeManager;

    public CrossJoinWithArrayNotContainsToAntiJoin(Metadata metadata, FunctionAndTypeManager functionAndTypeManager)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
    }

    public static RowExpression getCandidateArrayNotContainsExpression(FunctionResolution functionResolution, RowExpression filterPredicate, List<VariableReferenceExpression> leftInput, List<VariableReferenceExpression> rightInput)
    {
        List<RowExpression> conjuncts = extractConjuncts(filterPredicate);
        for (RowExpression conjunct : conjuncts) {
            if (PlannerUtils.isNegationExpression(functionResolution, conjunct) &&
                    isSupportedArrayContainsFilter(functionResolution, conjunct.getChildren().get(0), leftInput, rightInput)) {
                return conjunct;
            }
        }
        return null;
    }

    @Override
    public Pattern<FilterNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isRewriteCrossJoinArrayNotContainsToAntiJoinEnabled(session);
    }

    @Override
    public Result apply(FilterNode node, Captures captures, Context context)
    {
        JoinNode joinNode = captures.get(JOIN);

        if (!(joinNode.getType().equals(JoinType.INNER) && joinNode.getCriteria().isEmpty())) {
            return Result.empty();
        }
        List<VariableReferenceExpression> leftColumns = joinNode.getLeft().getOutputVariables();
        List<VariableReferenceExpression> rightColumns = joinNode.getRight().getOutputVariables();
        RowExpression filterExpression = node.getPredicate();
        FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());

        RowExpression arrayNotContainsExpression = getCandidateArrayNotContainsExpression(functionResolution, filterExpression, leftColumns, rightColumns);
        if (arrayNotContainsExpression == null) {
            return Result.empty();
        }
        List<RowExpression> allConjuncts = extractConjuncts(filterExpression);
        List<RowExpression> remainingConjuncts = allConjuncts.stream().filter(x -> !x.equals(arrayNotContainsExpression)).collect(Collectors.toList());

        RowExpression arrayContainsExpression = arrayNotContainsExpression.getChildren().get(0);
        RowExpression array = ((CallExpression) arrayContainsExpression).getArguments().get(0);
        RowExpression element = ((CallExpression) arrayContainsExpression).getArguments().get(1);

        checkState(element instanceof VariableReferenceExpression, "Argument to CONTAINS is not a column");
        checkState(array instanceof VariableReferenceExpression, "Argument to CONTAINS is not a column");

        boolean arrayAtLeftInput = leftColumns.contains(array);
        PlanNode inputWithArray = arrayAtLeftInput ? joinNode.getLeft() : joinNode.getRight();

        if (!isFromScalarSubquery(context, inputWithArray)) {
            // rewrite is incorrect if array input has more than 1 row or columns included in the output: if the source of CROSS JOIN was a subquery these conditions are guaranteed
            return Result.empty();
        }
        final Type type = element.getType();
        CallExpression arrayDistinct = call(functionAndTypeManager, "array_distinct", new ArrayType(type),
                call(functionAndTypeManager, "remove_nulls", new ArrayType(type), array));
        VariableReferenceExpression arrayDistinctVariable = context.getVariableAllocator().newVariable(arrayDistinct);
        PlanNode project = PlannerUtils.addProjections(inputWithArray, context.getIdAllocator(), ImmutableMap.of(arrayDistinctVariable, arrayDistinct));
        VariableReferenceExpression unnestVariable = context.getVariableAllocator().newVariable("field", type);
        UnnestNode unnest = new UnnestNode(inputWithArray.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                project,
                project.getOutputVariables(),
                ImmutableMap.of(arrayDistinctVariable, ImmutableList.of(unnestVariable)),
                Optional.empty());

        PlanNode newLeftNode;

        if (arrayAtLeftInput) {
            newLeftNode = joinNode.getRight();
        }
        else {
            newLeftNode = joinNode.getLeft();
        }

        // if element is not a VariableReferenceExpression, push the expression into a Project node so the variable can be used in equijoins
        checkState(element instanceof VariableReferenceExpression, "Argument to CONTAINS is not a column");

        EquiJoinClause equiJoinClause = new EquiJoinClause((VariableReferenceExpression) element, unnestVariable);

        List<VariableReferenceExpression> newOutputColumns = Stream.concat(newLeftNode.getOutputVariables().stream(), unnest.getOutputVariables().stream()).collect(toImmutableList());

        JoinNode newJoinNode = new JoinNode(joinNode.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                JoinType.LEFT,
                newLeftNode,
                unnest,
                ImmutableList.of(equiJoinClause),
                newOutputColumns,
                joinNode.getFilter(),
                Optional.empty(),
                Optional.empty(),
                joinNode.getDistributionType(),
                joinNode.getDynamicFilters());

        RowExpression isNull = specialForm(IS_NULL, BOOLEAN, ImmutableList.of(unnestVariable));
        remainingConjuncts.add(isNull);
        FilterNode filterNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), newJoinNode, and(remainingConjuncts));
        PlanNode result = restrictOutput(filterNode, context.getIdAllocator(), joinNode.getOutputVariables());
        return Result.ofPlanNode(result);
    }

    private boolean isFromScalarSubquery(Context context, PlanNode node)
    {
        // TODO: currently we only support EnforceSingleRow which guarantees the filter+cross join was generated from a subquery (so no other columns needed from the array side of the cross join)
        PlanNode extractedNode = context.getLookup().resolve(node);
        return extractedNode instanceof EnforceSingleRowNode ||
                (extractedNode instanceof ProjectNode && isFromScalarSubquery(context, extractedNode.getSources().get(0)));
    }
}