RemoveRedundantCastToVarcharInJoinClause.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.SystemSessionProperties.isRemoveRedundantCastToVarcharInJoinEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.sql.planner.PlannerUtils.addProjections;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.facebook.presto.sql.relational.Expressions.castToBigInt;

/**
 * Remove redundant cast to varchar in join condition for queries like `select select * from orders o join customer c on cast(o.custkey as varchar) = cast(c.custkey as varchar)`
 * Transform from
 * <pre>
 *     - Join
 *          left_cast = right_cast
 *          - Project
 *              left_cast := cast(lkey as varchar)
 *              - TableScan
 *                  lkey BIGINT
 *          - Project
 *              right_cast := cast(rkey as varchar)
 *              - TableScan
 *                  rkey BIGINT
 *
 * </pre>
 * into
 * <pre>
 *     - Join
 *          new_lkey = new_rkey
 *          - Project
 *              left_cast := cast(lkey as varchar)
 *              new_lkey := lkey
 *              - TableScan
 *                  lkey BIGINT
 *          - Project
 *              right_cast := cast(rkey as varchar)
 *              new_rkey := rkey
 *              - TableScan
 *                  rkey BIGINT
 * </pre>
 * We will rely on optimizations later to remove unnecessary cast (if not used) and identity projection here.
 * <p>
 * Notice that we do not apply similar optimizations to queries with similar join condition like `cast(bigint as varchar) = varchar`. In general it can be converted to
 * `bigint = try_cast(varchar as bigint)` as if the varchar here cannot be converted to bigint, try_cast will return null and will not match anyway. However, a special case is
 * varchar begins with 0. `select cast(92 as varchar) = '092'` is false, but `select 92 = try_cast('092' as bigint)` returns true.
 */
public class RemoveRedundantCastToVarcharInJoinClause
        implements Rule<JoinNode>
{
    private static final List<Type> TYPE_SUPPORTED = ImmutableList.of(INTEGER, BIGINT);
    private final FunctionAndTypeManager functionAndTypeManager;
    private final FunctionResolution functionResolution;

    public RemoveRedundantCastToVarcharInJoinClause(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = functionAndTypeManager;
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isRemoveRedundantCastToVarcharInJoinEnabled(session);
    }

    @Override
    public Pattern<JoinNode> getPattern()
    {
        return join();
    }

    @Override
    public Result apply(JoinNode node, Captures captures, Context context)
    {
        PlanNode leftInput = context.getLookup().resolve(node.getLeft());
        PlanNode rightInput = context.getLookup().resolve(node.getRight());
        if (!(leftInput instanceof ProjectNode) || !(rightInput instanceof ProjectNode)) {
            return Result.empty();
        }
        ProjectNode leftProject = (ProjectNode) leftInput;
        ProjectNode rightProject = (ProjectNode) rightInput;

        ImmutableList.Builder<EquiJoinClause> joinClauseBuilder = ImmutableList.builder();
        ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newLeftAssignmentsBuilder = ImmutableMap.builder();
        ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newRightAssignmentsBuilder = ImmutableMap.builder();
        boolean isChanged = false;
        for (EquiJoinClause equiJoinClause : node.getCriteria()) {
            RowExpression leftProjectAssignment = leftProject.getAssignments().getMap().get(equiJoinClause.getLeft());
            RowExpression rightProjectAssignment = rightProject.getAssignments().getMap().get(equiJoinClause.getRight());
            if (!isSupportedCast(leftProjectAssignment) || !isSupportedCast(rightProjectAssignment)) {
                joinClauseBuilder.add(equiJoinClause);
                continue;
            }

            RowExpression leftAssignment = ((CallExpression) leftProjectAssignment).getArguments().get(0);
            RowExpression rightAssignment = ((CallExpression) rightProjectAssignment).getArguments().get(0);

            if (!leftAssignment.getType().equals(rightAssignment.getType())) {
                leftAssignment = castToBigInt(functionAndTypeManager, leftAssignment);
                rightAssignment = castToBigInt(functionAndTypeManager, rightAssignment);
            }

            VariableReferenceExpression newLeft = context.getVariableAllocator().newVariable(leftAssignment);
            newLeftAssignmentsBuilder.put(newLeft, leftAssignment);

            VariableReferenceExpression newRight = context.getVariableAllocator().newVariable(rightAssignment);
            newRightAssignmentsBuilder.put(newRight, rightAssignment);

            joinClauseBuilder.add(new EquiJoinClause(newLeft, newRight));
            isChanged = true;
        }

        if (!isChanged) {
            return Result.empty();
        }

        newLeftAssignmentsBuilder.putAll(leftProject.getAssignments().getMap());
        Map<VariableReferenceExpression, RowExpression> newLeftAssignments = newLeftAssignmentsBuilder.build();
        newRightAssignmentsBuilder.putAll(rightProject.getAssignments().getMap());
        Map<VariableReferenceExpression, RowExpression> newRightAssignments = newRightAssignmentsBuilder.build();

        PlanNode newLeftProject = addProjections(leftProject.getSource(), context.getIdAllocator(), newLeftAssignments);
        PlanNode newRightProject = addProjections(rightProject.getSource(), context.getIdAllocator(), newRightAssignments);

        return Result.ofPlanNode(new JoinNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), node.getType(), newLeftProject, newRightProject, joinClauseBuilder.build(), node.getOutputVariables(), node.getFilter(), Optional.empty(), Optional.empty(), node.getDistributionType(), node.getDynamicFilters()));
    }

    private boolean isSupportedCast(RowExpression rowExpression)
    {
        if (rowExpression instanceof CallExpression && functionResolution.isCastFunction(((CallExpression) rowExpression).getFunctionHandle())) {
            CallExpression cast = (CallExpression) rowExpression;
            return TYPE_SUPPORTED.contains(cast.getArguments().get(0).getType()) && cast.getType() instanceof VarcharType;
        }
        return false;
    }
}