ProtocolTestBase.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements. See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership. The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License. You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied. See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.thrift.protocol;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.thrift.Fixtures;
import org.apache.thrift.TBase;
import org.apache.thrift.TConfiguration;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.TSerializer;
import org.apache.thrift.server.ServerTestBase;
import org.apache.thrift.transport.TMemoryBuffer;
import org.apache.thrift.transport.TTransportException;
import org.junit.jupiter.api.Test;
import thrift.test.CompactProtoTestStruct;
import thrift.test.HolyMoley;
import thrift.test.Nesting;
import thrift.test.OneOfEach;
import thrift.test.Srv;
import thrift.test.ThriftTest;

public abstract class ProtocolTestBase {

  /** Does it make sense to call methods like writeI32 directly on your protocol? */
  protected abstract boolean canBeUsedNaked();

  /** The protocol factory for the protocol being tested. */
  protected abstract TProtocolFactory getFactory();

  @Test
  public void testDouble() throws Exception {
    if (canBeUsedNaked()) {
      TMemoryBuffer buf = new TMemoryBuffer(1000);
      TProtocol proto = getFactory().getProtocol(buf);
      proto.writeDouble(123.456);
      assertEquals(123.456, proto.readDouble());
    }

    internalTestStructField(
        new StructFieldTestCase(TType.DOUBLE, (short) 15) {
          @Override
          public void readMethod(TProtocol proto) throws TException {
            assertEquals(123.456, proto.readDouble());
          }

          @Override
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeDouble(123.456);
          }
        });
  }

  @Test
  public void testSerialization() throws Exception {
    internalTestSerialization(OneOfEach.class, Fixtures.getOneOfEach());
    internalTestSerialization(Nesting.class, Fixtures.getNesting());
    internalTestSerialization(HolyMoley.class, Fixtures.getHolyMoley());
    internalTestSerialization(CompactProtoTestStruct.class, Fixtures.getCompactProtoTestStruct());
  }

  @Test
  public void testBinary() throws Exception {
    for (byte[] b :
        Arrays.asList(
            new byte[0],
            new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
            new byte[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14},
            new byte[] {0x5D},
            new byte[] {(byte) 0xD5, (byte) 0x5D},
            new byte[] {(byte) 0xFF, (byte) 0xD5, (byte) 0x5D},
            new byte[128])) {
      if (canBeUsedNaked()) {
        internalTestNakedBinary(b);
      }
      internalTestBinaryField(b);
    }

    if (canBeUsedNaked()) {
      byte[] data = {1, 2, 3, 4, 5, 6};

      TMemoryBuffer buf = new TMemoryBuffer(0);
      TProtocol proto = getFactory().getProtocol(buf);
      ByteBuffer bb = ByteBuffer.wrap(data);
      bb.get();
      proto.writeBinary(bb.slice());
      assertEquals(ByteBuffer.wrap(data, 1, 5), proto.readBinary());
    }
  }

  @Test
  public void testString() throws Exception {
    for (String s :
        Arrays.asList("", "short", "borderlinetiny", "a bit longer than the smallest possible")) {
      if (canBeUsedNaked()) {
        internalTestNakedString(s);
      }
      internalTestStringField(s);
    }
  }

  @Test
  public void testUuid() throws Exception {
    UUID uuid = UUID.fromString("00112233-4455-6677-8899-aabbccddeeff");
    if (canBeUsedNaked()) {
      internalTestNakedUuid(uuid);
    }
    internalTestUuidField(uuid);
  }

  @Test
  public void testLong() throws Exception {
    if (canBeUsedNaked()) {
      internalTestNakedI64(0);
    }
    internalTestI64Field(0);
    for (int i = 0; i < 62; i++) {
      if (canBeUsedNaked()) {
        internalTestNakedI64(1L << i);
        internalTestNakedI64(-(1L << i));
      }
      internalTestI64Field(1L << i);
      internalTestI64Field(-(1L << i));
    }
  }

  @Test
  public void testInt() throws Exception {
    for (int i :
        Arrays.asList(
            0, 1, 7, 150, 15000, 31337, 0xffff, 0xffffff, -1, -7, -150, -15000, -0xffff,
            -0xffffff)) {
      if (canBeUsedNaked()) {
        internalTestNakedI32(i);
      }
      internalTestI32Field(i);
    }
  }

  @Test
  public void testShort() throws Exception {
    for (int s : Arrays.asList(0, 1, 7, 150, 15000, 0x7fff, -1, -7, -150, -15000, -0x7fff)) {
      if (canBeUsedNaked()) {
        internalTestNakedI16((short) s);
      }
      internalTestI16Field((short) s);
    }
  }

  @Test
  public void testByte() throws Exception {
    if (canBeUsedNaked()) {
      internalTestNakedByte();
    }
    for (int i = 0; i < 128; i++) {
      internalTestByteField((byte) i);
      internalTestByteField((byte) -i);
    }
  }

  private void internalTestNakedByte() throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(1000);
    TProtocol proto = getFactory().getProtocol(buf);
    proto.writeByte((byte) 123);
    assertEquals((byte) 123, proto.readByte());
  }

  private void internalTestByteField(final byte b) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.BYTE, (short) 15) {
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeByte(b);
          }

          public void readMethod(TProtocol proto) throws TException {
            assertEquals(b, proto.readByte());
          }
        });
  }

  private void internalTestNakedI16(short n) throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);
    proto.writeI16(n);
    assertEquals(n, proto.readI16());
  }

  private void internalTestI16Field(final short n) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.I16, (short) 15) {
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeI16(n);
          }

          public void readMethod(TProtocol proto) throws TException {
            assertEquals(n, proto.readI16());
          }
        });
  }

  private void internalTestNakedUuid(UUID uuid) throws TException {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol protocol = getFactory().getProtocol(buf);
    protocol.writeUuid(uuid);
    assertEquals(uuid, protocol.readUuid());
  }

  private void internalTestUuidField(UUID uuid) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.UUID, (short) 17) {
          @Override
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeUuid(uuid);
          }

          @Override
          public void readMethod(TProtocol proto) throws TException {
            assertEquals(uuid, proto.readUuid());
          }
        });
  }

  private void internalTestNakedI32(int n) throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);
    proto.writeI32(n);
    assertEquals(n, proto.readI32());
  }

  private void internalTestI32Field(final int n) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.I32, (short) 15) {
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeI32(n);
          }

          public void readMethod(TProtocol proto) throws TException {
            assertEquals(n, proto.readI32());
          }
        });
  }

  private void internalTestNakedI64(long n) throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);
    proto.writeI64(n);
    assertEquals(n, proto.readI64());
  }

  private void internalTestI64Field(final long n) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.I64, (short) 15) {
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeI64(n);
          }

          public void readMethod(TProtocol proto) throws TException {
            assertEquals(n, proto.readI64());
          }
        });
  }

  private void internalTestNakedString(String str) throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);
    proto.writeString(str);
    assertEquals(str, proto.readString());
  }

  private void internalTestStringField(final String str) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.STRING, (short) 15) {
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeString(str);
          }

          public void readMethod(TProtocol proto) throws TException {
            assertEquals(str, proto.readString());
          }
        });
  }

  private void internalTestNakedBinary(byte[] data) throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);
    proto.writeBinary(ByteBuffer.wrap(data));
    assertEquals(ByteBuffer.wrap(data), proto.readBinary());
  }

  private void internalTestBinaryField(final byte[] data) throws Exception {
    internalTestStructField(
        new StructFieldTestCase(TType.STRING, (short) 15) {
          public void writeMethod(TProtocol proto) throws TException {
            proto.writeBinary(ByteBuffer.wrap(data));
          }

          public void readMethod(TProtocol proto) throws TException {
            assertEquals(ByteBuffer.wrap(data), proto.readBinary());
          }
        });
  }

  private <T extends TBase> void internalTestSerialization(Class<T> klass, T expected)
      throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TBinaryProtocol binproto = new TBinaryProtocol(buf);

    expected.write(binproto);

    buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);

    expected.write(proto);
    System.out.println("Size in " + proto.getClass().getSimpleName() + ": " + buf.length());

    T actual = klass.getDeclaredConstructor().newInstance();
    actual.read(proto);
    assertEquals(expected, actual);
  }

  @Test
  public void testMessage() throws Exception {
    List<TMessage> msgs =
        Arrays.asList(
            new TMessage[] {
              new TMessage("short message name", TMessageType.CALL, 0),
              new TMessage("1", TMessageType.REPLY, 12345),
              new TMessage(
                  "loooooooooooooooooooooooooooooooooong", TMessageType.EXCEPTION, 1 << 16),
              new TMessage("Janky", TMessageType.CALL, 0),
            });

    for (TMessage msg : msgs) {
      TMemoryBuffer buf = new TMemoryBuffer(0);
      TProtocol proto = getFactory().getProtocol(buf);
      TMessage output = null;

      proto.writeMessageBegin(msg);
      proto.writeMessageEnd();

      output = proto.readMessageBegin();

      assertEquals(msg, output);
    }
  }

  @Test
  public void testServerRequest() throws Exception {
    Srv.Iface handler =
        new Srv.Iface() {
          public int Janky(int i32arg) throws TException {
            return i32arg * 2;
          }

          public int primitiveMethod() throws TException {
            return 0;
          }

          public CompactProtoTestStruct structMethod() throws TException {
            return null;
          }

          public void voidMethod() throws TException {}

          public void methodWithDefaultArgs(int something) throws TException {}

          @Override
          public void onewayMethod() throws TException {}

          @Override
          public boolean declaredExceptionMethod(boolean shouldThrow) throws TException {
            return shouldThrow;
          }
        };

    Srv.Processor testProcessor = new Srv.Processor(handler);

    TMemoryBuffer clientOutTrans = new TMemoryBuffer(0);
    TProtocol clientOutProto = getFactory().getProtocol(clientOutTrans);
    TMemoryBuffer clientInTrans = new TMemoryBuffer(0);
    TProtocol clientInProto = getFactory().getProtocol(clientInTrans);

    Srv.Client testClient = new Srv.Client(clientInProto, clientOutProto);

    testClient.send_Janky(1);
    // System.out.println(clientOutTrans.inspect());
    testProcessor.process(clientOutProto, clientInProto);
    // System.out.println(clientInTrans.inspect());
    assertEquals(2, testClient.recv_Janky());
  }

  @Test
  public void testTDeserializer() throws TException {
    TSerializer ser = new TSerializer(getFactory());
    byte[] bytes = ser.serialize(Fixtures.getCompactProtoTestStruct());

    TDeserializer deser = new TDeserializer(getFactory());
    CompactProtoTestStruct cpts = new CompactProtoTestStruct();
    deser.deserialize(cpts, bytes);

    assertEquals(Fixtures.getCompactProtoTestStruct(), cpts);
  }

  //
  // Helper methods
  //

  private void internalTestStructField(StructFieldTestCase testCase) throws Exception {
    TMemoryBuffer buf = new TMemoryBuffer(0);
    TProtocol proto = getFactory().getProtocol(buf);

    TField field = new TField("test_field", testCase.type_, testCase.id_);
    proto.writeStructBegin(new TStruct("test_struct"));
    proto.writeFieldBegin(field);
    testCase.writeMethod(proto);
    proto.writeFieldEnd();
    proto.writeStructEnd();

    proto.readStructBegin();
    TField readField = proto.readFieldBegin();
    assertEquals(testCase.id_, readField.id);
    assertEquals(testCase.type_, readField.type);
    testCase.readMethod(proto);
    proto.readStructEnd();
  }

  private abstract static class StructFieldTestCase {
    byte type_;
    short id_;

    public StructFieldTestCase(byte type, short id) {
      type_ = type;
      id_ = id;
    }

    public abstract void writeMethod(TProtocol proto) throws TException;

    public abstract void readMethod(TProtocol proto) throws TException;
  }

  private static final int NUM_TRIALS = 5;
  private static final int NUM_REPS = 10000;

  protected void benchmark() throws Exception {
    for (int trial = 0; trial < NUM_TRIALS; trial++) {
      TSerializer ser = new TSerializer(getFactory());
      byte[] serialized = null;
      long serStart = System.currentTimeMillis();
      for (int rep = 0; rep < NUM_REPS; rep++) {
        serialized = ser.serialize(Fixtures.getHolyMoley());
      }
      long serEnd = System.currentTimeMillis();
      long serElapsed = serEnd - serStart;
      System.out.println(
          "Ser:\t"
              + serElapsed
              + "ms\t"
              + ((double) serElapsed / NUM_REPS)
              + "ms per serialization");

      HolyMoley cpts = new HolyMoley();
      TDeserializer deser = new TDeserializer(getFactory());
      long deserStart = System.currentTimeMillis();
      for (int rep = 0; rep < NUM_REPS; rep++) {
        deser.deserialize(cpts, serialized);
      }
      long deserEnd = System.currentTimeMillis();
      long deserElapsed = deserEnd - deserStart;
      System.out.println(
          "Des:\t"
              + deserElapsed
              + "ms\t"
              + ((double) deserElapsed / NUM_REPS)
              + "ms per deserialization");
    }
  }

  private final ServerTestBase.TestHandler testHandler =
      new ServerTestBase.TestHandler() {
        @Override
        public String testString(String thing) {
          thing = thing + " Apache Thrift Java " + thing;
          return thing;
        }

        @Override
        public List<Integer> testList(List<Integer> thing) {
          thing.addAll(thing);
          thing.addAll(thing);
          return thing;
        }

        @Override
        public Set<Integer> testSet(Set<Integer> thing) {
          thing.addAll(thing.stream().map(x -> x + 100).collect(Collectors.toSet()));
          return thing;
        }

        @Override
        public Map<String, String> testStringMap(Map<String, String> thing) {
          thing.put("a", "123");
          thing.put(" x y ", " with spaces ");
          thing.put("same", "same");
          thing.put("0", "numeric key");
          thing.put("1", "");
          thing.put("ok", "2355555");
          thing.put("end", "0");
          return thing;
        }
      };

  private TProtocol initConfig(int maxSize) throws TException {
    TConfiguration config = TConfiguration.custom().setMaxMessageSize(maxSize).build();
    TMemoryBuffer bufferTrans = new TMemoryBuffer(config, 0);
    return getFactory().getProtocol(bufferTrans);
  }

  @Test
  public void testReadCheckMaxMessageRequestForString() throws TException {
    TProtocol clientOutProto = initConfig(15);
    TProtocol clientInProto = initConfig(15);
    ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
    ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
    try {
      testClient.send_testString("test");
      testProcessor.process(clientOutProto, clientInProto);
      String result = testClient.recv_testString();
      System.out.println("----result: " + result);
    } catch (TException e) {
      assertEquals(TTransportException.MESSAGE_SIZE_LIMIT, ((TTransportException) e).getType());
    }
  }

  @Test
  public void testReadCheckMaxMessageRequestForList() throws TException {
    TProtocol clientOutProto = initConfig(15);
    TProtocol clientInProto = initConfig(15);
    ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
    ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
    TTransportException e =
        assertThrows(
            TTransportException.class,
            () -> {
              testClient.send_testList(Arrays.asList(1, 23242346, 888888, 90));
              testProcessor.process(clientOutProto, clientInProto);
              testClient.recv_testList();
            },
            "Limitations not achieved as expected");
    assertEquals(TTransportException.MESSAGE_SIZE_LIMIT, e.getType());
  }

  @Test
  public void testReadCheckMaxMessageRequestForMap() throws TException {
    TProtocol clientOutProto = initConfig(13);
    TProtocol clientInProto = initConfig(13);
    ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
    ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
    Map<String, String> thing = new HashMap<>();
    thing.put("key", "Thrift");

    TTransportException e =
        assertThrows(
            TTransportException.class,
            () -> {
              testClient.send_testStringMap(thing);
              testProcessor.process(clientOutProto, clientInProto);
              testClient.recv_testStringMap();
            },
            "Limitations not achieved as expected");

    assertEquals(TTransportException.MESSAGE_SIZE_LIMIT, e.getType());
  }

  @Test
  public void testReadCheckMaxMessageRequestForSet() throws TException {
    TProtocol clientOutProto = initConfig(10);
    TProtocol clientInProto = initConfig(10);
    ThriftTest.Client testClient = new ThriftTest.Client(clientInProto, clientOutProto);
    ThriftTest.Processor testProcessor = new ThriftTest.Processor(testHandler);
    TTransportException e =
        assertThrows(
            TTransportException.class,
            () -> {
              testClient.send_testSet(
                  Stream.of(234, 0, 987087, 45, 88888888, 9).collect(Collectors.toSet()));
              testProcessor.process(clientOutProto, clientInProto);
              testClient.recv_testSet();
            },
            "Limitations not achieved as expected");
    assertEquals(TTransportException.MESSAGE_SIZE_LIMIT, e.getType());
  }
}