RemoveCrossJoinWithConstantInput.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.RowExpressionVariableInliner;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isRemoveCrossJoinWithConstantSingleRowInputEnabled;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableMap.toImmutableMap;

/**
 * When one side of a cross join is one single row of constant, we can remove the cross join and replace it with a project.
 * <pre>
 * - Cross Join
 *      - table scan
 *          left_field
 *      - values // only one row
 *          right_field := 1
 * </pre>
 * into
 * <pre>
 * - project
 *      left_field := left_field
 *      right_field := 1
 *      - table scan
 *          left_field
 * </pre>
 */
public class RemoveCrossJoinWithConstantInput
        implements Rule<JoinNode>
{
    private final RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator;

    public RemoveCrossJoinWithConstantInput(FunctionAndTypeManager functionAndTypeManager)
    {
        this.rowExpressionDeterminismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
    }

    @Override
    public Pattern<JoinNode> getPattern()
    {
        return join().matching(x -> x.getType().equals(JoinType.INNER) && x.getCriteria().isEmpty());
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isRemoveCrossJoinWithConstantSingleRowInputEnabled(session);
    }

    @Override
    public Result apply(JoinNode node, Captures captures, Context context)
    {
        PlanNode singleValueInput;
        PlanNode joinInput;
        PlanNode leftInput = context.getLookup().resolve(node.getLeft());
        PlanNode rightInput = context.getLookup().resolve(node.getRight());
        if (isOutputSingleConstantRow(rightInput, context)) {
            singleValueInput = rightInput;
            joinInput = leftInput;
        }
        else if (isOutputSingleConstantRow(leftInput, context)) {
            singleValueInput = leftInput;
            joinInput = rightInput;
        }
        else {
            return Result.empty();
        }
        Optional<Map<VariableReferenceExpression, RowExpression>> mapping = getConstantAssignments(singleValueInput, context);
        if (!mapping.isPresent()) {
            return Result.empty();
        }
        PlanNode resultNode = addProjections(joinInput, context.getIdAllocator(), mapping.get());
        if (node.getFilter().isPresent()) {
            resultNode = new FilterNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), resultNode, node.getFilter().get());
        }
        return Result.ofPlanNode(resultNode);
    }

    private boolean isOutputSingleConstantRow(PlanNode planNode, Context context)
    {
        while (planNode instanceof ProjectNode) {
            planNode = context.getLookup().resolve(((ProjectNode) planNode).getSource());
        }
        if (planNode instanceof ValuesNode) {
            return ((ValuesNode) planNode).getRows().size() == 1;
        }
        return false;
    }

    private Optional<Map<VariableReferenceExpression, RowExpression>> getConstantAssignments(PlanNode planNode, Context context)
    {
        List<VariableReferenceExpression> outputVariables = planNode.getOutputVariables();
        Map<VariableReferenceExpression, RowExpression> mapping = outputVariables.stream().collect(toImmutableMap(Function.identity(), Function.identity()));
        while (planNode instanceof ProjectNode) {
            Map<VariableReferenceExpression, RowExpression> assignments = ((ProjectNode) planNode).getAssignments().getMap();
            mapping = updateAssignments(mapping, assignments);
            planNode = context.getLookup().resolve(((ProjectNode) planNode).getSource());
        }

        checkState(planNode instanceof ValuesNode);
        ValuesNode valuesNode = (ValuesNode) planNode;
        if (!valuesNode.getOutputVariables().isEmpty()) {
            Map<VariableReferenceExpression, RowExpression> assignments = IntStream.range(0, valuesNode.getOutputVariables().size()).boxed()
                    .collect(toImmutableMap(idx -> valuesNode.getOutputVariables().get(idx), idx -> valuesNode.getRows().get(0).get(idx)));
            mapping = updateAssignments(mapping, assignments);
        }
        boolean allDeterministic = mapping.values().stream().allMatch(rowExpressionDeterminismEvaluator::isDeterministic);
        if (allDeterministic) {
            return Optional.of(mapping);
        }
        return Optional.empty();
    }

    private static Map<VariableReferenceExpression, RowExpression> updateAssignments(Map<VariableReferenceExpression, RowExpression> mapping, Map<VariableReferenceExpression, RowExpression> newAssignments)
    {
        return mapping.entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, entry -> RowExpressionVariableInliner.inlineVariables(newAssignments, entry.getValue())));
    }
}