TestBackPressure.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.google.common.collect.ImmutableList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import org.apache.arrow.flight.perf.PerformanceTestServer;
import org.apache.arrow.flight.perf.TestPerf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

public class TestBackPressure {

  private static final int BATCH_SIZE = 4095;

  /** Make sure that failing to consume one stream doesn't block other streams. */
  @Disabled
  @Test
  public void ensureIndependentSteams() throws Exception {
    ensureIndependentSteams((b) -> (location -> new PerformanceTestServer(b, location)));
  }

  /** Make sure that failing to consume one stream doesn't block other streams. */
  @Disabled
  @Test
  public void ensureIndependentSteamsWithCallbacks() throws Exception {
    ensureIndependentSteams(
        (b) ->
            (location ->
                new PerformanceTestServer(
                    b, location, new BackpressureStrategy.CallbackBackpressureStrategy(), true)));
  }

  /** Test to make sure stream doesn't go faster than the consumer is consuming. */
  @Disabled
  @Test
  public void ensureWaitUntilProceed() throws Exception {
    ensureWaitUntilProceed(new PollingBackpressureStrategy(), false);
  }

  /**
   * Test to make sure stream doesn't go faster than the consumer is consuming using a
   * callback-based backpressure strategy.
   */
  @Disabled
  @Test
  public void ensureWaitUntilProceedWithCallbacks() throws Exception {
    ensureWaitUntilProceed(new RecordingCallbackBackpressureStrategy(), true);
  }

  /** Make sure that failing to consume one stream doesn't block other streams. */
  private static void ensureIndependentSteams(
      Function<BufferAllocator, Function<Location, PerformanceTestServer>> serverConstructor)
      throws Exception {
    try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
        final PerformanceTestServer server =
            serverConstructor.apply(a).apply(forGrpcInsecure(LOCALHOST, 0)).start();
        final FlightClient client = FlightClient.builder(a, server.getLocation()).build()) {
      try (FlightStream fs1 =
          client.getStream(
              client
                  .getInfo(TestPerf.getPerfFlightDescriptor(110L * BATCH_SIZE, BATCH_SIZE, 1))
                  .getEndpoints()
                  .get(0)
                  .getTicket())) {
        consume(fs1, 10);

        // stop consuming fs1 but make sure we can consume a large amount of fs2.
        try (FlightStream fs2 =
            client.getStream(
                client
                    .getInfo(TestPerf.getPerfFlightDescriptor(200L * BATCH_SIZE, BATCH_SIZE, 1))
                    .getEndpoints()
                    .get(0)
                    .getTicket())) {
          consume(fs2, 100);

          consume(fs1, 100);
          consume(fs2, 100);

          consume(fs1);
          consume(fs2);
        }
      }
    }
  }

  /** Make sure that a stream doesn't go faster than the consumer is consuming. */
  private static void ensureWaitUntilProceed(
      SleepTimeRecordingBackpressureStrategy bpStrategy, boolean isNonBlocking) throws Exception {
    // request some values.
    final long wait = 3000;
    final long epsilon = 1000;

    try (BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE)) {

      final FlightProducer producer =
          new NoOpFlightProducer() {

            @Override
            public void getStream(
                CallContext context, Ticket ticket, ServerStreamListener listener) {
              bpStrategy.register(listener);
              final Runnable loadData =
                  () -> {
                    int batches = 0;
                    final Schema pojoSchema =
                        new Schema(
                            ImmutableList.of(Field.nullable("a", MinorType.BIGINT.getType())));
                    try (VectorSchemaRoot root = VectorSchemaRoot.create(pojoSchema, allocator)) {
                      listener.start(root);
                      while (true) {
                        bpStrategy.waitForListener(0);
                        if (batches > 100) {
                          root.clear();
                          listener.completed();
                          return;
                        }

                        root.allocateNew();
                        root.setRowCount(4095);
                        listener.putNext();
                        batches++;
                      }
                    }
                  };

              if (!isNonBlocking) {
                loadData.run();
              } else {
                final ExecutorService service = Executors.newSingleThreadExecutor();
                Future<?> unused = service.submit(loadData);
                service.shutdown();
              }
            }
          };

      try (BufferAllocator serverAllocator =
              allocator.newChildAllocator("server", 0, Long.MAX_VALUE);
          FlightServer server =
              FlightServer.builder(serverAllocator, forGrpcInsecure(LOCALHOST, 0), producer)
                  .build()
                  .start();
          BufferAllocator clientAllocator =
              allocator.newChildAllocator("client", 0, Long.MAX_VALUE);
          FlightClient client =
              FlightClient.builder(clientAllocator, server.getLocation()).build();
          FlightStream stream = client.getStream(new Ticket(new byte[1]))) {
        VectorSchemaRoot root = stream.getRoot();
        root.clear();
        Thread.sleep(wait);
        while (stream.next()) {
          root.clear();
        }
        long expected = wait - epsilon;
        assertTrue(
            bpStrategy.getSleepTime() > expected,
            String.format(
                "Expected a sleep of at least %dms but only slept for %d",
                expected, bpStrategy.getSleepTime()));
      }
    }
  }

  private static void consume(FlightStream stream) {
    VectorSchemaRoot root = stream.getRoot();
    while (stream.next()) {
      root.clear();
    }
  }

  private static void consume(FlightStream stream, int batches) {
    VectorSchemaRoot root = stream.getRoot();
    while (batches > 0 && stream.next()) {
      root.clear();
      batches--;
    }
  }

  private interface SleepTimeRecordingBackpressureStrategy extends BackpressureStrategy {
    /**
     * Returns the total time spent waiting on the listener to be ready.
     *
     * @return the total time spent waiting on the listener to be ready.
     */
    long getSleepTime();
  }

  /**
   * Implementation of a backpressure strategy that polls on isReady and records amount of time
   * spent in Thread.sleep().
   */
  private static class PollingBackpressureStrategy
      implements SleepTimeRecordingBackpressureStrategy {
    private final AtomicLong sleepTime = new AtomicLong(0);
    private FlightProducer.ServerStreamListener listener;

    @Override
    public long getSleepTime() {
      return sleepTime.get();
    }

    @Override
    public void register(FlightProducer.ServerStreamListener listener) {
      this.listener = listener;
    }

    @Override
    public WaitResult waitForListener(long timeout) {
      while (!listener.isReady()) {
        try {
          Thread.sleep(1);
          sleepTime.addAndGet(1L);
        } catch (InterruptedException expected) {
          // it is expected and no action needed
        }
      }
      return WaitResult.READY;
    }
  }

  /**
   * Implementation of a backpressure strategy that uses callbacks to detect changes in client
   * readiness state and records spent time waiting.
   */
  private static class RecordingCallbackBackpressureStrategy
      extends BackpressureStrategy.CallbackBackpressureStrategy
      implements SleepTimeRecordingBackpressureStrategy {
    private final AtomicLong sleepTime = new AtomicLong(0);

    @Override
    public long getSleepTime() {
      return sleepTime.get();
    }

    @Override
    public WaitResult waitForListener(long timeout) {
      final long startTime = System.currentTimeMillis();
      final WaitResult result = super.waitForListener(timeout);
      sleepTime.addAndGet(System.currentTimeMillis() - startTime);
      return result;
    }
  }
}