TestAsyncIPC.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.ipc;

import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.CommonConfigurationKeys;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.ipc.RPC.RpcKind;
import org.apache.hadoop.ipc.TestIPC.CallInfo;
import org.apache.hadoop.ipc.TestIPC.TestServer;
import org.apache.hadoop.ipc.protobuf.RpcHeaderProtos.RpcResponseHeaderProto;
import org.apache.hadoop.net.NetUtils;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.concurrent.AsyncGetFuture;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

public class TestAsyncIPC {

  private static Configuration conf;
  private static final Logger LOG = LoggerFactory.getLogger(TestAsyncIPC.class);

  static <T extends Writable> AsyncGetFuture<T, IOException>
      getAsyncRpcResponseFuture() {
    return new AsyncGetFuture<>(Client.getAsyncRpcResponse());
  }

  @BeforeEach
  public void setupConf() {
    conf = new Configuration();
    conf.setInt(CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_KEY, 10000);
    Client.setPingInterval(conf, TestIPC.PING_INTERVAL);
    // Set asynchronous mode for main thread.
    Client.setAsynchronousMode(true);
  }

  static class AsyncCaller extends Thread {
    private Client client;
    private InetSocketAddress server;
    private int count;
    private boolean failed;
    Map<Integer, Future<LongWritable>> returnFutures =
        new HashMap<Integer, Future<LongWritable>>();
    Map<Integer, Long> expectedValues = new HashMap<Integer, Long>();

    AsyncCaller(Client client, InetSocketAddress server, int count,
        boolean checkAsyncCallEnabled) {
      this.client = client;
      // Disable checkAsyncCall.
      if (!checkAsyncCallEnabled) {
        this.client.setMaxAsyncCalls(-1);
      }
      this.server = server;
      this.count = count;
      // Set asynchronous mode, since AsyncCaller extends Thread.
      Client.setAsynchronousMode(true);
    }

    @Override
    public void run() {
      // In case Thread#Start is called, which will spawn new thread.
      Client.setAsynchronousMode(true);
      for (int i = 0; i < count; i++) {
        try {
          final long param = TestIPC.RANDOM.nextLong();
          TestIPC.call(client, param, server, conf);
          returnFutures.put(i, getAsyncRpcResponseFuture());
          expectedValues.put(i, param);
        } catch (Exception e) {
          failed = true;
          throw new RuntimeException(e);
        }
      }
    }

    void assertReturnValues() throws InterruptedException, ExecutionException {
      for (int i = 0; i < count; i++) {
        LongWritable value = returnFutures.get(i).get();
        assertEquals(expectedValues.get(i).longValue(), value.get(),
            "call" + i + " failed.");
      }
      assertFalse(failed);
    }

    void assertReturnValues(long timeout, TimeUnit unit)
        throws InterruptedException, ExecutionException {
      final boolean[] checked = new boolean[count];
      for(boolean done = false; !done;) {
        done = true;
        for (int i = 0; i < count; i++) {
          if (checked[i]) {
            continue;
          } else {
            done = false;
          }

          final LongWritable value;
          try {
            value = returnFutures.get(i).get(timeout, unit);
          } catch (TimeoutException e) {
            LOG.info("call" + i + " caught ", e);
            continue;
          }

          assertEquals(expectedValues.get(i).longValue(), value.get(),
              "call" + i + " failed.");
          checked[i] = true;
        }
      }
      assertFalse(failed);
    }
  }

  /**
   * For testing the asynchronous calls of the RPC client
   * implemented with CompletableFuture.
   */
  static class AsyncCompletableFutureCaller extends Thread {
    private final Client client;
    private final InetSocketAddress server;
    private final int count;
    private final List<CompletableFuture<Writable>> completableFutures;
    private final List<Long> expectedValues;

    AsyncCompletableFutureCaller(Client client, InetSocketAddress server, int count) {
      this.client = client;
      this.server = server;
      this.count = count;
      this.completableFutures = new ArrayList<>(count);
      this.expectedValues = new ArrayList<>(count);
      setName("Async CompletableFuture Caller");
    }

    @Override
    public void run() {
      // Set the RPC client to use asynchronous mode.
      Client.setAsynchronousMode(true);
      long startTime = Time.monotonicNow();
      try {
        for (int i = 0; i < count; i++) {
          final long param = TestIPC.RANDOM.nextLong();
          TestIPC.call(client, param, server, conf);
          expectedValues.add(param);
          completableFutures.add(Client.getResponseFuture());
        }
        // Since the run method is asynchronous,
        // it does not need to wait for a response after sending a request,
        // so the time taken by the run method is less than count * 100
        // (where 100 is the time taken by the server to process a request).
        long cost = Time.monotonicNow() - startTime;
        assertTrue(cost < count * 100L);
        LOG.info("[{}] run cost {}ms", Thread.currentThread().getName(), cost);
      } catch (Exception e) {
        fail();
      }
    }

    public void assertReturnValues()
        throws InterruptedException, ExecutionException {
      for (int i = 0; i < count; i++) {
        LongWritable value = (LongWritable) completableFutures.get(i).get();
        assertEquals(expectedValues.get(i).longValue(), value.get(),
            "call" + i + " failed.");
      }
    }
  }

  static class AsyncLimitlCaller extends Thread {
    private Client client;
    private InetSocketAddress server;
    private int count;
    private boolean failed;
    Map<Integer, Future<LongWritable>> returnFutures = new HashMap<Integer, Future<LongWritable>>();
    Map<Integer, Long> expectedValues = new HashMap<Integer, Long>();
    int start = 0, end = 0;

    int getStart() {
      return start;
    }

    int getEnd() {
      return end;
    }

    int getCount() {
      return count;
    }

    public AsyncLimitlCaller(Client client, InetSocketAddress server, int count) {
      this(0, client, server, count);
    }

    final int callerId;

    public AsyncLimitlCaller(int callerId, Client client, InetSocketAddress server,
        int count) {
      this.client = client;
      this.server = server;
      this.count = count;
      // Set asynchronous mode, since AsyncLimitlCaller extends Thread.
      Client.setAsynchronousMode(true);
      this.callerId = callerId;
    }

    @Override
    public void run() {
      // in case Thread#Start is called, which will spawn new thread
      Client.setAsynchronousMode(true);
      for (int i = 0; i < count; i++) {
        try {
          final long param = TestIPC.RANDOM.nextLong();
          runCall(i, param);
        } catch (Exception e) {
          LOG.error(String.format("Caller-%d Call-%d caught: %s", callerId, i,
              StringUtils.stringifyException(e)));
          failed = true;
        }
      }
    }

    private void runCall(final int idx, final long param)
        throws InterruptedException, ExecutionException, IOException {
      for (;;) {
        try {
          doCall(idx, param);
          return;
        } catch (AsyncCallLimitExceededException e) {
          /**
           * reached limit of async calls, fetch results of finished async calls
           * to let follow-on calls go
           */
          start = end;
          end = idx;
          waitForReturnValues(start, end);
        }
      }
    }

    private void doCall(final int idx, final long param) throws IOException {
      TestIPC.call(client, param, server, conf);
      returnFutures.put(idx, getAsyncRpcResponseFuture());
      expectedValues.put(idx, param);
    }

    private void waitForReturnValues(final int start, final int end)
        throws InterruptedException, ExecutionException {
      for (int i = start; i < end; i++) {
        LongWritable value = returnFutures.get(i).get();
        if (expectedValues.get(i) != value.get()) {
          LOG.error(String.format("Caller-%d Call-%d failed!", callerId, i));
          failed = true;
          break;
        }
      }
    }
  }

  @Test
  @Timeout(value = 60)
  public void testAsyncCallCheckDisabled() throws IOException, InterruptedException,
      ExecutionException {
    internalTestAsyncCall(3, true, 2, 5, 10, false);
  }

  @Test
  @Timeout(value = 60)
  public void testAsyncCall() throws IOException, InterruptedException,
      ExecutionException {
    internalTestAsyncCall(3, false, 2, 5, 100, true);
    internalTestAsyncCall(3, true, 2, 5, 10, true);
  }

  @Test
  @Timeout(value = 60)
  public void testAsyncCallLimit() throws IOException,
      InterruptedException, ExecutionException {
    internalTestAsyncCallLimit(100, false, 5, 10, 500);
  }

  public void internalTestAsyncCall(int handlerCount, boolean handlerSleep,
      int clientCount, int callerCount, int callCount,
      boolean checkAsyncCallEnabled) throws IOException,
      InterruptedException, ExecutionException {
    Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    Client[] clients = new Client[clientCount];
    for (int i = 0; i < clientCount; i++) {
      clients[i] = new Client(LongWritable.class, conf);
    }

    AsyncCaller[] callers = new AsyncCaller[callerCount];
    for (int i = 0; i < callerCount; i++) {
      callers[i] = new AsyncCaller(clients[i % clientCount], addr, callCount,
          checkAsyncCallEnabled);
      callers[i].start();
    }
    for (int i = 0; i < callerCount; i++) {
      if (!checkAsyncCallEnabled) {
        assertEquals(0, clients[i % clientCount].getAsyncCallCounter());
      }
      callers[i].join();
      callers[i].assertReturnValues();
    }
    for (int i = 0; i < clientCount; i++) {
      clients[i].stop();
    }
    server.stop();
  }

  @Test
  @Timeout(value = 60)
  public void testCallGetReturnRpcResponseMultipleTimes() throws IOException,
      InterruptedException, ExecutionException {
    int handlerCount = 10, callCount = 100;
    Server server = new TestIPC.TestServer(handlerCount, false, conf);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();
    final Client client = new Client(LongWritable.class, conf);

    int asyncCallCount = client.getAsyncCallCount();

    try {
      AsyncCaller caller = new AsyncCaller(client, addr, callCount, true);
      caller.run();
      caller.assertReturnValues();
      caller.assertReturnValues();
      caller.assertReturnValues();
      assertEquals(asyncCallCount, client.getAsyncCallCount());
    } finally {
      client.stop();
      server.stop();
    }
  }

  @Test
  @Timeout(value = 60)
  public void testFutureGetWithTimeout() throws IOException,
      InterruptedException, ExecutionException {
//    GenericTestUtils.setLogLevel(AsyncGetFuture.LOG, Level.ALL);
    final Server server = new TestIPC.TestServer(10, true, conf);
    final InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    final Client client = new Client(LongWritable.class, conf);

    try {
      final AsyncCaller caller = new AsyncCaller(client, addr, 10, true);
      caller.run();
      caller.assertReturnValues(10, TimeUnit.MILLISECONDS);
    } finally {
      client.stop();
      server.stop();
    }
  }


  public void internalTestAsyncCallLimit(int handlerCount, boolean handlerSleep,
      int clientCount, int callerCount, int callCount) throws IOException,
      InterruptedException, ExecutionException {
    Configuration conf = new Configuration();
    conf.setInt(CommonConfigurationKeys.IPC_CLIENT_ASYNC_CALLS_MAX_KEY, 100);
    Client.setPingInterval(conf, TestIPC.PING_INTERVAL);

    Server server = new TestIPC.TestServer(handlerCount, handlerSleep, conf);
    InetSocketAddress addr = NetUtils.getConnectAddress(server);
    server.start();

    Client[] clients = new Client[clientCount];
    for (int i = 0; i < clientCount; i++) {
      clients[i] = new Client(LongWritable.class, conf);
    }

    AsyncLimitlCaller[] callers = new AsyncLimitlCaller[callerCount];
    for (int i = 0; i < callerCount; i++) {
      callers[i] = new AsyncLimitlCaller(i, clients[i % clientCount], addr,
          callCount);
      callers[i].start();
    }
    for (int i = 0; i < callerCount; i++) {
      callers[i].join();
      callers[i].waitForReturnValues(callers[i].getStart(),
          callers[i].getCount());
      String msg = String.format("Expected not failed for caller-%d: %s.", i,
          callers[i]);
      assertFalse(callers[i].failed, msg);
    }
    for (int i = 0; i < clientCount; i++) {
      clients[i].stop();
    }
    server.stop();
  }

  /**
   * Test if (1) the rpc server uses the call id/retry provided by the rpc
   * client, and (2) the rpc client receives the same call id/retry from the rpc
   * server.
   *
   * @throws ExecutionException
   * @throws InterruptedException
   */
  @Test
  @Timeout(value = 60)
  public void testCallIdAndRetry() throws IOException, InterruptedException,
      ExecutionException {
    final Map<Integer, CallInfo> infoMap = new HashMap<Integer, CallInfo>();

    // Override client to store the call info and check response
    final Client client = new Client(LongWritable.class, conf) {
      @Override
      Call createCall(RpcKind rpcKind, Writable rpcRequest) {
        // Set different call id and retry count for the next call
        Client.setCallIdAndRetryCount(Client.nextCallId(),
            TestIPC.RANDOM.nextInt(255), null);

        final Call call = super.createCall(rpcKind, rpcRequest);

        CallInfo info = new CallInfo();
        info.id = call.id;
        info.retry = call.retry;
        infoMap.put(call.id, info);

        return call;
      }

      @Override
      void checkResponse(RpcResponseHeaderProto header) throws IOException {
        super.checkResponse(header);
        assertEquals(infoMap.get(header.getCallId()).retry,
            header.getRetryCount());
      }
    };

    // Attach a listener that tracks every call received by the server.
    final TestServer server = new TestIPC.TestServer(1, false, conf);
    server.callListener = new Runnable() {
      @Override
      public void run() {
        assertEquals(infoMap.get(Server.getCallId()).retry,
            Server.getCallRetryCount());
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      final AsyncCaller caller = new AsyncCaller(client, addr, 4, true);
      caller.run();
      caller.assertReturnValues();
    } finally {
      client.stop();
      server.stop();
    }
  }

  /**
   * Test if the rpc server gets the retry count from client.
   *
   * @throws ExecutionException
   * @throws InterruptedException
   */
  @Test
  @Timeout(value = 60)
  public void testCallRetryCount() throws IOException, InterruptedException,
      ExecutionException {
    final int retryCount = 255;
    // Override client to store the call id
    final Client client = new Client(LongWritable.class, conf);
    Client.setCallIdAndRetryCount(Client.nextCallId(), retryCount, null);

    // Attach a listener that tracks every call ID received by the server.
    final TestServer server = new TestIPC.TestServer(1, false, conf);
    server.callListener = new Runnable() {
      @Override
      public void run() {
        // we have not set the retry count for the client, thus on the server
        // side we should see retry count as 0
        assertEquals(retryCount, Server.getCallRetryCount());
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      final AsyncCaller caller = new AsyncCaller(client, addr, 10, true);
      caller.run();
      caller.assertReturnValues();
    } finally {
      client.stop();
      server.stop();
    }
  }

  /**
   * Test if the rpc server gets the default retry count (0) from client.
   *
   * @throws ExecutionException
   * @throws InterruptedException
   */
  @Test
  @Timeout(value = 60)
  public void testInitialCallRetryCount() throws IOException,
      InterruptedException, ExecutionException {
    // Override client to store the call id
    final Client client = new Client(LongWritable.class, conf);
    Client.setCallIdAndRetryCount(Client.nextCallId(), 0, null);

    // Attach a listener that tracks every call ID received by the server.
    final TestServer server = new TestIPC.TestServer(1, false, conf);
    server.callListener = new Runnable() {
      @Override
      public void run() {
        // we have not set the retry count for the client, thus on the server
        // side we should see retry count as 0
        assertEquals(0, Server.getCallRetryCount());
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      final AsyncCaller caller = new AsyncCaller(client, addr, 10, true);
      caller.run();
      caller.assertReturnValues();
    } finally {
      client.stop();
      server.stop();
    }
  }

  /**
   * Tests that client generates a unique sequential call ID for each RPC call,
   * even if multiple threads are using the same client.
   *
   * @throws InterruptedException
   * @throws ExecutionException
   */
  @Test
  @Timeout(value = 60)
  public void testUniqueSequentialCallIds() throws IOException,
      InterruptedException, ExecutionException {
    int serverThreads = 10, callerCount = 100, perCallerCallCount = 100;
    TestServer server = new TestIPC.TestServer(serverThreads, false, conf);

    // Attach a listener that tracks every call ID received by the server. This
    // list must be synchronized, because multiple server threads will add to
    // it.
    final List<Integer> callIds = Collections
        .synchronizedList(new ArrayList<Integer>());
    server.callListener = new Runnable() {
      @Override
      public void run() {
        callIds.add(Server.getCallId());
      }
    };

    Client client = new Client(LongWritable.class, conf);

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      AsyncCaller[] callers = new AsyncCaller[callerCount];
      for (int i = 0; i < callerCount; ++i) {
        callers[i] = new AsyncCaller(client, addr, perCallerCallCount, true);
        callers[i].start();
      }
      for (int i = 0; i < callerCount; ++i) {
        callers[i].join();
        callers[i].assertReturnValues();
      }
    } finally {
      client.stop();
      server.stop();
    }

    int expectedCallCount = callerCount * perCallerCallCount;
    assertEquals(expectedCallCount, callIds.size());

    // It is not guaranteed that the server executes requests in sequential
    // order
    // of client call ID, so we must sort the call IDs before checking that it
    // contains every expected value.
    Collections.sort(callIds);
    final int startID = callIds.get(0).intValue();
    for (int i = 0; i < expectedCallCount; ++i) {
      assertEquals(startID + i, callIds.get(i).intValue());
    }
  }

  @Test
  @Timeout(value = 60)
  public void testAsyncCallWithCompletableFuture() throws IOException,
      InterruptedException, ExecutionException {
    // Override client to store the call id
    final Client client = new Client(LongWritable.class, conf);

    // Construct an RPC server, which includes a handler thread.
    final TestServer server = new TestIPC.TestServer(1, false, conf);
    server.callListener = () -> {
      try {
        // The server requires at least 100 milliseconds to process a request.
        Thread.sleep(100);
      } catch (InterruptedException e) {
        throw new RuntimeException(e);
      }
    };

    try {
      InetSocketAddress addr = NetUtils.getConnectAddress(server);
      server.start();
      // Send 10 asynchronous requests.
      final AsyncCompletableFutureCaller caller =
          new AsyncCompletableFutureCaller(client, addr, 10);
      caller.start();
      caller.join();
      // Check if the values returned by the asynchronous call meet the expected values.
      caller.assertReturnValues();
    } finally {
      client.stop();
      server.stop();
    }
  }
}