AddNotNullFiltersToJoinNode.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.Session;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.IntermediateFormExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;

import java.util.Collection;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.getNotNullInferenceStrategy;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinNotNullInferenceStrategy.NONE;
import static com.facebook.presto.sql.planner.plan.Patterns.join;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Sets.intersection;
import static java.util.Collections.singletonList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Stream.concat;

public class AddNotNullFiltersToJoinNode
        implements Rule<JoinNode>
{
    private static final Pattern<JoinNode> PATTERN = join();
    private final FunctionAndTypeManager functionAndTypeManager;
    private final Logger logger = Logger.get(AddNotNullFiltersToJoinNode.class);
    private final FunctionResolution functionResolution;

    public AddNotNullFiltersToJoinNode(FunctionAndTypeManager functionAndTypeManager)
    {
        this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return getNotNullInferenceStrategy(session) != NONE;
    }

    @Override
    public Pattern<JoinNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(JoinNode joinNode, Captures captures, Context context)
    {
        Collection<VariableReferenceExpression> inferredNotNullVariables;
        JoinNotNullInferenceStrategy notNullInferenceStrategy = getNotNullInferenceStrategy(context.getSession());

        switch (joinNode.getType()) {
            case LEFT:
                // NOT NULL can be inferred for the right-side variables
                inferredNotNullVariables = extractNotNullVariables(joinNode.getCriteria(), joinNode.getFilter(), joinNode.getRight().getOutputVariables(), notNullInferenceStrategy);
                break;
            case RIGHT:
                // NOT NULL can be inferred for the left-side variables
                inferredNotNullVariables = extractNotNullVariables(joinNode.getCriteria(), joinNode.getFilter(), joinNode.getLeft().getOutputVariables(), notNullInferenceStrategy);
                break;
            case INNER:
                // NOT NULL can be inferred for variables from both sides of the join
                inferredNotNullVariables = extractNotNullVariables(joinNode.getCriteria(), joinNode.getFilter(), concat(joinNode.getLeft().getOutputVariables().stream(),
                        joinNode.getRight().getOutputVariables().stream()).collect(toImmutableList()), notNullInferenceStrategy);
                break;
            case FULL:
            default:
                // NOT NULL cannot be inferred
                return Result.empty();
        }

        if (inferredNotNullVariables.isEmpty()) {
            return Result.empty();
        }

        Set<VariableReferenceExpression> existingNotNullVariables = getExistingNotNullVariables(joinNode.getFilter());
        logger.debug("NotNull filters :: Existing : %s, Inferred :%s", existingNotNullVariables, inferredNotNullVariables);

        if (existingNotNullVariables.containsAll(inferredNotNullVariables)) {
            // No new NOT NULL variables were inferred
            return Result.empty();
        }

        RowExpression updatedJoinFilter = and(joinNode.getFilter().orElse(TRUE_CONSTANT), buildNotNullRowExpression(inferredNotNullVariables));
        return Result.ofPlanNode(
                new JoinNode(joinNode.getSourceLocation(),
                        context.getIdAllocator().getNextId(),
                        joinNode.getType(),
                        joinNode.getLeft(),
                        joinNode.getRight(),
                        joinNode.getCriteria(),
                        joinNode.getOutputVariables(),
                        Optional.ofNullable(updatedJoinFilter),
                        joinNode.getLeftHashVariable(),
                        joinNode.getRightHashVariable(),
                        joinNode.getDistributionType(),
                        joinNode.getDynamicFilters()));
    }

    private Collection<VariableReferenceExpression> extractNotNullVariables(List<EquiJoinClause> joinCriteria, Optional<RowExpression> joinFilter,
            List<VariableReferenceExpression> candidates, JoinNotNullInferenceStrategy notNullInferenceStrategy)
    {
        RowExpression combinedFilter = TRUE_CONSTANT;

        for (EquiJoinClause criteria : joinCriteria) {
            combinedFilter = and(combinedFilter, criteria.getLeft());
            combinedFilter = and(combinedFilter, criteria.getRight());
        }

        combinedFilter = and(combinedFilter, joinFilter.orElse(TRUE_CONSTANT));

        return intersection(ImmutableSet.copyOf(candidates), inferNotNullVariables(combinedFilter, notNullInferenceStrategy));
    }

    @VisibleForTesting
    Set<VariableReferenceExpression> getExistingNotNullVariables(Optional<RowExpression> joinFilter)
    {
        if (!joinFilter.isPresent()) {
            return ImmutableSet.of();
        }

        ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();

        DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>> isNotNullExtractingVisitor =
                new DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>>()
                {
                    @Override
                    public Void visitCall(CallExpression call, ImmutableSet.Builder<VariableReferenceExpression> context)
                    {
                        // Match a 'not(IS_NULL(VariableReferenceExpression))' call *exactly*
                        if (functionResolution.isNotFunction(call.getFunctionHandle()) &&
                                call.getArguments().size() == 1 &&
                                call.getArguments().get(0) instanceof SpecialFormExpression &&
                                ((SpecialFormExpression) call.getArguments().get(0)).getForm() == IS_NULL &&
                                ((SpecialFormExpression) call.getArguments().get(0)).getArguments().size() == 1 &&
                                ((SpecialFormExpression) call.getArguments().get(0)).getArguments().get(0) instanceof VariableReferenceExpression) {
                            context.add((VariableReferenceExpression) ((SpecialFormExpression) call.getArguments().get(0)).getArguments().get(0));
                        }
                        return null;
                    }

                    @Override
                    public Void visitIntermediateFormExpression(IntermediateFormExpression expression, ImmutableSet.Builder<VariableReferenceExpression> context)
                    {
                        return null;
                    }

                    @Override
                    public Void visitSpecialForm(SpecialFormExpression specialForm, ImmutableSet.Builder<VariableReferenceExpression> context)
                    {
                        if (specialForm.getForm() == AND) {
                            return super.visitSpecialForm(specialForm, context);
                        }
                        return null;
                    }
                };

        joinFilter.get().accept(isNotNullExtractingVisitor, builder);
        return builder.build();
    }

    private ImmutableSet<VariableReferenceExpression> inferNotNullVariables(RowExpression expression, JoinNotNullInferenceStrategy notNullInferenceStrategy)
    {
        ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
        expression.accept(new ExtractInferredNotNullVariablesVisitor(functionAndTypeManager, notNullInferenceStrategy), builder);
        return builder.build();
    }

    private RowExpression buildNotNullRowExpression(Collection<VariableReferenceExpression> expressions)
    {
        List<CallExpression> isNotNullExpressions = expressions.stream().map(x -> new CallExpression(
                        x.getSourceLocation(),
                        "not",
                        functionResolution.notFunction(),
                        BOOLEAN,
                        singletonList(new SpecialFormExpression(x.getSourceLocation(), IS_NULL, BOOLEAN, x))))
                .collect(toImmutableList());

        return and(isNotNullExpressions);
    }

    @VisibleForTesting
    public static class ExtractInferredNotNullVariablesVisitor
            extends DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>>
    {
        private final FunctionAndTypeManager functionAndTypeManager;
        private final JoinNotNullInferenceStrategy notNullInferenceStrategy;

        public ExtractInferredNotNullVariablesVisitor(FunctionAndTypeManager functionAndTypeManager, JoinNotNullInferenceStrategy notNullInferenceStrategy)
        {
            this.functionAndTypeManager = functionAndTypeManager;
            this.notNullInferenceStrategy = notNullInferenceStrategy;
        }

        @Override
        public Void visitCall(CallExpression call, ImmutableSet.Builder<VariableReferenceExpression> context)
        {
            final FunctionHandle functionHandle = call.getFunctionHandle();
            final FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(functionHandle);

            switch (notNullInferenceStrategy) {
                case INFER_FROM_STANDARD_OPERATORS:
                    if (!functionMetadata.getOperatorType().isPresent() || functionMetadata.getOperatorType().get().isCalledOnNullInput()) {
                        // We can't map this CallExpression to an OperatorType OR
                        // this OperatorType can be called on NULL inputs, so we can't make NOT NULL inferences on it's arguments
                        return null;
                    }
                    break;
                case USE_FUNCTION_METADATA:
                    if (functionMetadata.isCalledOnNullInput()) {
                        // Since this function can operate on NULL inputs and return a valid value, we can't make NOT NULL inference on it's arguments
                        return null;
                    }
                    break;
                default:
                    return null;
            }

            return super.visitCall(call, context);
        }

        @Override
        public Void visitSpecialForm(SpecialFormExpression specialForm, ImmutableSet.Builder<VariableReferenceExpression> context)
        {
            SpecialFormExpression.Form form = specialForm.getForm();
            if (form == AND) {
                // All arguments of an AND expression must be NOT NULL for the expression to be true
                // Hence, we can proceed with extracting candidates for NOT NULL inference on it's arguments
                return super.visitSpecialForm(specialForm, context);
            }
            // For all other SpecialForms e.g. OR, COALESCE, IS_NULL, CASE, DEREFERENCE we abstain from making NOT NULL inferences
            return null;
        }

        @Override
        public Void visitIntermediateFormExpression(IntermediateFormExpression expression, ImmutableSet.Builder<VariableReferenceExpression> context)
        {
            // TODO : For now, we are not traversing any IntermediateFormExpression's. For some cases, such as InSubqueryExpression
            // we may be able to do some null inference-ing
            return null;
        }

        @Override
        public Void visitVariableReference(VariableReferenceExpression variableReferenceExpression, ImmutableSet.Builder<VariableReferenceExpression> context)
        {
            context.add(variableReferenceExpression);
            return super.visitVariableReference(variableReferenceExpression, context);
        }
    }
}