TestLeftJoinNullFilterToSemiJoin.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.testng.annotations.Test;

import java.util.Optional;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values;

public class TestLeftJoinNullFilterToSemiJoin
        extends BaseRuleTest
{
    @Test
    public void testTrigger()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                })
                .matches(
                        project(
                                filter(
                                        "not(COALESCE(semijoinoutput, false))",
                                        semiJoin(
                                                "left_k1",
                                                "right_k1",
                                                "semijoinoutput",
                                                values("left_k1", "left_k2"),
                                                aggregation(singleGroupingSet(ImmutableList.of("right_k1")), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, values("right_k1"))))));
    }

    @Test
    public void testNotTriggerWithFilter()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    p.rowExpression("left_k2 > 10"),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                }).doesNotFire();
    }

    @Test
    public void testNotTriggerNotNull()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is not null"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                }).doesNotFire();
    }

    @Test
    public void testNotTriggerOtherOutputUsed()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    p.variable("right_k2", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1"), p.variable("right_k2")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                }).doesNotFire();
    }

    @Test
    public void testNotTriggerOutputUsedInFilter()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null or right_k1 > 2"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                }).doesNotFire();
    }

    @Test
    public void testNotTriggerOutputUsedInFilter2()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null and right_k1 > 2"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                }).doesNotFire();
    }

    @Test
    public void testNotTriggerOtherOutputUsedInFilter()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null or left_k2 > 2"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                }).doesNotFire();
    }

    @Test
    public void testTriggerForFilterWithAnd()
    {
        tester().assertThat(new LeftJoinNullFilterToSemiJoin(getMetadata().getFunctionAndTypeManager()))
                .on(p ->
                {
                    p.variable("left_k1", BIGINT);
                    p.variable("left_k2", BIGINT);
                    p.variable("right_k1", BIGINT);
                    return p.filter(
                            p.rowExpression("right_k1 is null and left_k2 > 2"),
                            p.join(JoinType.LEFT,
                                    p.values(p.variable("left_k1"), p.variable("left_k2")),
                                    p.values(p.variable("right_k1")),
                                    new EquiJoinClause(p.variable("left_k1"), p.variable("right_k1"))));
                })
                .matches(
                        filter(
                                "left_k2 > 2",
                                project(
                                        filter(
                                                "not(COALESCE(semijoinoutput, false))",
                                                semiJoin(
                                                        "left_k1",
                                                        "right_k1",
                                                        "semijoinoutput",
                                                        values("left_k1", "left_k2"),
                                                        aggregation(singleGroupingSet(ImmutableList.of("right_k1")), ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), AggregationNode.Step.SINGLE, values("right_k1")))))));
    }
}