TestTAsyncClientManager.java
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.thrift.async;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.server.ServerTestBase;
import org.apache.thrift.server.THsHaServer;
import org.apache.thrift.server.THsHaServer.Args;
import org.apache.thrift.server.TServer;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TNonblockingSocket;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransportException;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import thrift.test.CompactProtoTestStruct;
import thrift.test.ExceptionWithAMap;
import thrift.test.Srv;
import thrift.test.Srv.Iface;
public class TestTAsyncClientManager {
protected TServer server_;
protected Thread serverThread_;
protected TAsyncClientManager clientManager_;
@BeforeEach
public void setUp() throws Exception {
server_ =
new THsHaServer(
new Args(
new TNonblockingServerSocket(
new TNonblockingServerSocket.NonblockingAbstractServerSocketArgs()
.port(ServerTestBase.PORT)))
.processor(new Srv.Processor(new SrvHandler())));
serverThread_ =
new Thread(
new Runnable() {
public void run() {
server_.serve();
}
});
serverThread_.start();
clientManager_ = new TAsyncClientManager();
Thread.sleep(500);
}
@AfterEach
public void tearDown() throws Exception {
server_.stop();
clientManager_.stop();
serverThread_.join();
}
@Test
public void testBasicCall() throws Exception {
Srv.AsyncClient client = getClient();
basicCall(client);
}
@Test
public void testBasicCallWithTimeout() throws Exception {
Srv.AsyncClient client = getClient();
client.setTimeout(5000);
basicCall(client);
}
private abstract static class ErrorCallTest<C extends TAsyncClient, R> {
final void runTest() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicReference<Exception> error = new AtomicReference<Exception>();
C client =
executeErroringCall(
new AsyncMethodCallback<R>() {
@Override
public void onComplete(R response) {
latch.countDown();
}
@Override
public void onError(Exception exception) {
error.set(exception);
latch.countDown();
}
});
latch.await(2, TimeUnit.SECONDS);
assertTrue(client.hasError());
Exception exception = error.get();
assertNotNull(exception);
assertSame(exception, client.getError());
validateError(client, exception);
}
/**
* Executes a call that is expected to raise an exception.
*
* @param callback The testing callback that should be installed.
* @return The client the call was made against.
* @throws Exception if there was a problem setting up the client or making the call.
*/
abstract C executeErroringCall(AsyncMethodCallback<R> callback) throws Exception;
/**
* Further validates the properties of the error raised in the remote call and the state of the
* client after that call.
*
* @param client The client returned from {@link #executeErroringCall(AsyncMethodCallback)}.
* @param error The exception raised by the remote call.
*/
abstract void validateError(C client, Exception error);
}
@Test
public void testUnexpectedRemoteExceptionCall() throws Exception {
new ErrorCallTest<Srv.AsyncClient, Boolean>() {
@Override
Srv.AsyncClient executeErroringCall(AsyncMethodCallback<Boolean> callback) throws Exception {
Srv.AsyncClient client = getClient();
client.declaredExceptionMethod(false, callback);
return client;
}
@Override
void validateError(Srv.AsyncClient client, Exception error) {
assertFalse(client.hasTimeout());
assertTrue(error instanceof TException);
}
}.runTest();
}
@Test
public void testDeclaredRemoteExceptionCall() throws Exception {
new ErrorCallTest<Srv.AsyncClient, Boolean>() {
@Override
Srv.AsyncClient executeErroringCall(AsyncMethodCallback<Boolean> callback) throws Exception {
Srv.AsyncClient client = getClient();
client.declaredExceptionMethod(true, callback);
return client;
}
@Override
void validateError(Srv.AsyncClient client, Exception error) {
assertFalse(client.hasTimeout());
assertEquals(ExceptionWithAMap.class, error.getClass());
ExceptionWithAMap exceptionWithAMap = (ExceptionWithAMap) error;
assertEquals("blah", exceptionWithAMap.getBlah());
assertEquals(new HashMap<String, String>(), exceptionWithAMap.getMap_field());
}
}.runTest();
}
@Test
public void testTimeoutCall() throws Exception {
new ErrorCallTest<Srv.AsyncClient, Integer>() {
@Override
Srv.AsyncClient executeErroringCall(AsyncMethodCallback<Integer> callback) throws Exception {
Srv.AsyncClient client = getClient();
client.setTimeout(100);
client.primitiveMethod(callback);
return client;
}
@Override
void validateError(Srv.AsyncClient client, Exception error) {
assertTrue(client.hasTimeout());
assertTrue(error instanceof TimeoutException);
}
}.runTest();
}
@Test
public void testVoidCall() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean returned = new AtomicBoolean(false);
Srv.AsyncClient client = getClient();
client.voidMethod(
new FailureLessCallback<Void>() {
@Override
public void onComplete(Void response) {
returned.set(true);
latch.countDown();
}
});
latch.await(1, TimeUnit.SECONDS);
assertTrue(returned.get());
}
@Test
public void testOnewayCall() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean returned = new AtomicBoolean(false);
Srv.AsyncClient client = getClient();
client.onewayMethod(
new FailureLessCallback<Void>() {
@Override
public void onComplete(Void response) {
returned.set(true);
latch.countDown();
}
});
latch.await(1, TimeUnit.SECONDS);
assertTrue(returned.get());
}
@Test
public void testParallelCalls() throws Exception {
// make multiple calls with deserialization in the selector thread (repro Eric's issue)
int numThreads = 50;
int numCallsPerThread = 100;
List<JankyRunnable> runnables = new ArrayList<JankyRunnable>();
List<Thread> threads = new ArrayList<Thread>();
for (int i = 0; i < numThreads; i++) {
JankyRunnable runnable = new JankyRunnable(numCallsPerThread);
Thread thread = new Thread(runnable);
thread.start();
threads.add(thread);
runnables.add(runnable);
}
for (Thread thread : threads) {
thread.join();
}
int numSuccesses = 0;
for (JankyRunnable runnable : runnables) {
numSuccesses += runnable.getNumSuccesses();
}
assertEquals(numThreads * numCallsPerThread, numSuccesses);
}
private Srv.AsyncClient getClient() throws IOException, TTransportException {
return new Srv.AsyncClient(new TBinaryProtocol.Factory(), clientManager_, getClientTransport());
}
protected TNonblockingTransport getClientTransport() throws TTransportException, IOException {
return new TNonblockingSocket(ServerTestBase.HOST, ServerTestBase.PORT);
}
private void basicCall(Srv.AsyncClient client) throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean returned = new AtomicBoolean(false);
client.Janky(
1,
new FailureLessCallback<Integer>() {
@Override
public void onComplete(Integer response) {
assertEquals(3, response.intValue());
returned.set(true);
latch.countDown();
}
@Override
public void onError(Exception exception) {
try {
StringWriter sink = new StringWriter();
exception.printStackTrace(new PrintWriter(sink, true));
Assertions.fail("unexpected onError with exception " + sink.toString());
} finally {
latch.countDown();
}
}
});
latch.await(100, TimeUnit.SECONDS);
assertTrue(returned.get());
}
public static class SrvHandler implements Iface {
// Use this method for a standard call testing
@Override
public int Janky(int arg) throws TException {
assertEquals(1, arg);
return 3;
}
// Using this method for timeout testing - sleeps for 1 second before returning
@Override
public int primitiveMethod() throws TException {
try {
Thread.sleep(1000);
} catch (InterruptedException e) {
e.printStackTrace();
}
return 0;
}
@Override
public void methodWithDefaultArgs(int something) throws TException {}
@Override
public CompactProtoTestStruct structMethod() throws TException {
return null;
}
@Override
public void voidMethod() throws TException {}
@Override
public void onewayMethod() throws TException {}
@Override
public boolean declaredExceptionMethod(boolean shouldThrowDeclared) throws TException {
if (shouldThrowDeclared) {
throw new ExceptionWithAMap("blah", new HashMap<String, String>());
} else {
throw new TException("Unexpected!");
}
}
}
private abstract static class FailureLessCallback<T> implements AsyncMethodCallback<T> {
@Override
public void onError(Exception exception) {
fail(exception);
}
}
private static void fail(Exception exception) {
StringWriter sink = new StringWriter();
exception.printStackTrace(new PrintWriter(sink, true));
Assertions.fail("unexpected error " + sink);
}
private class JankyRunnable implements Runnable {
private final int numCalls_;
private int numSuccesses_ = 0;
private final Srv.AsyncClient client_;
public JankyRunnable(int numCalls) throws Exception {
numCalls_ = numCalls;
client_ = getClient();
client_.setTimeout(20000);
}
public int getNumSuccesses() {
return numSuccesses_;
}
public void run() {
for (int i = 0; i < numCalls_ && !client_.hasError(); i++) {
final int iteration = i;
try {
// connect an async client
final CountDownLatch latch = new CountDownLatch(1);
final AtomicBoolean returned = new AtomicBoolean(false);
client_.Janky(
1,
new AsyncMethodCallback<Integer>() {
@Override
public void onComplete(Integer result) {
assertEquals(3, result.intValue());
returned.set(true);
latch.countDown();
}
@Override
public void onError(Exception exception) {
try {
StringWriter sink = new StringWriter();
exception.printStackTrace(new PrintWriter(sink, true));
Assertions.fail(
"unexpected onError on iteration " + iteration + ": " + sink.toString());
} finally {
latch.countDown();
}
}
});
boolean calledBack = latch.await(30, TimeUnit.SECONDS);
assertTrue(calledBack, "wasn't called back in time on iteration " + iteration);
assertTrue(returned.get(), "onComplete not called on iteration " + iteration);
this.numSuccesses_++;
} catch (Exception e) {
fail(e);
}
}
}
}
}