TestPerf.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.flight.perf;

import static org.apache.arrow.flight.FlightTestUtil.LOCALHOST;
import static org.apache.arrow.flight.Location.forGrpcInsecure;

import com.google.common.base.MoreObjects;
import com.google.common.base.Stopwatch;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.ListeningExecutorService;
import com.google.common.util.concurrent.MoreExecutors;
import com.google.protobuf.ByteString;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.perf.impl.PerfOuterClass.Perf;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;

@Disabled
public class TestPerf {

  public static final boolean VALIDATE = false;

  public static FlightDescriptor getPerfFlightDescriptor(
      long recordCount, int recordsPerBatch, int streamCount) {
    final Schema pojoSchema =
        new Schema(
            ImmutableList.of(
                Field.nullable("a", MinorType.BIGINT.getType()),
                Field.nullable("b", MinorType.BIGINT.getType()),
                Field.nullable("c", MinorType.BIGINT.getType()),
                Field.nullable("d", MinorType.BIGINT.getType())));

    byte[] bytes = pojoSchema.serializeAsMessage();
    ByteString serializedSchema = ByteString.copyFrom(bytes);

    return FlightDescriptor.command(
        Perf.newBuilder()
            .setRecordsPerStream(recordCount)
            .setRecordsPerBatch(recordsPerBatch)
            .setSchema(serializedSchema)
            .setStreamCount(streamCount)
            .build()
            .toByteArray());
  }

  public static void main(String[] args) throws Exception {
    new TestPerf().throughput();
  }

  @Test
  public void throughput() throws Exception {
    final int numRuns = 10;
    ListeningExecutorService pool =
        MoreExecutors.listeningDecorator(Executors.newFixedThreadPool(4));
    double[] throughPuts = new double[numRuns];

    for (int i = 0; i < numRuns; i++) {
      try (final BufferAllocator a = new RootAllocator(Long.MAX_VALUE);
          final PerformanceTestServer server =
              new PerformanceTestServer(a, forGrpcInsecure(LOCALHOST, 0)).start();
          final FlightClient client = FlightClient.builder(a, server.getLocation()).build(); ) {
        final FlightInfo info = client.getInfo(getPerfFlightDescriptor(50_000_000L, 4095, 2));
        List<ListenableFuture<Result>> results =
            info.getEndpoints().stream()
                .map(t -> new Consumer(client, t.getTicket()))
                .map(t -> pool.submit(t))
                .collect(Collectors.toList());

        final Result r =
            Futures.whenAllSucceed(results)
                .call(
                    () -> {
                      Result res = new Result();
                      for (ListenableFuture<Result> f : results) {
                        res.add(f.get());
                      }
                      return res;
                    },
                    pool)
                .get();

        double seconds = r.nanos * 1.0d / 1000 / 1000 / 1000;
        throughPuts[i] = (r.bytes * 1.0d / 1024 / 1024) / seconds;
        System.out.printf(
            "Transferred %d records totaling %s bytes at %f MiB/s. %f record/s. %f batch/s.%n",
            r.rows,
            r.bytes,
            throughPuts[i],
            (r.rows * 1.0d) / seconds,
            (r.batches * 1.0d) / seconds);
      }
    }
    pool.shutdown();

    System.out.println("Summary: ");
    double average = Arrays.stream(throughPuts).sum() / numRuns;
    double sqrSum =
        Arrays.stream(throughPuts).map(val -> val - average).map(val -> val * val).sum();
    double stddev = Math.sqrt(sqrSum / numRuns);
    System.out.printf(
        "Average throughput: %f MiB/s, standard deviation: %f MiB/s%n", average, stddev);
  }

  private static final class Consumer implements Callable<Result> {

    private final FlightClient client;
    private final Ticket ticket;

    public Consumer(FlightClient client, Ticket ticket) {
      super();
      this.client = client;
      this.ticket = ticket;
    }

    @Override
    public Result call() throws Exception {
      final Result r = new Result();
      Stopwatch watch = Stopwatch.createStarted();
      try (final FlightStream stream = client.getStream(ticket)) {
        final VectorSchemaRoot root = stream.getRoot();
        try {
          BigIntVector a = (BigIntVector) root.getVector("a");
          while (stream.next()) {
            int rows = root.getRowCount();
            long aSum = r.aSum;
            for (int i = 0; i < rows; i++) {
              if (VALIDATE) {
                aSum += a.get(i);
              }
            }
            r.bytes += rows * 32L;
            r.rows += rows;
            r.aSum = aSum;
            r.batches++;
          }

          r.nanos = watch.elapsed(TimeUnit.NANOSECONDS);
          return r;
        } finally {
          root.clear();
        }
      }
    }
  }

  private static final class Result {
    private long rows;
    private long aSum;
    private long bytes;
    private long nanos;
    private long batches;

    public void add(Result r) {
      rows += r.rows;
      aSum += r.aSum;
      bytes += r.bytes;
      batches += r.batches;
      nanos = Math.max(nanos, r.nanos);
    }

    @Override
    public String toString() {
      return MoreObjects.toStringHelper(this)
          .add("rows", rows)
          .add("aSum", aSum)
          .add("batches", batches)
          .add("bytes", bytes)
          .add("nanos", nanos)
          .toString();
    }
  }
}