SortExpressionExtractor.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.List;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Collections.singletonList;
import static java.util.Comparator.comparing;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

/**
 * Extracts sort expression to be used for creating {@link com.facebook.presto.operator.SortedPositionLinks} from join filter expression.
 * Currently this class can extract sort and search expressions from filter function conjuncts of shape:
 * <p>
 * {@code A.a < f(B.x, B.y, B.z)} or {@code f(B.x, B.y, B.z) < A.a}
 * <p>
 * where {@code a} is the build side symbol reference and {@code x,y,z} are probe
 * side symbol references. Any of inequality operators ({@code <,<=,>,>=}) can be used.
 * Same build side symbol need to be used in all conjuncts.
 */
public final class SortExpressionExtractor
{
    /* TODO:
       This class could be extended to handle any expressions like:
       A.a * sin(A.b) / log(B.x) < cos(B.z)
       by transforming it to:
       f(A.a, A.b) < g(B.x, B.z)
       Where f(...) and g(...) would be some functions/expressions. That
       would allow us to perform binary search on arbitrary complex expressions
       by sorting position links according to the result of f(...) function.
     */
    private SortExpressionExtractor() {}

    public static Optional<SortExpressionContext> getSortExpressionContext(JoinNode joinNode, FunctionAndTypeManager functionAndTypeManager)
    {
        return joinNode.getFilter()
                .flatMap(filter -> SortExpressionExtractor.extractSortExpression(ImmutableSet.copyOf(joinNode.getRight().getOutputVariables()), filter, functionAndTypeManager));
    }

    public static Optional<SortExpressionContext> extractSortExpression(Set<VariableReferenceExpression> buildVariables, RowExpression filter, FunctionAndTypeManager functionAndTypeManager)
    {
        List<RowExpression> filterConjuncts = LogicalRowExpressions.extractConjuncts(filter);
        SortExpressionVisitor visitor = new SortExpressionVisitor(buildVariables, functionAndTypeManager);

        DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
        List<SortExpressionContext> sortExpressionCandidates = filterConjuncts.stream()
                .filter(determinismEvaluator::isDeterministic)
                .map(conjunct -> conjunct.accept(visitor, null))
                .filter(Optional::isPresent)
                .map(Optional::get)
                .collect(toMap(SortExpressionContext::getSortExpression, identity(), SortExpressionExtractor::merge))
                .values()
                .stream()
                .collect(toImmutableList());

        // For now heuristically pick sort expression which has most search expressions assigned to it.
        // TODO: make it cost based decision based on symbol statistics
        return sortExpressionCandidates.stream()
                .sorted(comparing(context -> -1 * context.getSearchExpressions().size()))
                .findFirst();
    }

    private static SortExpressionContext merge(SortExpressionContext left, SortExpressionContext right)
    {
        checkArgument(left.getSortExpression().equals(right.getSortExpression()));
        ImmutableList.Builder<RowExpression> searchExpressions = ImmutableList.builder();
        searchExpressions.addAll(left.getSearchExpressions());
        searchExpressions.addAll(right.getSearchExpressions());
        return new SortExpressionContext(left.getSortExpression(), searchExpressions.build());
    }

    private static class SortExpressionVisitor
            implements RowExpressionVisitor<Optional<SortExpressionContext>, Void>
    {
        private final Set<VariableReferenceExpression> buildVariables;
        private final FunctionAndTypeManager functionAndTypeManager;

        public SortExpressionVisitor(Set<VariableReferenceExpression> buildVariables, FunctionAndTypeManager functionAndTypeManager)
        {
            this.buildVariables = buildVariables;
            this.functionAndTypeManager = functionAndTypeManager;
        }

        @Override
        public Optional<SortExpressionContext> visitCall(CallExpression call, Void context)
        {
            FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle());
            if (!functionMetadata.getOperatorType().map(OperatorType::isComparisonOperator).orElse(false)) {
                return Optional.empty();
            }

            switch (functionMetadata.getOperatorType().get()) {
                case GREATER_THAN:
                case GREATER_THAN_OR_EQUAL:
                case LESS_THAN:
                case LESS_THAN_OR_EQUAL:
                    RowExpression left = call.getArguments().get(0);
                    RowExpression right = call.getArguments().get(1);
                    Optional<VariableReferenceExpression> sortChannel = asBuildVariableReference(buildVariables, right);
                    boolean hasBuildReferencesOnOtherSide = hasBuildVariableReference(buildVariables, left);
                    if (!sortChannel.isPresent()) {
                        sortChannel = asBuildVariableReference(buildVariables, left);
                        hasBuildReferencesOnOtherSide = hasBuildVariableReference(buildVariables, right);
                    }
                    if (sortChannel.isPresent() && !hasBuildReferencesOnOtherSide) {
                        return sortChannel.map(variableReference -> new SortExpressionContext(variableReference, singletonList(call)));
                    }
                    return Optional.empty();
                default:
                    return Optional.empty();
            }
        }

        @Override
        public Optional<SortExpressionContext> visitInputReference(InputReferenceExpression input, Void context)
        {
            return Optional.empty();
        }

        @Override
        public Optional<SortExpressionContext> visitConstant(ConstantExpression literal, Void context)
        {
            return Optional.empty();
        }

        @Override
        public Optional<SortExpressionContext> visitLambda(LambdaDefinitionExpression lambda, Void context)
        {
            return Optional.empty();
        }

        @Override
        public Optional<SortExpressionContext> visitVariableReference(VariableReferenceExpression reference, Void context)
        {
            return Optional.empty();
        }

        @Override
        public Optional<SortExpressionContext> visitSpecialForm(SpecialFormExpression specialForm, Void context)
        {
            return Optional.empty();
        }
    }

    private static Optional<VariableReferenceExpression> asBuildVariableReference(Set<VariableReferenceExpression> buildLayout, RowExpression expression)
    {
        // Currently only we support only symbol as sort expression on build side
        if (expression instanceof VariableReferenceExpression) {
            VariableReferenceExpression reference = (VariableReferenceExpression) expression;
            if (buildLayout.contains(reference)) {
                return Optional.of(reference);
            }
        }
        return Optional.empty();
    }

    private static boolean hasBuildVariableReference(Set<VariableReferenceExpression> buildVariables, RowExpression expression)
    {
        return expression.accept(new BuildVariableReferenceFinder(buildVariables), null);
    }

    private static class BuildVariableReferenceFinder
            implements RowExpressionVisitor<Boolean, Void>
    {
        private final Set<VariableReferenceExpression> buildVariables;

        public BuildVariableReferenceFinder(Set<VariableReferenceExpression> buildVariables)
        {
            this.buildVariables = ImmutableSet.copyOf(requireNonNull(buildVariables, "buildVariables is null"));
        }

        @Override
        public Boolean visitInputReference(InputReferenceExpression input, Void context)
        {
            return false;
        }

        @Override
        public Boolean visitCall(CallExpression call, Void context)
        {
            for (RowExpression argument : call.getArguments()) {
                if (argument.accept(this, context)) {
                    return true;
                }
            }
            return false;
        }

        @Override
        public Boolean visitConstant(ConstantExpression literal, Void context)
        {
            return false;
        }

        @Override
        public Boolean visitLambda(LambdaDefinitionExpression lambda, Void context)
        {
            return lambda.getBody().accept(this, context);
        }

        @Override
        public Boolean visitVariableReference(VariableReferenceExpression reference, Void context)
        {
            return buildVariables.contains(reference);
        }

        @Override
        public Boolean visitSpecialForm(SpecialFormExpression specialForm, Void context)
        {
            for (RowExpression argument : specialForm.getArguments()) {
                if (argument.accept(this, context)) {
                    return true;
                }
            }
            return false;
        }
    }
}