TestDetermineRemotePartitionedExchangeEncoding.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.ExchangeEncoding;
import com.facebook.presto.spi.plan.Partitioning;
import com.facebook.presto.spi.plan.PartitioningHandle;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.assertions.MatchResult;
import com.facebook.presto.sql.planner.assertions.Matcher;
import com.facebook.presto.sql.planner.assertions.PlanMatchPattern;
import com.facebook.presto.sql.planner.assertions.SymbolAliases;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleAssert;
import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Optional;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.NATIVE_MIN_COLUMNAR_ENCODING_CHANNELS_TO_PREFER_ROW_WISE_ENCODING;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DecimalType.createDecimalType;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.spi.plan.ExchangeEncoding.ROW_WISE;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION;
import static com.facebook.presto.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH;
import static com.facebook.presto.sql.planner.assertions.MatchResult.match;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineRemotePartitionedExchangeEncoding.estimateNumberOfColumnarChannels;
import static com.facebook.presto.sql.planner.iterative.rule.DetermineRemotePartitionedExchangeEncoding.estimateNumberOfOutputColumnarChannels;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.REMOTE_STREAMING;
import static com.facebook.presto.sql.planner.plan.ExchangeNode.partitionedExchange;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
import static org.testng.Assert.assertEquals;

public class TestDetermineRemotePartitionedExchangeEncoding
{
    private static final int MIN_COLUMNAR_STREAMS = 100;

    private RuleTester tester;

    @BeforeClass
    public void setUp()
    {
        tester = new RuleTester();
    }

    @AfterClass(alwaysRun = true)
    public void tearDown()
    {
        tester.close();
        tester = null;
    }

    @Test
    public void testPrestoOnSpark()
    {
        // special exchanges are always columnar
        assertForPrestoOnSpark()
                .on(p -> createExchange(FIXED_ARBITRARY_DISTRIBUTION, MIN_COLUMNAR_STREAMS))
                .doesNotFire();
        // do not fire twice
        assertForPrestoOnSpark()
                .on(p -> createExchange(FIXED_HASH_DISTRIBUTION, MIN_COLUMNAR_STREAMS).withRowWiseEncoding())
                .doesNotFire();
        // hash based exchanges are always row wise in Presto on Spark
        assertForPrestoOnSpark()
                .on(p -> createExchange(FIXED_HASH_DISTRIBUTION, MIN_COLUMNAR_STREAMS - 1))
                .matches(exchangeEncoding(ROW_WISE));
    }

    @Test
    public void testPresto()
    {
        // exchanges are always columnar in Presto
        assertForPresto()
                .on(p -> createExchange(FIXED_ARBITRARY_DISTRIBUTION, MIN_COLUMNAR_STREAMS))
                .doesNotFire();
        assertForPresto()
                .on(p -> createExchange(FIXED_HASH_DISTRIBUTION, MIN_COLUMNAR_STREAMS - 1))
                .doesNotFire();
        assertForPresto()
                .on(p -> createExchange(FIXED_HASH_DISTRIBUTION, MIN_COLUMNAR_STREAMS + 1))
                .doesNotFire();
    }

    @Test
    public void testNative()
    {
        // special exchanges are always columnar
        assertForNative()
                .on(p -> createExchange(FIXED_ARBITRARY_DISTRIBUTION, MIN_COLUMNAR_STREAMS))
                .doesNotFire();
        // hash based exchange with the total number of output columnar streams lower than threshold is columnar
        assertForNative()
                .on(p -> createExchange(FIXED_HASH_DISTRIBUTION, MIN_COLUMNAR_STREAMS - 1))
                .doesNotFire();
        // otherwise row wise
        assertForNative()
                .on(p -> createExchange(FIXED_HASH_DISTRIBUTION, MIN_COLUMNAR_STREAMS))
                .matches(exchangeEncoding(ROW_WISE));
    }

    private RuleAssert assertForPrestoOnSpark()
    {
        return createAssert(false, true);
    }

    private RuleAssert assertForNative()
    {
        return createAssert(true, false);
    }

    private RuleAssert assertForPresto()
    {
        return createAssert(false, false);
    }

    private RuleAssert createAssert(boolean nativeExecution, boolean prestoSparkExecutionEnvironment)
    {
        return tester.assertThat(new DetermineRemotePartitionedExchangeEncoding(nativeExecution, prestoSparkExecutionEnvironment))
                .setSystemProperty(NATIVE_MIN_COLUMNAR_ENCODING_CHANNELS_TO_PREFER_ROW_WISE_ENCODING, MIN_COLUMNAR_STREAMS + "");
    }

    private static ExchangeNode createExchange(PartitioningHandle handle, int numberOfOutputColumnarStreams)
    {
        int numberOfBigintColumns = numberOfOutputColumnarStreams / 2;
        List<Type> types = IntStream.range(0, numberOfBigintColumns)
                .mapToObj(i -> BIGINT)
                .collect(toImmutableList());
        ExchangeNode exchangeNode = createExchangeNode(handle, types, types);
        assertEquals(estimateNumberOfOutputColumnarChannels(exchangeNode), numberOfBigintColumns * 2);
        return exchangeNode;
    }

    private static PlanMatchPattern exchangeEncoding(ExchangeEncoding encoding)
    {
        return node(ExchangeNode.class, node(ValuesNode.class)).with(new ExchangeEncodingMatcher(encoding));
    }

    @Test
    public void testEstimateNumberOfOutputColumnarChannels()
    {
        assertEquals(estimateNumberOfOutputColumnarChannels(createExchangeNode(FIXED_HASH_DISTRIBUTION, ImmutableList.of(BIGINT), ImmutableList.of(BIGINT))), 2);
        assertEquals(estimateNumberOfOutputColumnarChannels(createExchangeNode(FIXED_HASH_DISTRIBUTION, ImmutableList.of(BIGINT, VARCHAR), ImmutableList.of(BIGINT))), 2);
    }

    private static ExchangeNode createExchangeNode(PartitioningHandle handle, List<Type> inputTypes, List<Type> outputTypes)
    {
        return partitionedExchange(
                new PlanNodeId("exchange"),
                REMOTE_STREAMING,
                new ValuesNode(
                        Optional.empty(),
                        new PlanNodeId("values"),
                        createExpressions(inputTypes),
                        ImmutableList.of(),
                        Optional.empty()),
                new PartitioningScheme(
                        Partitioning.create(handle, ImmutableList.of()),
                        createExpressions(outputTypes)));
    }

    private static List<VariableReferenceExpression> createExpressions(List<Type> types)
    {
        ImmutableList.Builder<VariableReferenceExpression> result = ImmutableList.builder();
        for (int i = 0; i < types.size(); i++) {
            result.add(new VariableReferenceExpression(Optional.empty(), "exp_" + i, types.get(i)));
        }
        return result.build();
    }

    @Test
    public void testEstimateNumberOfColumnarChannels()
    {
        assertEquals(estimateNumberOfColumnarChannels(BIGINT), 2);
        assertEquals(estimateNumberOfColumnarChannels(REAL), 2);
        assertEquals(estimateNumberOfColumnarChannels(VARCHAR), 3);
        assertEquals(estimateNumberOfColumnarChannels(createVarcharType(10)), 3);
        assertEquals(estimateNumberOfColumnarChannels(createDecimalType(3, 2)), 2);
        assertEquals(estimateNumberOfColumnarChannels(createDecimalType(30, 2)), 2);
        assertEquals(estimateNumberOfColumnarChannels(new ArrayType(BIGINT)), 4);
        assertEquals(estimateNumberOfColumnarChannels(new ArrayType(VARCHAR)), 5);
        assertEquals(estimateNumberOfColumnarChannels(RowType.anonymous(ImmutableList.of(BIGINT, VARCHAR))), 7);
    }

    private static class ExchangeEncodingMatcher
            implements Matcher
    {
        private final ExchangeEncoding encoding;

        private ExchangeEncodingMatcher(ExchangeEncoding encoding)
        {
            this.encoding = requireNonNull(encoding, "encoding is null");
        }

        @Override
        public boolean shapeMatches(PlanNode node)
        {
            return node instanceof ExchangeNode;
        }

        @Override
        public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases)
        {
            ExchangeNode exchangeNode = (ExchangeNode) node;
            return exchangeNode.getPartitioningScheme().getEncoding() == encoding ? match() : NO_MATCH;
        }
    }
}