TestRewriteAggregationIfToFilter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.assertions.ExpressionMatcher;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.function.Function;

import static com.facebook.presto.SystemSessionProperties.AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.functionCall;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.globalAggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;

public class TestRewriteAggregationIfToFilter
        extends BaseRuleTest
{
    @Test
    public void testDoesNotFireForNonIf()
    {
        // The aggregation expression is not an if expression.
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a", BooleanType.BOOLEAN);
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr"), p.rowExpression("count(a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("ds > '2021-07-01'")),
                                    p.values(ds))));
                }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForIfWithElse()
    {
        // The if expression has an else branch. We cannot rewrite it.
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a");
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr"), p.rowExpression("count(a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1, 2)")),
                                    p.values(ds))));
                }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForNonDeterministicFunction()
    {
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a", DOUBLE);
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr"), p.rowExpression("sum(a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("IF(ds > '2021-07-01', random())")),
                                    p.values(ds))));
                }).doesNotFire();
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a", BIGINT);
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr"), p.rowExpression("sum(a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("IF(random() > DOUBLE '0.1', 1)")),
                                    p.values(ds))));
                }).doesNotFire();
    }

    @Test
    public void testFireCount()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr"), p.rowExpression("count(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)")),
                                        p.values(ds))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr"), functionCall("count", ImmutableList.of("expr_0"))),
                                    ImmutableMap.of(new Symbol("expr"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(ImmutableMap.of(
                                                            "a", expression("IF(ds > '2021-07-01', 1)"),
                                                            "greater_than", expression("ds > '2021-07-01'"),
                                                            "expr_0", expression("1")),
                                                    values("ds")))));
        }
    }

    @Test
    public void testUnwrapIf()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr"), p.rowExpression("count(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-07-01', 1)")),
                                        p.values(ds))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr"), functionCall("count", ImmutableList.of("expr0"))),
                                    ImmutableMap.of(new Symbol("expr"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(ImmutableMap.of(
                                                            "a", expression("IF(ds > '2021-07-01', 1)"),
                                                            "greater_than", expression("ds > '2021-07-01'"),
                                                            "expr0", expression("1")),
                                                    values("ds")))));
        }
    }

    @Test
    public void testFireMin()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("MIN(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                        p.values(ds, column0))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("min", ImmutableList.of("column0_0"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-06-01', column0)"))
                                                            .put("greater_than", expression("ds > '2021-06-01'"))
                                                            .put("column0_0", expression("column0"))
                                                            .build(),
                                                    values("ds", "column0")))));
        }
    }

    @Test
    public void testFireMax()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("MAX(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                        p.values(ds, column0))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("max", ImmutableList.of("column0_0"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-06-01', column0)"))
                                                            .put("greater_than", expression("ds > '2021-06-01'"))
                                                            .put("column0_0", expression("column0"))
                                                            .build(),
                                                    values("ds", "column0")))));
        }
    }

    @Test
    public void testFireArbitrary()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("ARBITRARY(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                        p.values(ds, column0))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("arbitrary", ImmutableList.of("column0_0"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-06-01', column0)"))
                                                            .put("greater_than", expression("ds > '2021-06-01'"))
                                                            .put("column0_0", expression("column0"))
                                                            .build(),
                                                    values("ds", "column0")))));
        }
    }

    @Test
    public void testFireSum()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                        p.values(ds, column0))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("sum", ImmutableList.of("column0_0"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-06-01', column0)"))
                                                            .put("greater_than", expression("ds > '2021-06-01'"))
                                                            .put("column0_0", expression("column0"))
                                                            .build(),
                                                    values("ds", "column0")))));
        }
    }

    @Test
    public void testDoesNotFireForMaxBy()
    {
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a");
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr0"), p.rowExpression("MAX_BY(a, a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                    p.values(ds, column0))));
                }).doesNotFire();
    }

    @Test
    public void testDoesNotFireForMinBy()
    {
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a");
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr0"), p.rowExpression("MIN_BY(a, a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                    p.values(ds, column0))));
                }).doesNotFire();
    }

    @Test
    public void testFireTwoAggregations()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression b = p.variable("b");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("count(a)"))
                                .addAggregation(p.variable("expr1"), p.rowExpression("count(b)"))
                                .source(p.project(
                                        assignment(
                                                a, p.rowExpression("IF(ds > '2021-07-01', 1)"),
                                                b, p.rowExpression("IF(ds > '2021-06-01', 2)")),
                                        p.values(ds))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(
                                            Optional.of("expr0"), functionCall("count", ImmutableList.of("expr")),
                                            Optional.of("expr1"), functionCall("count", ImmutableList.of("expr_1"))),
                                    ImmutableMap.of(
                                            new Symbol("expr0"), new Symbol("greater_than"),
                                            new Symbol("expr1"), new Symbol("greater_than_0")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than or greater_than_0",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-07-01', 1)"))
                                                            .put("b", expression("IF(ds > '2021-06-01', 2)"))
                                                            .put("greater_than", expression("ds > '2021-07-01'"))
                                                            .put("expr", expression("1"))
                                                            .put("greater_than_0", expression("ds > '2021-06-01'"))
                                                            .put("expr_1", expression("2"))
                                                            .build(),
                                                    values("ds")))));
        }
    }

    @Test
    public void testFireTwoAggregationsWithSharedInput()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("MIN(a)"))
                                .addAggregation(p.variable("expr1"), p.rowExpression("MAX(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("IF(ds > '2021-06-01', column0)")),
                                        p.values(ds, column0))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(
                                            Optional.of("expr0"), functionCall("min", ImmutableList.of("column0_0")),
                                            Optional.of("expr1"), functionCall("max", ImmutableList.of("column0_0"))),
                                    ImmutableMap.of(
                                            new Symbol("expr0"), new Symbol("greater_than"),
                                            new Symbol("expr1"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-06-01', column0)"))
                                                            .put("greater_than", expression("ds > '2021-06-01'"))
                                                            .put("column0_0", expression("column0"))
                                                            .build(),
                                                    values("ds", "column0")))));
        }
    }

    @Test
    public void testFireForOneOfTwoAggregations()
    {
        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression b = p.variable("b");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("count(a)"))
                                .addAggregation(p.variable("expr1"), p.rowExpression("count(b)"))
                                .source(p.project(
                                        assignment(
                                                a, p.rowExpression("IF(ds > '2021-07-01', 1)"),
                                                b, p.rowExpression("ds")),
                                        p.values(ds))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(
                                            Optional.of("expr0"), functionCall("count", ImmutableList.of("expr")),
                                            Optional.of("expr1"), functionCall("count", ImmutableList.of("b"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "true",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("IF(ds > '2021-07-01', 1)"))
                                                            .put("b", expression("ds"))
                                                            .put("greater_than", expression("ds > '2021-07-01'"))
                                                            .put("expr", expression("1"))
                                                            .build(),
                                                    values("ds")))));
        }
    }

    @Test
    public void testArrayOffset()
    {
        for (String strategy : new String[] {"filter_with_if", "unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression arrayColumn = p.variable("arrayColumn", new ArrayType(BIGINT));
                        VariableReferenceExpression arrayElement = p.variable("arrayElement", BIGINT);
                        return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("SUM(arrayElement)"))
                                .source(p.project(
                                        assignment(arrayElement, p.rowExpression("IF(CARDINALITY(arrayColumn) > 0, arrayColumn[1])")),
                                        p.values(arrayColumn))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("SUM", ImmutableList.of("arrayElement"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("arrayElement", expression("IF(CARDINALITY(arrayColumn) > 0, arrayColumn[1])"))
                                                            .put("greater_than", expression("CARDINALITY(arrayColumn) > 0"))
                                                            .build(),
                                                    values("arrayColumn")))));
        }
    }

    @Test
    public void testDivide()
    {
        for (String strategy : new String[] {"filter_with_if", "unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a", BIGINT);
                        VariableReferenceExpression b = p.variable("b", BIGINT);
                        VariableReferenceExpression result = p.variable("result", BIGINT);
                        return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("SUM(result)"))
                                .source(p.project(
                                        assignment(result, p.rowExpression("IF(b != 0, a / b)")),
                                        p.values(a, b))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("SUM", ImmutableList.of("result"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("not_equal")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "not_equal",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("result", expression("IF(b != 0, a / b)"))
                                                            .put("not_equal", expression("b != 0"))
                                                            .build(),
                                                    values("a", "b")))));
        }

        // The condition expression doesn't reference the variables in the true branch. The IF can be unwrapped.
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "unwrap_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a", BIGINT);
                    VariableReferenceExpression b = p.variable("b", BIGINT);
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    VariableReferenceExpression result = p.variable("result", BIGINT);
                    return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr0"), p.rowExpression("SUM(result)"))
                            .source(p.project(
                                    assignment(result, p.rowExpression("IF(ds > '2021-07-01', a / b)")),
                                    p.values(ds, a, b))));
                })
                .matches(
                        aggregation(
                                globalAggregation(),
                                ImmutableMap.of(Optional.of("expr0"), functionCall("SUM", ImmutableList.of("result"))),
                                ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                Optional.empty(),
                                AggregationNode.Step.FINAL,
                                filter(
                                        "greater_than",
                                        project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                        .put("result", expression("a / b"))
                                                        .put("greater_than", expression("ds > '2021-07-01'"))
                                                        .build(),
                                                values("ds", "a", "b")))));
    }

    @Test
    public void testUnwrapIfForOneOfTwoAggregations()
    {
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "unwrap_if")
                .on(p -> {
                    VariableReferenceExpression result0 = p.variable("result0", BIGINT);
                    VariableReferenceExpression result1 = p.variable("result1", BIGINT);
                    VariableReferenceExpression a = p.variable("a", BIGINT);
                    VariableReferenceExpression b = p.variable("b", BIGINT);
                    return p.aggregation(aggregationBuilder -> aggregationBuilder.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr0"), p.rowExpression("count(result0)"))
                            .addAggregation(p.variable("expr1"), p.rowExpression("count(result1)"))
                            .source(p.project(
                                    assignment(
                                            result0, p.rowExpression("IF(b != 0, a / b)"),
                                            result1, p.rowExpression("IF(b > 0, b)")),
                                    p.values(a, b))));
                })
                .matches(
                        aggregation(
                                globalAggregation(),
                                ImmutableMap.of(
                                        Optional.of("expr0"), functionCall("count", ImmutableList.of("result0")),
                                        Optional.of("expr1"), functionCall("count", ImmutableList.of("b_0"))),
                                ImmutableMap.of(new Symbol("expr0"), new Symbol("not_equal"),
                                        new Symbol("expr1"), new Symbol("greater_than")),
                                Optional.empty(),
                                AggregationNode.Step.FINAL,
                                filter(
                                        "greater_than or not_equal",
                                        project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                        .put("result0", expression("IF(b != 0, a / b)"))
                                                        .put("result1", expression("IF(b > 0, b)"))
                                                        .put("b_0", expression("b"))
                                                        .put("not_equal", expression("b != 0"))
                                                        .put("greater_than", expression("b > 0"))
                                                        .build(),
                                                values("a", "b")))));
    }

    @Test
    public void testRewriteStrategies()
    {
        Function<PlanBuilder, PlanNode> planProvider = p -> {
            VariableReferenceExpression a = p.variable("a");
            VariableReferenceExpression column0 = p.variable("column0", BIGINT);
            return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                    .addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)"))
                    .source(p.project(
                            assignment(a, p.rowExpression("IF(column0 > 1, column0)")),
                            p.values(column0))));
        };

        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "disabled")
                .on(planProvider)
                .doesNotFire();

        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(planProvider)
                .matches(
                        aggregation(
                                globalAggregation(),
                                ImmutableMap.of(Optional.of("expr0"), functionCall("sum", ImmutableList.of("a"))),
                                ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                Optional.empty(),
                                AggregationNode.Step.FINAL,
                                filter(
                                        "greater_than",
                                        project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                        .put("a", expression("IF(column0 > 1, column0)"))
                                                        .put("greater_than", expression("column0 > 1"))
                                                        .build(),
                                                values("column0")))));

        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "unwrap_if_safe")
                .on(planProvider)
                .matches(
                        aggregation(
                                globalAggregation(),
                                ImmutableMap.of(Optional.of("expr0"), functionCall("sum", ImmutableList.of("a"))),
                                ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                Optional.empty(),
                                AggregationNode.Step.FINAL,
                                filter(
                                        "greater_than",
                                        project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                        .put("a", expression("IF(column0 > 1, column0)"))
                                                        .put("greater_than", expression("column0 > 1"))
                                                        .build(),
                                                values("column0")))));

        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "unwrap_if")
                .on(planProvider)
                .matches(
                        aggregation(
                                globalAggregation(),
                                ImmutableMap.of(Optional.of("expr0"), functionCall("sum", ImmutableList.of("column0_0"))),
                                ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                Optional.empty(),
                                AggregationNode.Step.FINAL,
                                filter(
                                        "greater_than",
                                        project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                        .put("a", expression("IF(column0 > 1, column0)"))
                                                        .put("greater_than", expression("column0 > 1"))
                                                        .put("column0_0", expression("column0"))
                                                        .build(),
                                                values("column0")))));
    }

    @Test
    public void testCast()
    {
        tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, "filter_with_if")
                .on(p -> {
                    VariableReferenceExpression a = p.variable("a");
                    VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                    VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                    return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                            .addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)"))
                            .source(p.project(
                                    assignment(a, p.rowExpression("CAST(IF(ds > '2021-06-01', column0) AS bigint)")),
                                    p.values(ds, column0))));
                })
                .matches(
                        aggregation(
                                globalAggregation(),
                                ImmutableMap.of(Optional.of("expr0"), functionCall("sum", ImmutableList.of("a"))),
                                ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                Optional.empty(),
                                AggregationNode.Step.FINAL,
                                filter(
                                        "greater_than",
                                        project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                        .put("a", expression("CAST(IF(ds > '2021-06-01', column0) as bigint)"))
                                                        .put("greater_than", expression("ds > '2021-06-01'"))
                                                        .build(),
                                                values("ds", "column0")))));

        for (String strategy : new String[] {"unwrap_if_safe", "unwrap_if"}) {
            tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
                    .setSystemProperty(AGGREGATION_IF_TO_FILTER_REWRITE_STRATEGY, strategy)
                    .on(p -> {
                        VariableReferenceExpression a = p.variable("a");
                        VariableReferenceExpression ds = p.variable("ds", VARCHAR);
                        VariableReferenceExpression column0 = p.variable("column0", BIGINT);
                        return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
                                .addAggregation(p.variable("expr0"), p.rowExpression("SUM(a)"))
                                .source(p.project(
                                        assignment(a, p.rowExpression("CAST(IF(ds > '2021-06-01', column0) AS bigint)")),
                                        p.values(ds, column0))));
                    })
                    .matches(
                            aggregation(
                                    globalAggregation(),
                                    ImmutableMap.of(Optional.of("expr0"), functionCall("sum", ImmutableList.of("cast"))),
                                    ImmutableMap.of(new Symbol("expr0"), new Symbol("greater_than")),
                                    Optional.empty(),
                                    AggregationNode.Step.FINAL,
                                    filter(
                                            "greater_than",
                                            project(new ImmutableMap.Builder<String, ExpressionMatcher>()
                                                            .put("a", expression("CAST(IF(ds > '2021-06-01', column0) as bigint)"))
                                                            .put("greater_than", expression("ds > '2021-06-01'"))
                                                            .put("cast", expression("CAST(column0 AS bigint)"))
                                                            .build(),
                                                    values("ds", "column0")))));
        }
    }
}