DynamicFilters.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.expressions;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.predicate.Domain;
import com.facebook.presto.common.predicate.Range;
import com.facebook.presto.common.predicate.ValueSet;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;

import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.common.type.StandardTypes.BOOLEAN;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.expressions.RowExpressionTreeRewriter.rewriteWith;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.OR;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public final class DynamicFilters
{
    private DynamicFilters() {}

    public static DynamicFilterExtractResult extractDynamicFilters(RowExpression expression)
    {
        List<RowExpression> conjuncts = extractConjuncts(expression);

        ImmutableList.Builder<RowExpression> staticConjuncts = ImmutableList.builder();
        ImmutableList.Builder<DynamicFilterPlaceholder> dynamicConjuncts = ImmutableList.builder();

        for (RowExpression conjunct : conjuncts) {
            Optional<DynamicFilterPlaceholder> placeholder = getPlaceholder(conjunct);
            if (placeholder.isPresent()) {
                dynamicConjuncts.add(placeholder.get());
            }
            else {
                staticConjuncts.add(conjunct);
            }
        }

        return new DynamicFilterExtractResult(staticConjuncts.build(), dynamicConjuncts.build());
    }

    public static boolean isDynamicFilter(RowExpression expression)
    {
        return getPlaceholder(expression).isPresent();
    }

    public static Optional<DynamicFilterPlaceholder> getPlaceholder(RowExpression expression)
    {
        if (!(expression instanceof CallExpression)) {
            return Optional.empty();
        }

        CallExpression call = (CallExpression) expression;
        List<RowExpression> arguments = call.getArguments();

        if (!call.getDisplayName().equals(DynamicFilterPlaceholderFunction.NAME)) {
            return Optional.empty();
        }

        checkArgument(arguments.size() == 3, "invalid arguments count: %s", arguments.size());

        RowExpression probeSymbol = arguments.get(0);
        RowExpression operatorExpression = arguments.get(1);
        checkArgument(operatorExpression instanceof ConstantExpression);
        checkArgument(operatorExpression.getType() instanceof VarcharType);
        String operator = ((Slice) ((ConstantExpression) operatorExpression).getValue()).toStringUtf8();

        RowExpression idExpression = arguments.get(2);
        checkArgument(idExpression instanceof ConstantExpression);
        checkArgument(idExpression.getType() instanceof VarcharType);
        String id = ((Slice) ((ConstantExpression) idExpression).getValue()).toStringUtf8();

        OperatorType operatorType = OperatorType.valueOf(operator);
        if (operatorType.isComparisonOperator()) {
            return Optional.of(new DynamicFilterPlaceholder(id, probeSymbol, operatorType));
        }
        return Optional.empty();
    }

    public static RowExpression removeNestedDynamicFilters(RowExpression expression)
    {
        return rewriteWith(new RowExpressionRewriter<AtomicBoolean>()
        {
            @Override
            public RowExpression rewriteRowExpression(RowExpression node, AtomicBoolean context, RowExpressionTreeRewriter<AtomicBoolean> treeRewriter)
            {
                return node;
            }

            @Override
            public RowExpression rewriteSpecialForm(SpecialFormExpression node, AtomicBoolean modified, RowExpressionTreeRewriter<AtomicBoolean> treeRewriter)
            {
                if (!isConjunctiveDisjunctive(node.getForm())) {
                    return node;
                }

                checkState(BooleanType.BOOLEAN.equals(node.getType()), "AND/OR must be boolean function");

                ImmutableList.Builder<RowExpression> expressionBuilder = ImmutableList.builder();
                for (RowExpression argument : node.getArguments()) {
                    expressionBuilder.add(rewriteWith(this, argument, modified));
                }
                List<RowExpression> arguments = expressionBuilder.build();

                expressionBuilder = ImmutableList.builder();
                if (isDynamicFilter(arguments.get(0))) {
                    expressionBuilder.add(TRUE_CONSTANT);
                    modified.set(true);
                }
                else {
                    expressionBuilder.add(arguments.get(0));
                }

                if (isDynamicFilter(arguments.get(1))) {
                    expressionBuilder.add(TRUE_CONSTANT);
                    modified.set(true);
                }
                else {
                    expressionBuilder.add(arguments.get(1));
                }

                if (!modified.get()) {
                    return node;
                }

                arguments = expressionBuilder.build();
                if (node.getForm().equals(AND)) {
                    if (arguments.get(0).equals(TRUE_CONSTANT) && arguments.get(1).equals(TRUE_CONSTANT)) {
                        return TRUE_CONSTANT;
                    }

                    if (arguments.get(0).equals(TRUE_CONSTANT)) {
                        return arguments.get(1);
                    }

                    if (arguments.get(1).equals(TRUE_CONSTANT)) {
                        return arguments.get(0);
                    }
                }

                if (node.getForm().equals(OR) && (arguments.get(0).equals(TRUE_CONSTANT) || arguments.get(1).equals(TRUE_CONSTANT))) {
                    return TRUE_CONSTANT;
                }

                return new SpecialFormExpression(node.getForm(), node.getType(), arguments);
            }

            private boolean isConjunctiveDisjunctive(SpecialFormExpression.Form form)
            {
                return form == AND || form == OR;
            }
        }, expression, new AtomicBoolean(false));
    }

    public static class DynamicFilterExtractResult
    {
        private final List<RowExpression> staticConjuncts;
        private final List<DynamicFilterPlaceholder> dynamicConjuncts;

        public DynamicFilterExtractResult(List<RowExpression> staticConjuncts, List<DynamicFilterPlaceholder> dynamicConjuncts)
        {
            this.staticConjuncts = ImmutableList.copyOf(requireNonNull(staticConjuncts, "staticConjuncts is null"));
            this.dynamicConjuncts = ImmutableList.copyOf(requireNonNull(dynamicConjuncts, "dynamicConjuncts is null"));
        }

        public List<RowExpression> getStaticConjuncts()
        {
            return staticConjuncts;
        }

        public List<DynamicFilterPlaceholder> getDynamicConjuncts()
        {
            return dynamicConjuncts;
        }
    }

    public static final class DynamicFilterPlaceholder
    {
        private final String id;
        private final RowExpression input;
        private final OperatorType operator;

        public DynamicFilterPlaceholder(String id, RowExpression input, OperatorType operator)
        {
            this.id = requireNonNull(id, "id is null");
            this.input = requireNonNull(input, "input is null");
            this.operator = requireNonNull(operator, "operator is null");
        }

        public DynamicFilterPlaceholder(String id, RowExpression input)
        {
            this(id, input, EQUAL);
        }

        public String getId()
        {
            return id;
        }

        public RowExpression getInput()
        {
            return input;
        }

        public OperatorType getOperator()
        {
            return operator;
        }

        @Override
        public boolean equals(Object o)
        {
            if (this == o) {
                return true;
            }
            if (o == null || getClass() != o.getClass()) {
                return false;
            }
            DynamicFilterPlaceholder that = (DynamicFilterPlaceholder) o;
            return Objects.equals(id, that.id) &&
                    Objects.equals(input, that.input) &&
                    Objects.equals(operator, that.operator);
        }

        @Override
        public int hashCode()
        {
            return Objects.hash(id, input, operator);
        }

        @Override
        public String toString()
        {
            return toStringHelper(this)
                    .add("id", id)
                    .add("input", input)
                    .add("operator", operator)
                    .toString();
        }

        public Domain applyComparison(Domain domain)
        {
            if (domain.isNone() || domain.isAll()) {
                return domain;
            }
            Range span = domain.getValues().getRanges().getSpan();
            switch (operator) {
                case EQUAL:
                    return domain;
                case LESS_THAN: {
                    Range range = Range.lessThan(span.getType(), span.getHigh().getValue());
                    return Domain.create(ValueSet.ofRanges(range), false);
                }
                case LESS_THAN_OR_EQUAL: {
                    Range range = Range.lessThanOrEqual(span.getType(), span.getHigh().getValue());
                    return Domain.create(ValueSet.ofRanges(range), false);
                }
                case GREATER_THAN: {
                    Range range = Range.greaterThan(span.getType(), span.getLow().getValue());
                    return Domain.create(ValueSet.ofRanges(range), false);
                }
                case GREATER_THAN_OR_EQUAL: {
                    Range range = Range.greaterThanOrEqual(span.getType(), span.getLow().getValue());
                    return Domain.create(ValueSet.ofRanges(range), false);
                }
                default:
                    throw new IllegalArgumentException("Unsupported dynamic filtering comparison operator: " + operator);
            }
        }
    }

    @ScalarFunction(value = DynamicFilterPlaceholderFunction.NAME, visibility = HIDDEN)
    public static final class DynamicFilterPlaceholderFunction
    {
        private DynamicFilterPlaceholderFunction() {}

        public static final String NAME = "$internal$dynamic_filter_function";

        @TypeParameter("T")
        @SqlType(BOOLEAN)
        public static boolean dynamicFilter(@SqlType("T") Block input, @SqlType(VARCHAR) Slice operator, @SqlType(VARCHAR) Slice id)
        {
            throw new UnsupportedOperationException();
        }

        @TypeParameter("T")
        @SqlType(BOOLEAN)
        public static boolean dynamicFilter(@SqlType("T") Slice input, @SqlType(VARCHAR) Slice operator, @SqlType(VARCHAR) Slice id)
        {
            throw new UnsupportedOperationException();
        }

        @TypeParameter("T")
        @SqlType(BOOLEAN)
        public static boolean dynamicFilter(@SqlType("T") long input, @SqlType(VARCHAR) Slice operator, @SqlType(VARCHAR) Slice id)
        {
            throw new UnsupportedOperationException();
        }

        @TypeParameter("T")
        @SqlType(BOOLEAN)
        public static boolean dynamicFilter(@SqlType("T") boolean input, @SqlType(VARCHAR) Slice operator, @SqlType(VARCHAR) Slice id)
        {
            throw new UnsupportedOperationException();
        }

        @TypeParameter("T")
        @SqlType(BOOLEAN)
        public static boolean dynamicFilter(@SqlType("T") double input, @SqlType(VARCHAR) Slice operator, @SqlType(VARCHAR) Slice id)
        {
            throw new UnsupportedOperationException();
        }
    }
}