ReplaceConditionalApproxDistinct.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.Map.Entry;

import static com.facebook.presto.SystemSessionProperties.isOptimizeConditionalApproxDistinctEnabled;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.planner.plan.Patterns.aggregation;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.isNull;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

/**
 * elimination of approx distinct on conditional constant values.
 * <p>
 * depending on the inner conditional, the expression is converted
 * to its equivalent arbitrary() expression.
 *
 *     - approx_distinct(if(..., non-null)) -> arbitrary(if(..., 1, NULL))
 *     - approx_distinct(if(..., null, non-null)) -> arbitrary(if(..., NULL, 1))
 *     - approx_distinct(if(..., null, null)) -> arbitrary(0)
 *
 * An intermediate projection is inserted to convert any NULL arbitrary output
 * to zero values.
 */
public class ReplaceConditionalApproxDistinct
        implements Rule<AggregationNode>
{
    private static final Capture<ProjectNode> SOURCE = Capture.newCapture();

    private static final Pattern<AggregationNode> PATTERN = aggregation()
            .with(source().matching(project().capturedAs(SOURCE)));

    private final StandardFunctionResolution functionResolution;

    private static final String ARBITRARY = "arbitrary";

    public ReplaceConditionalApproxDistinct(FunctionAndTypeManager functionAndTypeManager)
    {
        requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isOptimizeConditionalApproxDistinctEnabled(session);
    }

    @Override
    public Pattern<AggregationNode> getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(AggregationNode parent, Captures captures, Context context)
    {
        VariableAllocator variableAllocator = context.getVariableAllocator();
        boolean changed = false;
        ProjectNode project = captures.get(SOURCE);
        Assignments.Builder outputs = Assignments.builder();
        Assignments.Builder inputs = Assignments.builder();

        ImmutableMap.Builder<VariableReferenceExpression, AggregationNode.Aggregation> aggregations = ImmutableMap.builder();
        for (Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : parent.getAggregations().entrySet()) {
            VariableReferenceExpression variable = entry.getKey();
            AggregationNode.Aggregation aggregation = entry.getValue();
            SpecialFormExpression replaced;
            VariableReferenceExpression intermediate;
            VariableReferenceExpression expression;

            if (!isApproxDistinct(aggregation) || !aggregationIsReplaceable(aggregation, project.getAssignments())) {
                aggregations.put(variable, aggregation);
                outputs.put(variable, variable);
                continue;
            }
            changed = true;
            replaced = (SpecialFormExpression) project.getAssignments().get(
                    (VariableReferenceExpression) aggregation.getArguments().get(0));

            expression = variableAllocator.newVariable("expression", BIGINT);
            inputs.put(expression, replaceIfExpression(replaced));

            intermediate = variableAllocator.newVariable("intermediate", BIGINT);
            aggregations.put(intermediate, new AggregationNode.Aggregation(
                    new CallExpression(
                            aggregation.getCall().getSourceLocation(),
                            ARBITRARY,
                            functionResolution.arbitraryFunction(BIGINT),
                            BIGINT,
                            ImmutableList.of(expression)),
                    aggregation.getFilter(),
                    aggregation.getOrderBy(),
                    aggregation.isDistinct(),
                    aggregation.getMask()));

            outputs.put(variable, new SpecialFormExpression(
                    COALESCE,
                    BIGINT,
                    ImmutableList.of(
                        intermediate,
                        constant(0L, BIGINT))));
        }

        if (!changed) {
            return Result.empty();
        }

        ProjectNode child = new ProjectNode(
                project.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                project.getSource(),
                inputs.putAll(project.getAssignments()).build(),
                project.getLocality());

        AggregationNode aggregation = new AggregationNode(
                parent.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                child,
                aggregations.build(),
                parent.getGroupingSets(),
                ImmutableList.of(),
                parent.getStep(),
                parent.getHashVariable(),
                parent.getGroupIdVariable(),
                parent.getAggregationId());

        aggregation.getHashVariable().ifPresent(hashvariable -> outputs.put(hashvariable, hashvariable));
        aggregation.getGroupingSets().getGroupingKeys().forEach(groupingKey -> outputs.put(groupingKey, groupingKey));
        return Result.ofPlanNode(new ProjectNode(
                context.getIdAllocator().getNextId(),
                aggregation,
                outputs.build()));
    }

    private boolean isApproxDistinct(AggregationNode.Aggregation aggregation)
    {
        return functionResolution.isApproximateCountDistinctFunction(aggregation.getFunctionHandle());
    }

    private ConstantExpression convertConstant(ConstantExpression expression)
    {
        return isNull(expression) ? constantNull(BIGINT) : constant(1L, BIGINT);
    }

    private RowExpression replaceIfExpression(SpecialFormExpression ifCondition)
    {
        ConstantExpression trueThen = (ConstantExpression) ifCondition.getArguments().get(1);
        ConstantExpression falseThen = (ConstantExpression) ifCondition.getArguments().get(2);
        RowExpression replace;

        if ((isNull(trueThen) && !isNull(falseThen)) || (!isNull(trueThen) && isNull(falseThen))) {
            // if(..., null, non-null) or if(..., non-null, null)
            replace = new SpecialFormExpression(
                    ifCondition.getSourceLocation(),
                    IF,
                    BIGINT,
                    ImmutableList.of(
                        ifCondition.getArguments().get(0),
                        convertConstant(trueThen),
                        convertConstant(falseThen)));
        }
        else {
            // if(..., null, null)
            checkState(isNull(trueThen) && isNull(falseThen),
                    "expected true (%s) and false (%s) predicates to be null",
                    trueThen, falseThen);
            replace = convertConstant(trueThen);
        }
        return replace;
    }

    private boolean aggregationIsReplaceable(AggregationNode.Aggregation aggregation, Assignments inputs)
    {
        RowExpression argument = aggregation.getArguments().get(0);
        RowExpression ifCondition = null;
        RowExpression trueThen = null;
        RowExpression falseThen = null;

        if (argument instanceof VariableReferenceExpression) {
            ifCondition = inputs.get((VariableReferenceExpression) argument);
        }

        if (ifCondition instanceof SpecialFormExpression && ((SpecialFormExpression) ifCondition).getForm() == IF) {
            trueThen = ((SpecialFormExpression) ifCondition).getArguments().get(1);
            falseThen = ((SpecialFormExpression) ifCondition).getArguments().get(2);
        }

        return trueThen instanceof ConstantExpression &&
                falseThen instanceof ConstantExpression &&
                (isNull(trueThen) || isNull(falseThen));
    }
}