TestJoinNodeFlattener.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.plan.MultiJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.testing.LocalQueryRunner;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.LinkedHashSet;
import java.util.Optional;

import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.spi.plan.JoinType.FULL;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.spi.plan.JoinType.LEFT;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.toMultiJoinNode;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static org.testng.Assert.assertEquals;

public class TestJoinNodeFlattener
{
    private static final int DEFAULT_JOIN_LIMIT = 10;
    private DeterminismEvaluator determinismEvaluator;
    private FunctionResolution functionResolution;

    private LocalQueryRunner queryRunner;
    private FunctionAndTypeResolver functionAndTypeResolver;

    @BeforeClass
    public void setUp()
    {
        queryRunner = new LocalQueryRunner(testSessionBuilder().build());
        determinismEvaluator = new RowExpressionDeterminismEvaluator(queryRunner.getMetadata());
        functionAndTypeResolver = queryRunner.getMetadata().getFunctionAndTypeManager().getFunctionAndTypeResolver();
        functionResolution = new FunctionResolution(functionAndTypeResolver);
    }

    @AfterClass(alwaysRun = true)
    public void tearDown()
    {
        closeAllRuntimeException(queryRunner);
        queryRunner = null;
    }

    @Test(expectedExceptions = IllegalStateException.class)
    public void testDoesNotAllowOuterJoin()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        JoinNode outerJoin = p.join(
                FULL,
                p.values(a1),
                p.values(b1),
                ImmutableList.of(equiJoinClause(a1, b1)),
                ImmutableList.of(a1, b1),
                Optional.empty());
        toMultiJoinNode(outerJoin, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator);
    }

    @Test
    public void testDoesNotConvertNestedOuterJoins()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression c1 = p.variable("C1");
        JoinNode leftJoin = p.join(
                LEFT,
                p.values(a1),
                p.values(b1),
                ImmutableList.of(equiJoinClause(a1, b1)),
                ImmutableList.of(a1, b1),
                Optional.empty());
        ValuesNode valuesC = p.values(c1);
        JoinNode joinNode = p.join(
                INNER,
                leftJoin,
                valuesC,
                ImmutableList.of(equiJoinClause(a1, c1)),
                ImmutableList.of(a1, b1, c1),
                Optional.empty());

        MultiJoinNode expected = MultiJoinNode.builder()
                .setSources(leftJoin, valuesC).setFilter(createEqualsExpression(a1, c1))
                .setOutputVariables(a1, b1, c1)
                .build();
        assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator), expected);
    }

    @Test
    public void testRetainsOutputSymbols()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression b2 = p.variable("B2");
        VariableReferenceExpression c1 = p.variable("C1");
        VariableReferenceExpression c2 = p.variable("C2");
        ValuesNode valuesA = p.values(a1);
        ValuesNode valuesB = p.values(b1, b2);
        ValuesNode valuesC = p.values(c1, c2);
        JoinNode joinNode = p.join(
                INNER,
                valuesA,
                p.join(
                        INNER,
                        valuesB,
                        valuesC,
                        ImmutableList.of(equiJoinClause(b1, c1)),
                        ImmutableList.of(b1, b2, c1, c2),
                        Optional.empty()),
                ImmutableList.of(equiJoinClause(a1, b1)),
                ImmutableList.of(a1, b1),
                Optional.empty());
        MultiJoinNode expected = MultiJoinNode.builder()
                .setSources(valuesA, valuesB, valuesC)
                .setFilter(and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1)))
                .setOutputVariables(a1, b1)
                .build();
        assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator), expected);
    }

    @Test
    public void testCombinesCriteriaAndFilters()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression b2 = p.variable("B2");
        VariableReferenceExpression c1 = p.variable("C1");
        VariableReferenceExpression c2 = p.variable("C2");
        ValuesNode valuesA = p.values(a1);
        ValuesNode valuesB = p.values(b1, b2);
        ValuesNode valuesC = p.values(c1, c2);
        RowExpression bcFilter = and(
                call(
                        OperatorType.GREATER_THAN.name(),
                        functionResolution.comparisonFunction(OperatorType.GREATER_THAN, c2.getType(), BIGINT),
                        BOOLEAN,
                        ImmutableList.of(c2, constant(0L, BIGINT))),
                call(
                        OperatorType.NOT_EQUAL.name(),
                        functionResolution.comparisonFunction(OperatorType.NOT_EQUAL, c2.getType(), BIGINT),
                        BOOLEAN,
                        ImmutableList.of(c2, constant(7L, BIGINT))),
                call(
                        OperatorType.GREATER_THAN.name(),
                        functionResolution.comparisonFunction(OperatorType.GREATER_THAN, b2.getType(), c2.getType()),
                        BOOLEAN,
                        ImmutableList.of(b2, c2)));
        RowExpression add = call(
                OperatorType.ADD.name(),
                functionResolution.arithmeticFunction(OperatorType.ADD, a1.getType(), c1.getType()),
                a1.getType(),
                ImmutableList.of(a1, c1));
        RowExpression abcFilter = call(
                OperatorType.LESS_THAN.name(),
                functionResolution.comparisonFunction(OperatorType.LESS_THAN, add.getType(), b1.getType()),
                BOOLEAN,
                ImmutableList.of(add, b1));
        JoinNode joinNode = p.join(
                INNER,
                valuesA,
                p.join(
                        INNER,
                        valuesB,
                        valuesC,
                        ImmutableList.of(equiJoinClause(b1, c1)),
                        ImmutableList.of(b1, b2, c1, c2),
                        Optional.of(bcFilter)),
                ImmutableList.of(equiJoinClause(a1, b1)),
                ImmutableList.of(a1, b1, b2, c1, c2),
                Optional.of(abcFilter));
        MultiJoinNode expected = new MultiJoinNode(
                new LinkedHashSet<>(ImmutableList.of(valuesA, valuesB, valuesC)),
                and(createEqualsExpression(b1, c1), createEqualsExpression(a1, b1), bcFilter, abcFilter),
                ImmutableList.of(a1, b1, b2, c1, c2),
                Assignments.builder().build(),
                false,
                Optional.empty());
        assertEquals(toMultiJoinNode(joinNode, noLookup(), DEFAULT_JOIN_LIMIT, false, functionResolution, determinismEvaluator), expected);
    }

    @Test
    public void testConvertsBushyTrees()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression c1 = p.variable("C1");
        VariableReferenceExpression d1 = p.variable("D1");
        VariableReferenceExpression d2 = p.variable("D2");
        VariableReferenceExpression e1 = p.variable("E1");
        VariableReferenceExpression e2 = p.variable("E2");
        ValuesNode valuesA = p.values(a1);
        ValuesNode valuesB = p.values(b1);
        ValuesNode valuesC = p.values(c1);
        ValuesNode valuesD = p.values(d1, d2);
        ValuesNode valuesE = p.values(e1, e2);
        JoinNode joinNode = p.join(
                INNER,
                p.join(
                        INNER,
                        p.join(
                                INNER,
                                valuesA,
                                valuesB,
                                ImmutableList.of(equiJoinClause(a1, b1)),
                                ImmutableList.of(a1, b1),
                                Optional.empty()),
                        valuesC,
                        ImmutableList.of(equiJoinClause(a1, c1)),
                        ImmutableList.of(a1, b1, c1),
                        Optional.empty()),
                p.join(
                        INNER,
                        valuesD,
                        valuesE,
                        ImmutableList.of(
                                equiJoinClause(d1, e1),
                                equiJoinClause(d2, e2)),
                        ImmutableList.of(d1, d2, e1, e2),
                        Optional.empty()),
                ImmutableList.of(equiJoinClause(b1, e1)),
                ImmutableList.of(a1, b1, c1, d1, d2, e1, e2),
                Optional.empty());
        MultiJoinNode expected = MultiJoinNode.builder()
                .setSources(valuesA, valuesB, valuesC, valuesD, valuesE)
                .setFilter(and(createEqualsExpression(a1, b1), createEqualsExpression(a1, c1), createEqualsExpression(d1, e1), createEqualsExpression(d2, e2), createEqualsExpression(b1, e1)))
                .setOutputVariables(a1, b1, c1, d1, d2, e1, e2)
                .build();
        assertEquals(toMultiJoinNode(joinNode, noLookup(), 5, false, functionResolution, determinismEvaluator), expected);
    }

    @Test
    public void testMoreThanJoinLimit()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression c1 = p.variable("C1");
        VariableReferenceExpression d1 = p.variable("D1");
        VariableReferenceExpression d2 = p.variable("D2");
        VariableReferenceExpression e1 = p.variable("E1");
        VariableReferenceExpression e2 = p.variable("E2");
        ValuesNode valuesA = p.values(a1);
        ValuesNode valuesB = p.values(b1);
        ValuesNode valuesC = p.values(c1);
        ValuesNode valuesD = p.values(d1, d2);
        ValuesNode valuesE = p.values(e1, e2);
        JoinNode join1 = p.join(
                INNER,
                valuesA,
                valuesB,
                ImmutableList.of(equiJoinClause(a1, b1)),
                ImmutableList.of(a1, b1),
                Optional.empty());
        JoinNode join2 = p.join(
                INNER,
                valuesD,
                valuesE,
                ImmutableList.of(
                        equiJoinClause(d1, e1),
                        equiJoinClause(d2, e2)),
                ImmutableList.of(d1, d2, e1, e2),
                Optional.empty());
        JoinNode joinNode = p.join(
                INNER,
                p.join(
                        INNER,
                        join1,
                        valuesC,
                        ImmutableList.of(equiJoinClause(a1, c1)),
                        ImmutableList.of(a1, b1, c1),
                        Optional.empty()),
                join2,
                ImmutableList.of(equiJoinClause(b1, e1)),
                ImmutableList.of(a1, b1, c1, d1, d2, e1, e2),
                Optional.empty());
        MultiJoinNode expected = MultiJoinNode.builder()
                .setSources(join1, join2, valuesC)
                .setFilter(and(createEqualsExpression(a1, c1), createEqualsExpression(b1, e1)))
                .setOutputVariables(a1, b1, c1, d1, d2, e1, e2)
                .build();
        assertEquals(toMultiJoinNode(joinNode, noLookup(), 2, true, functionResolution, determinismEvaluator), expected);
    }

    @Test
    public void testProjectNodesBetweenJoinNodesAreFlattenedForComplexEquiJoins()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression c1 = p.variable("C1");
        VariableReferenceExpression d1 = p.variable("D1");
        VariableReferenceExpression e1 = p.variable("E1");
        VariableReferenceExpression sum = p.variable("SUM");
        VariableReferenceExpression rename = p.variable("RENAME");
        VariableReferenceExpression renamePlusSum = p.variable("RENAME_PLUS_SUM");

        ValuesNode valuesA = p.values(a1);
        ValuesNode valuesB = p.values(b1);
        ValuesNode valuesC = p.values(c1);
        ValuesNode valuesD = p.values(d1);
        ValuesNode valuesE = p.values(e1);
        Assignments assignmentA1PlusB1 = Assignments.builder().put(sum, createAddExpression(a1, b1)).build();
        Assignments assignmentRenameC1 = Assignments.builder().put(rename, c1).build();
        Assignments assignmentRenamePlusSum = Assignments.builder().put(renamePlusSum, createAddExpression(rename, sum)).build();

        ProjectNode projectOverJoin3 = p.project(assignmentRenamePlusSum, p.join(
                INNER,
                p.project(assignmentA1PlusB1, p.join(// projectOverJoin1
                        INNER,
                        valuesA,
                        valuesB,
                        ImmutableList.of(equiJoinClause(a1, b1)),
                        ImmutableList.of(a1, b1),
                        Optional.empty())),
                p.project(assignmentRenameC1, p.join(// projectOverJoin2
                        INNER,
                        valuesC,
                        valuesD,
                        ImmutableList.of(equiJoinClause(c1, d1)),
                        ImmutableList.of(c1),
                        Optional.empty())),
                ImmutableList.of(equiJoinClause(sum, rename)),
                ImmutableList.of(sum, rename),
                Optional.empty()));

        JoinNode topMostJoinNode = p.join(
                INNER,
                valuesE,
                projectOverJoin3,
                ImmutableList.of(equiJoinClause(e1, renamePlusSum)),
                ImmutableList.of(e1, renamePlusSum),
                Optional.empty());

        MultiJoinNode expected = MultiJoinNode.builder()
                .setSources(valuesA, valuesB, valuesC, valuesD, valuesE)
                .setFilter(and(createEqualsExpression(a1, b1),
                        createEqualsExpression(c1, d1),
                        createEqualsExpression(createAddExpression(a1, b1), c1),
                        createEqualsExpression(e1, createAddExpression(c1, createAddExpression(a1, b1)))))
                .setAssignments(Assignments.of(e1, e1, renamePlusSum, createAddExpression(c1, createAddExpression(a1, b1))))
                .setOutputVariables(e1, c1, a1, b1)
                .build();
        MultiJoinNode actual = toMultiJoinNode(topMostJoinNode, noLookup(), 5, /*handleComplexEquiJoins*/ true, functionResolution, determinismEvaluator);
        assertEquals(actual, expected);

        // Negative test - when handleComplexEquiJoins = false, we have a split join space; the ProjectNodes are not flattened
        expected = MultiJoinNode.builder()
                .setSources(valuesE, projectOverJoin3)
                .setFilter(createEqualsExpression(e1, renamePlusSum))
                .setAssignments(Assignments.of())
                .setOutputVariables(e1, renamePlusSum)
                .build();

        assertEquals(toMultiJoinNode(topMostJoinNode, noLookup(), 5, /*handleComplexEquiJoins*/ false, functionResolution, determinismEvaluator), expected);
    }

    @Test
    public void testProjectNodesWithNonDeterministicAssignmentsAreNotFlattenedForComplexEquiJoins()
    {
        PlanBuilder p = planBuilder();
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        VariableReferenceExpression c1 = p.variable("C1");
        VariableReferenceExpression randomPlusSum = p.variable("RANDOM_PLUS_SUM");
        Assignments nonDeterministicAssignment = Assignments.builder().put(randomPlusSum, createAddExpression(createRandomExpression(), createAddExpression(a1, b1))).build();

        ValuesNode valuesA = p.values(a1);
        ValuesNode valuesB = p.values(b1);
        ValuesNode valuesC = p.values(c1);

        ProjectNode projectWithNonDeterministicAssignment = p.project(nonDeterministicAssignment, p.join(
                INNER,
                valuesA,
                valuesB,
                ImmutableList.of(equiJoinClause(a1, b1)),
                ImmutableList.of(a1, b1),
                Optional.empty()));

        JoinNode joinNodeToFlatten = p.join(
                INNER,
                projectWithNonDeterministicAssignment,
                valuesC,
                ImmutableList.of(equiJoinClause(randomPlusSum, c1)),
                ImmutableList.of(),
                Optional.empty());

        MultiJoinNode expected = MultiJoinNode.builder()
                .setSources(projectWithNonDeterministicAssignment, valuesC)
                .setFilter(createEqualsExpression(randomPlusSum, c1))
                .setAssignments(Assignments.of())
                .setOutputVariables()
                .build();

        assertEquals(toMultiJoinNode(joinNodeToFlatten, noLookup(), 5, /*handleComplexEquiJoins*/ true, functionResolution, determinismEvaluator), expected);
    }

    private CallExpression createRandomExpression()
    {
        return call("random", functionAndTypeResolver.lookupFunction("random", fromTypes()), DOUBLE);
    }

    private RowExpression createEqualsExpression(RowExpression left, RowExpression right)
    {
        return call(
                OperatorType.EQUAL.name(),
                functionResolution.comparisonFunction(OperatorType.EQUAL, left.getType(), right.getType()),
                BOOLEAN,
                ImmutableList.of(left, right));
    }

    private RowExpression createAddExpression(RowExpression left, RowExpression right)
    {
        return call(
                OperatorType.ADD.name(),
                functionResolution.arithmeticFunction(OperatorType.ADD, left.getType(), right.getType()),
                BIGINT,
                ImmutableList.of(left, right));
    }

    private EquiJoinClause equiJoinClause(VariableReferenceExpression variable1, VariableReferenceExpression variable2)
    {
        return new EquiJoinClause(variable1, variable2);
    }

    private PlanBuilder planBuilder()
    {
        return new PlanBuilder(TEST_SESSION, new PlanNodeIdAllocator(), queryRunner.getMetadata());
    }
}