TestPrestoSparkSourceDistributionSplitAssigner.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.spark.planner;

import com.facebook.presto.execution.Lifespan;
import com.facebook.presto.execution.ScheduledSplit;
import com.facebook.presto.metadata.Split;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSplit;
import com.facebook.presto.spi.HostAddress;
import com.facebook.presto.spi.NodeProvider;
import com.facebook.presto.spi.connector.ConnectorPartitionHandle;
import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.schedule.NodeSelectionStrategy;
import com.facebook.presto.split.SplitSource;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.SetMultimap;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.units.DataSize;
import org.testng.annotations.Test;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.concurrent.ThreadLocalRandom;

import static com.facebook.presto.spi.connector.NotPartitionedPartitionHandle.NOT_PARTITIONED;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.primitives.Ints.min;
import static com.google.common.util.concurrent.Futures.immediateFuture;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.util.Objects.requireNonNull;
import static org.assertj.core.api.Assertions.assertThat;
import static org.testng.Assert.assertEquals;

public class TestPrestoSparkSourceDistributionSplitAssigner
{
    @Test
    public void testSplitAssignmentWithAutoTuneEnabled()
    {
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                2,
                4,
                ImmutableList.of(),
                ImmutableMap.of());
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                2,
                4,
                ImmutableList.of(1L),
                ImmutableMap.of(
                        0, ImmutableList.of(1L)));
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                2,
                4,
                ImmutableList.of(1L, 1L),
                ImmutableMap.of(
                        0, ImmutableList.of(1L),
                        1, ImmutableList.of(1L)));
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                2,
                4,
                ImmutableList.of(10L, 11L, 12L, 13L, 9L),
                ImmutableMap.of(
                        0, ImmutableList.of(13L),
                        1, ImmutableList.of(12L),
                        2, ImmutableList.of(11L),
                        3, ImmutableList.of(10L, 9L)));
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                1,
                4,
                ImmutableList.of(3L, 4L, 5L),
                ImmutableMap.of(
                        0, ImmutableList.of(5L, 4L),
                        1, ImmutableList.of(3L)));
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                1,
                10,
                ImmutableList.of(1L, 2L, 3L, 4L, 5L, 6L, 7L, 8L, 9L, 10L, 11L),
                ImmutableMap.<Integer, List<Long>>builder()
                        .put(0, ImmutableList.of(11L))
                        .put(1, ImmutableList.of(10L))
                        .put(2, ImmutableList.of(9L))
                        .put(3, ImmutableList.of(8L, 1L))
                        .put(4, ImmutableList.of(7L, 2L))
                        .put(5, ImmutableList.of(6L, 3L))
                        .put(6, ImmutableList.of(5L, 4L))
                        .build());
        assertSplitAssignmentWithAutoTuneEnabled(
                new DataSize(10, BYTE),
                1,
                10,
                ImmutableList.of(1L, 2L, 3L, 4L, 5L, 6L),
                ImmutableMap.<Integer, List<Long>>builder()
                        .put(0, ImmutableList.of(6L, 3L))
                        .put(1, ImmutableList.of(5L, 4L))
                        .put(2, ImmutableList.of(2L, 1L))
                        .build());
    }

    private static void assertSplitAssignmentWithAutoTuneEnabled(
            DataSize maxSplitsDataSizePerSparkPartition,
            int minSparkInputPartitionCountForAutoTune,
            int maxSparkInputPartitionCountForAutoTune,
            List<Long> splitSizes,
            Map<Integer, List<Long>> expectedAssignment)
    {
        assertSplitAssignment(
                true,
                maxSplitsDataSizePerSparkPartition,
                // doesn't matter with auto tune enabled
                1,
                minSparkInputPartitionCountForAutoTune,
                maxSparkInputPartitionCountForAutoTune,
                splitSizes,
                expectedAssignment);
    }

    @Test
    public void testSplitAssignmentWithAutoTuneDisabled()
    {
        assertSplitAssignmentWithAutoTuneDisabled(
                1,
                ImmutableList.of(),
                ImmutableMap.of());
        assertSplitAssignmentWithAutoTuneDisabled(
                1,
                ImmutableList.of(1L),
                ImmutableMap.of(0, ImmutableList.of(1L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                1,
                ImmutableList.of(1L, 1L),
                ImmutableMap.of(0, ImmutableList.of(1L, 1L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                2,
                ImmutableList.of(1L, 1L),
                ImmutableMap.of(
                        0, ImmutableList.of(1L),
                        1, ImmutableList.of(1L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                2,
                ImmutableList.of(1L, 1L, 2L),
                ImmutableMap.of(
                        0, ImmutableList.of(2L),
                        1, ImmutableList.of(1L, 1L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                2,
                ImmutableList.of(2L, 1L, 1L, 1L),
                ImmutableMap.of(
                        0, ImmutableList.of(2L, 1L),
                        1, ImmutableList.of(1L, 1L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                2,
                ImmutableList.of(2L, 1L, 1L, 1L, 3L),
                ImmutableMap.of(
                        0, ImmutableList.of(3L, 1L),
                        1, ImmutableList.of(2L, 1L, 1L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                3,
                ImmutableList.of(1L, 2L, 3L, 4L, 5L, 6L),
                ImmutableMap.of(
                        0, ImmutableList.of(6L, 1L),
                        1, ImmutableList.of(5L, 2L),
                        2, ImmutableList.of(4L, 3L)));
        assertSplitAssignmentWithAutoTuneDisabled(
                3,
                ImmutableList.of(5L, 6L, 7L, 8L, 9L, 10L),
                ImmutableMap.of(
                        0, ImmutableList.of(10L, 5L),
                        1, ImmutableList.of(9L, 6L),
                        2, ImmutableList.of(8L, 7L)));
    }

    private static void assertSplitAssignmentWithAutoTuneDisabled(
            int initialPartitionCount,
            List<Long> splitSizes,
            Map<Integer, List<Long>> expectedAssignment)
    {
        assertSplitAssignment(
                false,
                // doesn't matter with auto tune disabled
                new DataSize(1, BYTE),
                initialPartitionCount,
                // doesn't matter with auto tune disabled
                1,
                // doesn't matter with auto tune disabled
                2,
                splitSizes,
                expectedAssignment);
    }

    private static void assertSplitAssignment(
            boolean autoTuneEnabled,
            DataSize maxSplitsDataSizePerSparkPartition,
            int initialPartitionCount,
            int minSparkInputPartitionCountForAutoTune,
            int maxSparkInputPartitionCountForAutoTune,
            List<Long> splitSizes,
            Map<Integer, List<Long>> expectedAssignment)
    {
        // assign splits in one shot
        {
            PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(
                    new PlanNodeId("test"),
                    createSplitSource(splitSizes),
                    Integer.MAX_VALUE,
                    maxSplitsDataSizePerSparkPartition.toBytes(),
                    initialPartitionCount,
                    autoTuneEnabled,
                    minSparkInputPartitionCountForAutoTune,
                    maxSparkInputPartitionCountForAutoTune,
                    0);

            Optional<SetMultimap<Integer, ScheduledSplit>> actualAssignment = assigner.getNextBatch();
            if (!splitSizes.isEmpty()) {
                assertThat(actualAssignment).isPresent();
                assertAssignedSplits(actualAssignment.get(), expectedAssignment);
            }
            else {
                assertThat(actualAssignment).isNotPresent();
            }
        }

        // assign splits iteratively
        for (int splitBatchSize = 1; splitBatchSize < splitSizes.size(); splitBatchSize *= 2) {
            HashMultimap<Integer, ScheduledSplit> actualAssignment = HashMultimap.create();

            // sort splits to make assignment match the assignment done in one shot
            List<Long> sortedSplits = new ArrayList<>(splitSizes);
            sortedSplits.sort(Comparator.<Long>naturalOrder().reversed());

            PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(
                    new PlanNodeId("test"),
                    createSplitSource(sortedSplits),
                    splitBatchSize,
                    maxSplitsDataSizePerSparkPartition.toBytes(),
                    initialPartitionCount,
                    autoTuneEnabled,
                    minSparkInputPartitionCountForAutoTune,
                    maxSparkInputPartitionCountForAutoTune,
                    0);

            while (true) {
                Optional<SetMultimap<Integer, ScheduledSplit>> assignment = assigner.getNextBatch();
                if (!assignment.isPresent()) {
                    break;
                }
                actualAssignment.putAll(assignment.get());
            }

            assertAssignedSplits(actualAssignment, expectedAssignment);
        }
    }

    @Test
    public void testAssignSplitsToPartitionWithRandomSplitSizes()
    {
        DataSize maxSplitDataSizePerPartition = new DataSize(2048, BYTE);
        int initialPartitionCount = 3;
        int minSparkInputPartitionCountForAutoTune = 2;
        int maxSparkInputPartitionCountForAutoTune = 5;
        int maxSplitSizeInBytes = 2048;
        for (int i = 0; i < 3; ++i) {
            List<Long> splitSizes = new ArrayList<>(1000);
            for (int j = 0; j < 1000; j++) {
                splitSizes.add(ThreadLocalRandom.current().nextLong((long) (maxSplitSizeInBytes * 1.2)));
            }

            PrestoSparkSplitAssigner assigner = new PrestoSparkSourceDistributionSplitAssigner(
                    new PlanNodeId("test"),
                    createSplitSource(splitSizes),
                    333,
                    maxSplitDataSizePerPartition.toBytes(),
                    initialPartitionCount,
                    true,
                    minSparkInputPartitionCountForAutoTune,
                    maxSparkInputPartitionCountForAutoTune,
                    0);

            HashMultimap<Integer, ScheduledSplit> actualAssignment = HashMultimap.create();

            while (true) {
                Optional<SetMultimap<Integer, ScheduledSplit>> assignment = assigner.getNextBatch();
                if (!assignment.isPresent()) {
                    break;
                }
                actualAssignment.putAll(assignment.get());
            }

            long expectedSizeInBytes = splitSizes.stream()
                    .mapToLong(Long::longValue)
                    .sum();
            long actualTotalSizeInBytes = actualAssignment.values().stream()
                    .mapToLong(split -> split.getSplit().getConnectorSplit().getSplitSizeInBytes().orElseThrow(() -> new IllegalArgumentException("split size is expected to be present")))
                    .sum();

            // check if all splits got assigned
            assertEquals(expectedSizeInBytes, actualTotalSizeInBytes);
        }
    }

    private static void assertAssignedSplits(SetMultimap<Integer, ScheduledSplit> actual, Map<Integer, List<Long>> expected)
    {
        Map<Integer, List<Long>> actualAssignment = getAssignedSplitSizes(actual);
        assertThat(actualAssignment.keySet()).isEqualTo(expected.keySet());
        for (Integer partition : actualAssignment.keySet()) {
            assertThat(actualAssignment.get(partition)).containsExactlyInAnyOrder(expected.get(partition).toArray(new Long[] {}));
        }
    }

    private static Map<Integer, List<Long>> getAssignedSplitSizes(SetMultimap<Integer, ScheduledSplit> assignedSplits)
    {
        return assignedSplits.asMap().entrySet().stream()
                .collect(toImmutableMap(
                        Map.Entry::getKey,
                        entry -> entry.getValue().stream()
                                .map(split -> split.getSplit().getConnectorSplit().getSplitSizeInBytes().orElseThrow(() -> new IllegalArgumentException("split size is expected to be present")))
                                .collect(toImmutableList())));
    }

    private static SplitSource createSplitSource(List<Long> splitSizes)
    {
        List<Split> splits = splitSizes.stream()
                .map(size -> new Split(new ConnectorId("test"), TestingTransactionHandle.create(), new MockSplit(size)))
                .collect(toImmutableList());
        return new MockSplitSource(splits);
    }

    private static class MockSplit
            implements ConnectorSplit
    {
        private final long splitSizeInBytes;

        public MockSplit(long splitSizeInBytes)
        {
            this.splitSizeInBytes = splitSizeInBytes;
        }

        @Override
        public NodeSelectionStrategy getNodeSelectionStrategy()
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public List<HostAddress> getPreferredNodes(NodeProvider nodeProvider)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public Object getInfo()
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public OptionalLong getSplitSizeInBytes()
        {
            return OptionalLong.of(splitSizeInBytes);
        }
    }

    private static class MockSplitSource
            implements SplitSource
    {
        private final List<Split> splits;

        private int position;
        private boolean closed;

        private MockSplitSource(List<Split> splits)
        {
            this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null"));
        }

        @Override
        public ConnectorId getConnectorId()
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public ConnectorTransactionHandle getTransactionHandle()
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public ListenableFuture<SplitBatch> getNextBatch(ConnectorPartitionHandle partitionHandle, Lifespan lifespan, int maxSize)
        {
            checkState(!closed, "split source is closed");
            checkState(!isFinished(), "split source is finished");
            checkArgument(partitionHandle.equals(NOT_PARTITIONED), "unexpected partition handle: %s", partitionHandle);
            checkArgument(lifespan.equals(Lifespan.taskWide()), "unexpected lifespan: %s", lifespan);

            int remaining = splits.size() - position;
            int batchSize = min(remaining, maxSize);
            List<Split> batch = ImmutableList.copyOf(splits.subList(position, position + batchSize));
            position += batchSize;

            return immediateFuture(new SplitBatch(batch, isFinished()));
        }

        @Override
        public void rewind(ConnectorPartitionHandle partitionHandle)
        {
            throw new UnsupportedOperationException();
        }

        @Override
        public void close()
        {
            closed = true;
        }

        @Override
        public boolean isFinished()
        {
            return position >= splits.size();
        }
    }
}