TestPartialThriftDeserializer.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.thrift.partial;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.thrift.TDeserializer;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
import org.junit.jupiter.api.Test;

public class TestPartialThriftDeserializer {

  private ThriftSerDe serde = new ThriftSerDe();
  private TBinaryProtocol.Factory binaryProtocolFactory = new TBinaryProtocol.Factory();
  private TCompactProtocol.Factory compactProtocolFactory = new TCompactProtocol.Factory();

  private PartialThriftTestData testData = new PartialThriftTestData();

  public TestPartialThriftDeserializer() throws TException {}

  @Test
  public void testArgChecks() throws TException {
    // Should not throw.
    List<String> fieldNames = Collections.singletonList("i32Field");
    new TDeserializer(TestStruct.class, fieldNames, binaryProtocolFactory);

    // Verify it throws correctly.
    assertThrows(
        IllegalArgumentException.class,
        () -> new TDeserializer(null, fieldNames, binaryProtocolFactory),
        "'thriftClass' must not be null");

    assertThrows(
        IllegalArgumentException.class,
        () -> new TDeserializer(TestStruct.class, null, binaryProtocolFactory),
        "'fieldNames' must not be null");

    assertThrows(
        IllegalArgumentException.class,
        () -> new TDeserializer(TestStruct.class, fieldNames, null, binaryProtocolFactory),
        "'processor' must not be null");
  }

  /**
   * This test does not use partial deserialization. It is used to establish correctness of full
   * serialization used in the other tests.
   */
  @Test
  public void testRoundTripFull() throws TException {
    TestStruct ts1 = testData.createTestStruct(1, 2);

    byte[] bytesBinary = serde.serializeBinary(ts1);
    byte[] bytesCompact = serde.serializeCompact(ts1);

    TestStruct ts2 = serde.deserializeBinary(bytesBinary, TestStruct.class);
    assertEquals(ts1, ts2);

    ts2 = serde.deserializeCompact(bytesCompact, TestStruct.class);
    assertEquals(ts1, ts2);
  }

  @Test
  public void testPartialSimpleField() throws TException, IOException {
    TestStruct ts1 = testData.createTestStruct(1, 1);
    assertTrue(ts1.isSetI16Field());
    assertTrue(ts1.isSetI32Field());

    byte[] bytesBinary = serde.serializeBinary(ts1);
    byte[] bytesCompact = serde.serializeCompact(ts1);

    List<String> fieldNames = Arrays.asList("i32Field");

    TDeserializer partialBinaryDeserializer =
        new TDeserializer(TestStruct.class, fieldNames, binaryProtocolFactory);
    TDeserializer partialCompactDeserializer =
        new TDeserializer(TestStruct.class, fieldNames, compactProtocolFactory);

    PartialThriftComparer comparer =
        new PartialThriftComparer(partialBinaryDeserializer.getMetadata());

    StringBuilder sb = new StringBuilder();
    TestStruct ts2 = (TestStruct) partialBinaryDeserializer.partialDeserializeObject(bytesBinary);
    validatePartialSimpleField(ts1, ts2);
    assertTrue(comparer.areEqual(ts1, ts2, sb), sb::toString);

    ts2 = (TestStruct) partialCompactDeserializer.partialDeserializeObject(bytesCompact);
    validatePartialSimpleField(ts1, ts2);
    assertTrue(comparer.areEqual(ts1, ts2, sb), sb::toString);
  }

  private void validatePartialSimpleField(TestStruct ts1, TestStruct ts2) {
    assertTrue(ts2.isSetI32Field(), ts2.toString());
    assertEquals(ts1.getI32Field(), ts2.getI32Field());
    assertFalse(ts2.isSetI16Field());
  }

  @Test
  public void testPartialComplex() throws TException {
    int id = 1;
    int numItems = 10;
    TestStruct ts1 = testData.createTestStruct(id, numItems);

    byte[] bytesBinary = serde.serializeBinary(ts1);
    byte[] bytesCompact = serde.serializeCompact(ts1);

    List<String> fieldNames =
        Arrays.asList(
            "byteField",
            "i16Field",
            "i32Field",
            "i64Field",
            "doubleField",
            "stringField",
            "enumField",
            "binaryField",

            // List fields
            "byteList",
            "i16List",
            "i32List",
            "i64List",
            "doubleList",
            "stringList",
            "enumList",
            "listList",
            "setList",
            "mapList",
            "structList",
            "binaryList",

            // Set fields
            "byteSet",
            "i16Set",
            "i32Set",
            "i64Set",
            "doubleSet",
            "stringSet",
            "enumSet",
            "listSet",
            "setSet",
            "mapSet",
            "structSet",
            "binarySet",

            // Map fields
            "byteMap",
            "i16Map",
            "i32Map",
            "i64Map",
            "doubleMap",
            "stringMap",
            "enumMap",
            "listMap",
            "setMap",
            "mapMap",
            "structMap",
            "binaryMap",

            // Struct field
            "structField");
    StringBuilder sb = new StringBuilder();
    TDeserializer partialBinaryDeserializer =
        new TDeserializer(TestStruct.class, fieldNames, binaryProtocolFactory);
    TDeserializer partialCompactDeserializer =
        new TDeserializer(TestStruct.class, fieldNames, compactProtocolFactory);
    PartialThriftComparer comparer =
        new PartialThriftComparer(partialBinaryDeserializer.getMetadata());

    TestStruct ts2 = (TestStruct) partialBinaryDeserializer.partialDeserializeObject(bytesBinary);
    validatePartialComplex(ts1, ts2, id, numItems);
    assertTrue(comparer.areEqual(ts1, ts2, sb), sb::toString);

    ts2 = (TestStruct) partialCompactDeserializer.partialDeserializeObject(bytesCompact);
    validatePartialComplex(ts1, ts2, id, numItems);
    assertTrue(comparer.areEqual(ts1, ts2, sb), sb::toString);
  }

  private void validatePartialComplex(TestStruct ts1, TestStruct ts2, int id, int numItems) {

    // Validate primitive fields.
    assertTrue(ts2.isSetByteField(), ts2.toString());
    assertEquals(ts1.getByteField(), ts2.getByteField());

    assertTrue(ts2.isSetI16Field());
    assertEquals(ts1.getI16Field(), ts2.getI16Field());

    assertTrue(ts2.isSetI32Field());
    assertEquals(ts1.getI32Field(), ts2.getI32Field());

    assertTrue(ts2.isSetI64Field());
    assertEquals(ts1.getI64Field(), ts2.getI64Field());

    assertTrue(ts2.isSetDoubleField());
    assertEquals(ts1.getDoubleField(), ts2.getDoubleField(), 0.0001);

    assertTrue(ts2.isSetStringField());
    assertEquals(ts1.getStringField(), ts2.getStringField());

    assertTrue(ts2.isSetEnumField());
    assertEquals(ts1.getEnumField(), ts2.getEnumField());

    assertTrue(ts2.isSetBinaryField());
    assertArrayEquals(ts1.getBinaryField(), ts2.getBinaryField());

    // Validate list fields.
    validateList(ts2.getByteList(), id, numItems);
    validateList(ts2.getI16List(), id, numItems);
    validateList(ts2.getI32List(), id, numItems);
    validateList(ts2.getI64List(), id, numItems);
    validateList(ts2.getDoubleList(), id, numItems);
    validateStringList(ts2.getStringList(), id, numItems);
    validateEnumList(ts2.getEnumList(), id, numItems);

    validateListOfList(ts2.getListList(), id, numItems);
    validateListOfSet(ts2.getSetList(), id, numItems);
    validateListOfMap(ts2.getMapList(), id, numItems);
    validateListOfStruct(ts2.getStructList(), id, numItems);
    validateListOfBinary(ts2.getBinaryList(), id, numItems);

    // Validate set fields.
    validateSet(ts2.getByteSet(), Byte.class, numItems);
    validateSet(ts2.getI16Set(), Short.class, numItems);
    validateSet(ts2.getI32Set(), Integer.class, numItems);
    validateSet(ts2.getI64Set(), Long.class, numItems);
    validateSet(ts2.getDoubleSet(), Double.class, numItems);
    validateStringSet(ts2.getStringSet(), id, numItems);
    validateEnumSet(ts2.getEnumSet(), id, numItems);

    validateSetOfList(ts2.getListSet(), id, numItems);
    validateSetOfSet(ts2.getSetSet(), id, numItems);
    validateSetOfMap(ts2.getMapSet(), id, numItems);
    validateSetOfStruct(ts2.getStructSet(), id, numItems);
    validateSetOfBinary(ts2.getBinarySet(), id, numItems);

    // Validate map fields.
    validateMap(ts2.getByteMap(), Byte.class, numItems);
    validateMap(ts2.getI16Map(), Short.class, numItems);
    validateMap(ts2.getI32Map(), Integer.class, numItems);
    validateMap(ts2.getI64Map(), Long.class, numItems);
    validateMap(ts2.getDoubleMap(), Double.class, numItems);
    validateStringMap(ts2.getStringMap(), id, numItems);
    validateEnumMap(ts2.getEnumMap(), id, numItems);

    validateMapOfList(ts2.getListMap(), id, numItems);
    validateMapOfSet(ts2.getSetMap(), id, numItems);
    validateMapOfMap(ts2.getMapMap(), id, numItems);
    validateMapOfStruct(ts2.getStructMap(), id, numItems);
    validateMapOfBinary(ts2.getBinaryMap(), id, numItems);

    // Validate struct field.
    assertEquals(testData.createSmallStruct(id), ts2.getStructField());
  }

  private void validateNotNullAndNotEmpty(Collection<?> collection, int numItems) {
    assertNotNull(collection);
    assertEquals(numItems, collection.size());
  }

  // ----------------------------------------------------------------------
  // List validation helpers.

  private <V extends Number> void validateList(List<V> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);

    for (int i = 0; i < numItems; i++) {
      assertEquals(i, list.get(i).longValue());
    }
  }

  private void validateStringList(List<String> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);
    for (int i = 0; i < numItems; i++) {
      assertEquals(Integer.valueOf(i), Integer.valueOf(list.get(i)));
    }
  }

  private void validateEnumList(List<TstEnum> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);
    for (int i = 0; i < numItems; i++) {
      assertEquals(TstEnum.E_ONE, list.get(i));
    }
  }

  private <V extends Number> void validateListOfList(List<List<V>> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);

    for (int i = 0; i < numItems; i++) {
      validateList(list.get(i), id, numItems);
    }
  }

  private <V extends Number> void validateListOfSet(List<Set<V>> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);

    for (int i = 0; i < numItems; i++) {
      Set<V> set = list.get(i);
      for (int j = 0; j < numItems; j++) {
        assertTrue(set.contains(j));
      }
    }
  }

  private <V extends Number> void validateListOfMap(
      List<Map<String, V>> list, int id, int numItems) {

    validateNotNullAndNotEmpty(list, numItems);

    for (int i = 0; i < numItems; i++) {
      Map<String, V> map = list.get(i);
      for (int j = 0; j < numItems; j++) {
        String key = Integer.toString(j);
        assertTrue(map.containsKey(key));
        assertEquals(j, map.get(key));
      }
    }
  }

  private void validateListOfStruct(List<SmallStruct> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);

    for (int i = 0; i < numItems; i++) {
      SmallStruct ss = testData.createSmallStruct(i);
      for (int j = 0; j < numItems; j++) {
        assertEquals(ss, list.get(i));
      }
    }
  }

  private void validateListOfBinary(List<ByteBuffer> list, int id, int numItems) {
    validateNotNullAndNotEmpty(list, numItems);

    for (int i = 0; i < numItems; i++) {
      ByteBuffer bb = ByteBuffer.wrap(testData.BYTES);
      assertEquals(0, bb.compareTo(list.get(i)));
    }
  }

  // ----------------------------------------------------------------------
  // Set validation helpers.

  private <V extends Number> void validateSet(Set<V> set, Class<V> clasz, int numItems) {
    validateNotNullAndNotEmpty(set, numItems);

    for (int i = 0; i < numItems; i++) {
      if (clasz == Byte.class) {
        assertTrue(set.contains((byte) i));
      } else if (clasz == Short.class) {
        assertTrue(set.contains((short) i));
      } else if (clasz == Integer.class) {
        assertTrue(set.contains(i));
      } else if (clasz == Long.class) {
        assertTrue(set.contains((long) i));
      } else if (clasz == Double.class) {
        assertTrue(set.contains((double) i));
      }
    }
  }

  private void validateStringSet(Set<String> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, numItems);

    for (int i = 0; i < numItems; i++) {
      assertTrue(set.contains(Integer.toString(i)));
    }
  }

  private void validateEnumSet(Set<TstEnum> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, 1);

    assertTrue(set.contains(TstEnum.E_ONE));
  }

  private void validateSetOfList(Set<List<Integer>> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, 1);

    List<Integer> list = new ArrayList<>(numItems);
    for (int i = 0; i < numItems; i++) {
      list.add(i);
    }

    assertTrue(set.contains(list));
  }

  private void validateSetOfSet(Set<Set<Integer>> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, 1);

    Set<Integer> setElt = new HashSet<>();
    for (int i = 0; i < numItems; i++) {
      setElt.add(i);
    }

    assertTrue(set.contains(setElt));
  }

  private void validateSetOfMap(Set<Map<String, Integer>> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, 1);

    Map<String, Integer> map = new HashMap<>();
    for (int i = 0; i < numItems; i++) {
      map.put(Integer.toString(i), i);
    }

    assertTrue(set.contains(map));
  }

  private void validateSetOfStruct(Set<SmallStruct> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, numItems);

    for (int i = 0; i < numItems; i++) {
      SmallStruct ss = testData.createSmallStruct(i);
      assertTrue(set.contains(ss));
    }
  }

  private void validateSetOfBinary(Set<ByteBuffer> set, int id, int numItems) {
    validateNotNullAndNotEmpty(set, 1);

    for (ByteBuffer b : set) {
      ByteBuffer bb = ByteBuffer.wrap(testData.BYTES);
      assertEquals(0, bb.compareTo(b));
    }
  }

  // ----------------------------------------------------------------------
  // Map validation helpers.

  void validateNotNullAndNotEmpty(Map<?, ?> map, int numItems) {
    assertNotNull(map);
    assertEquals(numItems, map.size());
  }

  private <V extends Number> void validateMap(Map<V, V> map, Class<V> clasz, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    for (int i = 0; i < numItems; i++) {
      if (clasz == Byte.class) {
        assertTrue(map.containsKey((byte) i));
        assertEquals((byte) i, map.get((byte) i));
      } else if (clasz == Short.class) {
        assertTrue(map.containsKey((short) i));
        assertEquals((short) i, map.get((short) i));
      } else if (clasz == Integer.class) {
        assertTrue(map.containsKey(i));
        assertEquals(i, map.get(i));
      } else if (clasz == Long.class) {
        assertTrue(map.containsKey((long) i));
        assertEquals((long) i, map.get((long) i));
      } else if (clasz == Double.class) {
        assertTrue(map.containsKey((double) i));
        assertEquals((double) i, map.get((double) i));
      }
    }
  }

  private void validateStringMap(Map<String, String> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    for (int i = 0; i < numItems; i++) {
      String key = Integer.toString(i);
      assertTrue(map.containsKey(key));
      assertEquals(key, map.get(key));
    }
  }

  private void validateEnumMap(Map<TstEnum, TstEnum> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, 1);

    assertTrue(map.containsKey(TstEnum.E_ONE));
    assertEquals(TstEnum.E_ONE, map.get(TstEnum.E_ONE));
  }

  private void validateMapOfList(Map<Integer, List<Integer>> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    List<Integer> list = new ArrayList<>(numItems);
    for (int i = 0; i < numItems; i++) {
      list.add(i);
    }

    for (int i = 0; i < numItems; i++) {
      assertTrue(map.containsKey(i));
      assertEquals(list, map.get(i));
    }
  }

  private void validateMapOfSet(Map<Integer, Set<Integer>> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    Set<Integer> setElt = new HashSet<>();
    for (int i = 0; i < numItems; i++) {
      setElt.add(i);
    }

    for (int i = 0; i < numItems; i++) {
      assertTrue(map.containsKey(i));
      assertEquals(setElt, map.get(i));
    }
  }

  private void validateMapOfMap(Map<Integer, Map<Integer, Integer>> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    Map<Integer, Integer> mapElt = new HashMap<>();
    for (int i = 0; i < numItems; i++) {
      mapElt.put(i, i);
    }

    for (int i = 0; i < numItems; i++) {
      assertTrue(map.containsKey(i));
      assertEquals(mapElt, map.get(i));
    }
  }

  private void validateMapOfStruct(Map<SmallStruct, SmallStruct> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    for (int i = 0; i < numItems; i++) {
      SmallStruct ss = testData.createSmallStruct(i);
      assertTrue(map.containsKey(ss));
      assertEquals(ss, map.get(ss));
    }
  }

  private void validateMapOfBinary(Map<Integer, ByteBuffer> map, int id, int numItems) {
    validateNotNullAndNotEmpty(map, numItems);

    for (int i = 0; i < numItems; i++) {
      ByteBuffer bb = ByteBuffer.wrap(testData.BYTES);
      assertTrue(map.containsKey(i));
      assertEquals(0, bb.compareTo(map.get(i)));
    }
  }
}