TestJoinEnumerator.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.CachingCostProvider;
import com.facebook.presto.cost.CachingStatsProvider;
import com.facebook.presto.cost.CostComparator;
import com.facebook.presto.cost.CostProvider;
import com.facebook.presto.cost.PlanCostEstimate;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.LogicalPropertiesProvider;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.relation.DeterminismEvaluator;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.TestingRowExpressionTranslator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.assertions.RowExpressionVerifier;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.Lookup;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerationResult;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator;
import com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.JoinCondition;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.planner.plan.MultiJoinNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.facebook.presto.sql.tree.SymbolReference;
import com.facebook.presto.testing.LocalQueryRunner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.airlift.testing.Closeables.closeAllRuntimeException;
import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression;
import static com.facebook.presto.sql.planner.iterative.Lookup.noLookup;
import static com.facebook.presto.sql.planner.iterative.rule.ReorderJoins.JoinEnumerator.generatePartitions;
import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.expression;
import static com.facebook.presto.sql.planner.optimizations.JoinNodeUtils.toRowExpression;
import static com.facebook.presto.sql.relational.Expressions.variable;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;

public class TestJoinEnumerator
{
    private LocalQueryRunner queryRunner;
    private Metadata metadata;
    private DeterminismEvaluator determinismEvaluator;
    private FunctionResolution functionResolution;
    private PlanBuilder planBuilder;
    private TestingRowExpressionTranslator rowExpressionTranslator;
    private Session session;

    @BeforeClass
    public void setUp()
    {
        session = testSessionBuilder().build();
        queryRunner = new LocalQueryRunner(session);
        metadata = queryRunner.getMetadata();
        determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata);
        functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
        planBuilder = new PlanBuilder(session, new PlanNodeIdAllocator(), metadata);
        rowExpressionTranslator = new TestingRowExpressionTranslator(metadata);
    }

    @AfterClass(alwaysRun = true)
    public void tearDown()
    {
        closeAllRuntimeException(queryRunner);
        queryRunner = null;
    }

    @Test
    public void testGeneratePartitions()
    {
        assertEquals(generatePartitions(4),
                ImmutableSet.of(
                        ImmutableSet.of(0),
                        ImmutableSet.of(0, 1),
                        ImmutableSet.of(0, 2),
                        ImmutableSet.of(0, 3),
                        ImmutableSet.of(0, 1, 2),
                        ImmutableSet.of(0, 1, 3),
                        ImmutableSet.of(0, 2, 3)));

        assertEquals(generatePartitions(3),
                ImmutableSet.of(
                        ImmutableSet.of(0),
                        ImmutableSet.of(0, 1),
                        ImmutableSet.of(0, 2)));
    }

    @Test
    public void testDoesNotCreateJoinWhenPartitionedOnCrossJoin()
    {
        PlanNodeIdAllocator idAllocator = new PlanNodeIdAllocator();
        PlanBuilder p = new PlanBuilder(TEST_SESSION, idAllocator, queryRunner.getMetadata());
        VariableReferenceExpression a1 = p.variable("A1");
        VariableReferenceExpression b1 = p.variable("B1");
        MultiJoinNode multiJoinNode = new MultiJoinNode(
                new LinkedHashSet<>(ImmutableList.of(p.values(a1), p.values(b1))),
                TRUE_CONSTANT,
                ImmutableList.of(a1, b1),
                Assignments.of(),
                false,
                Optional.empty());
        JoinEnumerator joinEnumerator = new JoinEnumerator(
                new CostComparator(1, 1, 1),
                multiJoinNode.getFilter(),
                createContext(),
                determinismEvaluator,
                functionResolution,
                metadata);
        JoinEnumerationResult actual = joinEnumerator.createJoinAccordingToPartitioning(multiJoinNode.getSources(), multiJoinNode.getOutputVariables(), ImmutableSet.of(0));
        assertFalse(actual.getPlanNode().isPresent());
        assertEquals(actual.getCost(), PlanCostEstimate.infinite());
    }

    @Test
    public void testJoinClauseAndFilterInference()
    {
        ImmutableMap.Builder<String, Type> builder = ImmutableMap.builder();
        builder.put("a", BIGINT);
        builder.put("b", BIGINT);
        builder.put("c", BIGINT);
        builder.put("d", BIGINT);
        Map<String, Type> variableMap = builder.build();
        VariableReferenceExpression a = variable("a", variableMap.get("a"));
        VariableReferenceExpression b = variable("b", variableMap.get("b"));
        VariableReferenceExpression c = variable("c", variableMap.get("c"));
        VariableReferenceExpression d = variable("d", variableMap.get("d"));

        SymbolAliases.Builder newAliases = SymbolAliases.builder();
        newAliases.put("A", new SymbolReference("a"));
        newAliases.put("B", new SymbolReference("b"));
        newAliases.put("C", new SymbolReference("c"));
        newAliases.put("D", new SymbolReference("d"));
        SymbolAliases symbolAliases = newAliases.build();

        // Simple join predicates on variable references
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B", null);
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b", "c = d"), ImmutableSet.of(a, c), ImmutableSet.of(b, d), "A = B AND C = D", null);
        // Complex join predicate - All variables must be from one join side to have the predicate be an equi-join clause
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C", null);
        // Left and right side designation can be switched
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C", null);
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(a), ImmutableSet.of(b, c), "A = B + C + 1", null);
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = b + c + 1"), ImmutableSet.of(b, c), ImmutableSet.of(a), "A = B + C + 1", null);
        // If a predicate has a mix of variables from left & right sides, the predicate is treated as a filter
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = c"), ImmutableSet.of(a), ImmutableSet.of(b, c), null, "A + B = C");
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a + b = 1"), ImmutableSet.of(a), ImmutableSet.of(b), null, "A + B = 1");
        // Test with multiple equi-join conditions and filters
        assertJoinCondition(symbolAliases, toRowExpressionList(variableMap, "a = ABS(b)", "a = ceil(b-c)", "b = c + 10"),
                ImmutableSet.of(a), ImmutableSet.of(b, c), "A = abs(B) AND A = ceil(B-C)", "B = C + 10");
    }

    private List<RowExpression> toRowExpressionList(Map<String, Type> variableTypeMap, String... predicates)
    {
        return Arrays.stream(predicates)
                .map(p -> rowExpressionTranslator.translate(p, variableTypeMap))
                .collect(Collectors.toList());
    }

    private void assertJoinCondition(SymbolAliases symbolAliases, List<RowExpression> joinPredicates, Set<VariableReferenceExpression> leftVariables,
            Set<VariableReferenceExpression> rightVariables, String expectedEquiJoinClause, String expectedJoinFilter)
    {
        RowExpressionVerifier verifier = new RowExpressionVerifier(symbolAliases, metadata, session);
        JoinEnumerator joinEnumerator = new JoinEnumerator(
                new CostComparator(1, 1, 1),
                TRUE_CONSTANT,
                createContext(),
                determinismEvaluator,
                functionResolution,
                metadata);

        JoinCondition joinConditions = joinEnumerator.extractJoinConditions(joinPredicates,
                leftVariables, rightVariables, new VariableAllocator());

        Optional<RowExpression> equiJoinExpressionInlined = joinConditions.getJoinClauses().stream()
                .map(criteria -> toRowExpression(criteria, functionResolution))
                // We may have made left or right assignments to build the equi join clause
                // We inline these assignments for building the equi join clause to verify
                .map(expression -> replaceExpression(expression, joinConditions.getNewLeftAssignments()))
                .map(expression -> replaceExpression(expression, joinConditions.getNewRightAssignments()))
                .reduce(LogicalRowExpressions::and);

        if (equiJoinExpressionInlined.isPresent()) {
            assertNotNull(expectedEquiJoinClause);
            assertTrue(verifier.process(expression(expectedEquiJoinClause), equiJoinExpressionInlined.get()));
        }
        else {
            assertNull(expectedEquiJoinClause);
        }

        Optional<RowExpression> joinFilter = joinConditions.getJoinFilters().stream()
                .reduce(LogicalRowExpressions::and);

        if (joinFilter.isPresent()) {
            assertNotNull(expectedJoinFilter);
            assertTrue(verifier.process(expression(expectedJoinFilter), joinFilter.get()));
        }
        else {
            assertNull(expectedJoinFilter);
        }
    }

    private Rule.Context createContext()
    {
        PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator();
        VariableAllocator variableAllocator = new VariableAllocator();
        CachingStatsProvider statsProvider = new CachingStatsProvider(
                queryRunner.getStatsCalculator(),
                Optional.empty(),
                noLookup(),
                queryRunner.getDefaultSession(),
                TypeProvider.viewOf(variableAllocator.getVariables()));
        CachingCostProvider costProvider = new CachingCostProvider(
                queryRunner.getCostCalculator(),
                statsProvider,
                Optional.empty(),
                queryRunner.getDefaultSession());

        return new Rule.Context()
        {
            @Override
            public Lookup getLookup()
            {
                return noLookup();
            }

            @Override
            public PlanNodeIdAllocator getIdAllocator()
            {
                return planNodeIdAllocator;
            }

            @Override
            public VariableAllocator getVariableAllocator()
            {
                return variableAllocator;
            }

            @Override
            public Session getSession()
            {
                return queryRunner.getDefaultSession();
            }

            @Override
            public StatsProvider getStatsProvider()
            {
                return statsProvider;
            }

            @Override
            public CostProvider getCostProvider()
            {
                return costProvider;
            }

            @Override
            public void checkTimeoutNotExhausted() {}

            @Override
            public WarningCollector getWarningCollector()
            {
                return WarningCollector.NOOP;
            }

            @Override
            public Optional<LogicalPropertiesProvider> getLogicalPropertiesProvider()
            {
                return Optional.empty();
            }
        };
    }
}