JdbcToArrowTestHelper.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.adapter.jdbc;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertTrue;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.math.BigDecimal;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.AbstractMap;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.arrow.vector.BaseValueVector;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.NullVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeStampVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.JsonStringArrayList;
import org.apache.arrow.vector.util.JsonStringHashMap;
import org.apache.arrow.vector.util.ObjectMapperFactory;
import org.apache.arrow.vector.util.Text;

/**
 * This is a Helper class which has functionalities to read and assert the values from the given
 * FieldVector object.
 */
public class JdbcToArrowTestHelper {

  public static void assertIntVectorValues(IntVector intVector, int rowCount, Integer[] values) {
    assertEquals(rowCount, intVector.getValueCount());

    for (int j = 0; j < intVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(intVector.isNull(j));
      } else {
        assertEquals(values[j].intValue(), intVector.get(j));
      }
    }
  }

  public static void assertBooleanVectorValues(
      BitVector bitVector, int rowCount, Boolean[] values) {
    assertEquals(rowCount, bitVector.getValueCount());

    for (int j = 0; j < bitVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(bitVector.isNull(j));
      } else {
        assertEquals(values[j].booleanValue(), bitVector.get(j) == 1);
      }
    }
  }

  public static void assertBitVectorValues(BitVector bitVector, int rowCount, Integer[] values) {
    assertEquals(rowCount, bitVector.getValueCount());

    for (int j = 0; j < bitVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(bitVector.isNull(j));
      } else {
        assertEquals(values[j].intValue(), bitVector.get(j));
      }
    }
  }

  public static void assertTinyIntVectorValues(
      TinyIntVector tinyIntVector, int rowCount, Integer[] values) {
    assertEquals(rowCount, tinyIntVector.getValueCount());

    for (int j = 0; j < tinyIntVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(tinyIntVector.isNull(j));
      } else {
        assertEquals(values[j].intValue(), tinyIntVector.get(j));
      }
    }
  }

  public static void assertSmallIntVectorValues(
      SmallIntVector smallIntVector, int rowCount, Integer[] values) {
    assertEquals(rowCount, smallIntVector.getValueCount());

    for (int j = 0; j < smallIntVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(smallIntVector.isNull(j));
      } else {
        assertEquals(values[j].intValue(), smallIntVector.get(j));
      }
    }
  }

  public static void assertBigIntVectorValues(
      BigIntVector bigIntVector, int rowCount, Long[] values) {
    assertEquals(rowCount, bigIntVector.getValueCount());

    for (int j = 0; j < bigIntVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(bigIntVector.isNull(j));
      } else {
        assertEquals(values[j].longValue(), bigIntVector.get(j));
      }
    }
  }

  public static void assertDecimalVectorValues(
      DecimalVector decimalVector, int rowCount, BigDecimal[] values) {
    assertEquals(rowCount, decimalVector.getValueCount());

    for (int j = 0; j < decimalVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(decimalVector.isNull(j));
      } else {
        assertEquals(values[j].doubleValue(), decimalVector.getObject(j).doubleValue(), 0);
      }
    }
  }

  public static void assertFloat8VectorValues(
      Float8Vector float8Vector, int rowCount, Double[] values) {
    assertEquals(rowCount, float8Vector.getValueCount());

    for (int j = 0; j < float8Vector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(float8Vector.isNull(j));
      } else {
        assertEquals(values[j], float8Vector.get(j), 0.01);
      }
    }
  }

  public static void assertFloat4VectorValues(
      Float4Vector float4Vector, int rowCount, Float[] values) {
    assertEquals(rowCount, float4Vector.getValueCount());

    for (int j = 0; j < float4Vector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(float4Vector.isNull(j));
      } else {
        assertEquals(values[j], float4Vector.get(j), 0.01);
      }
    }
  }

  public static void assertTimeVectorValues(
      TimeMilliVector timeMilliVector, int rowCount, Long[] values) {
    assertEquals(rowCount, timeMilliVector.getValueCount());

    for (int j = 0; j < timeMilliVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(timeMilliVector.isNull(j));
      } else {
        assertEquals(values[j].longValue(), timeMilliVector.get(j));
      }
    }
  }

  public static void assertDateVectorValues(
      DateDayVector dateDayVector, int rowCount, Integer[] values) {
    assertEquals(rowCount, dateDayVector.getValueCount());

    for (int j = 0; j < dateDayVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(dateDayVector.isNull(j));
      } else {
        assertEquals(values[j].longValue(), dateDayVector.get(j));
      }
    }
  }

  public static void assertTimeStampVectorValues(
      TimeStampVector timeStampVector, int rowCount, Long[] values) {
    assertEquals(rowCount, timeStampVector.getValueCount());

    for (int j = 0; j < timeStampVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(timeStampVector.isNull(j));
      } else {
        assertEquals(values[j].longValue(), timeStampVector.get(j));
      }
    }
  }

  public static void assertVarBinaryVectorValues(
      VarBinaryVector varBinaryVector, int rowCount, byte[][] values) {
    assertEquals(rowCount, varBinaryVector.getValueCount());

    for (int j = 0; j < varBinaryVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(varBinaryVector.isNull(j));
      } else {
        assertArrayEquals(values[j], varBinaryVector.get(j));
      }
    }
  }

  public static void assertVarcharVectorValues(
      VarCharVector varCharVector, int rowCount, byte[][] values) {
    assertEquals(rowCount, varCharVector.getValueCount());

    for (int j = 0; j < varCharVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(varCharVector.isNull(j));
      } else {
        assertArrayEquals(values[j], varCharVector.get(j));
      }
    }
  }

  public static void assertNullVectorValues(NullVector vector, int rowCount) {
    assertEquals(rowCount, vector.getValueCount());
  }

  public static void assertListVectorValues(
      ListVector listVector, int rowCount, Integer[][] values) {
    assertEquals(rowCount, listVector.getValueCount());

    for (int j = 0; j < listVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(listVector.isNull(j));
      } else {
        List<Integer> list = (List<Integer>) listVector.getObject(j);
        assertEquals(Arrays.asList(values[j]), list);
      }
    }
  }

  public static void assertMapVectorValues(
      MapVector mapVector, int rowCount, Map<String, String>[] values) {
    assertEquals(rowCount, mapVector.getValueCount());

    for (int j = 0; j < mapVector.getValueCount(); j++) {
      if (values[j] == null) {
        assertTrue(mapVector.isNull(j));
      } else {
        JsonStringArrayList<JsonStringHashMap<String, Text>> actualSource =
            (JsonStringArrayList<JsonStringHashMap<String, Text>>) mapVector.getObject(j);
        Map<String, String> actualMap = null;
        if (actualSource != null && !actualSource.isEmpty()) {
          actualMap =
              actualSource.stream()
                  .map(
                      entry ->
                          new AbstractMap.SimpleEntry<>(
                              entry.get("key").toString(),
                              entry.get("value") != null ? entry.get("value").toString() : null))
                  .collect(
                      HashMap::new,
                      (collector, val) -> collector.put(val.getKey(), val.getValue()),
                      HashMap::putAll);
        }
        assertEquals(values[j], actualMap);
      }
    }
  }

  public static Map<String, String>[] getMapValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    Map<String, String>[] maps = new Map[dataArr.length];
    ObjectMapper objectMapper = ObjectMapperFactory.newObjectMapper();
    TypeReference<Map<String, String>> typeReference = new TypeReference<Map<String, String>>() {};
    for (int idx = 0; idx < dataArr.length; idx++) {
      String jsonString = dataArr[idx].replace("|", ",");
      if (!jsonString.isEmpty()) {
        try {
          maps[idx] = objectMapper.readValue(jsonString, typeReference);
        } catch (JsonProcessingException e) {
          throw new RuntimeException(e);
        }
      }
    }
    return maps;
  }

  public static void assertNullValues(BaseValueVector vector, int rowCount) {
    assertEquals(rowCount, vector.getValueCount());

    for (int j = 0; j < vector.getValueCount(); j++) {
      assertTrue(vector.isNull(j));
    }
  }

  public static void assertFieldMetadataIsEmpty(VectorSchemaRoot schema) {
    assertNotNull(schema);
    assertNotNull(schema.getSchema());
    assertNotNull(schema.getSchema().getFields());

    for (Field field : schema.getSchema().getFields()) {
      assertNotNull(field.getMetadata());
      assertEquals(0, field.getMetadata().size());
    }
  }

  public static void assertFieldMetadataMatchesResultSetMetadata(
      ResultSetMetaData rsmd, Schema schema) throws SQLException {
    assertNotNull(schema);
    assertNotNull(schema.getFields());
    assertNotNull(rsmd);

    List<Field> fields = schema.getFields();

    assertEquals(rsmd.getColumnCount(), fields.size());

    // Vector columns are created in the same order as ResultSet columns.
    for (int i = 1; i <= rsmd.getColumnCount(); ++i) {
      Map<String, String> metadata = fields.get(i - 1).getMetadata();

      assertNotNull(metadata);
      assertEquals(5, metadata.size());

      assertEquals(rsmd.getCatalogName(i), metadata.get(Constants.SQL_CATALOG_NAME_KEY));
      assertEquals(rsmd.getSchemaName(i), metadata.get(Constants.SQL_SCHEMA_NAME_KEY));
      assertEquals(rsmd.getTableName(i), metadata.get(Constants.SQL_TABLE_NAME_KEY));
      assertEquals(rsmd.getColumnLabel(i), metadata.get(Constants.SQL_COLUMN_NAME_KEY));
      assertEquals(rsmd.getColumnTypeName(i), metadata.get(Constants.SQL_TYPE_KEY));
    }
  }

  public static Integer[] getIntValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    Integer[] valueArr = new Integer[dataArr.length];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : Integer.parseInt(data);
    }
    return valueArr;
  }

  public static Boolean[] getBooleanValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    Boolean[] valueArr = new Boolean[dataArr.length];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : data.trim().equals("1");
    }
    return valueArr;
  }

  public static BigDecimal[] getDecimalValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    BigDecimal[] valueArr = new BigDecimal[dataArr.length];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : new BigDecimal(data);
    }
    return valueArr;
  }

  public static Double[] getDoubleValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    Double[] valueArr = new Double[dataArr.length];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : Double.parseDouble(data);
    }
    return valueArr;
  }

  public static Float[] getFloatValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    Float[] valueArr = new Float[dataArr.length];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : Float.parseFloat(data);
    }
    return valueArr;
  }

  public static Long[] getLongValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    Long[] valueArr = new Long[dataArr.length];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : Long.parseLong(data);
    }
    return valueArr;
  }

  public static byte[][] getCharArray(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    byte[][] valueArr = new byte[dataArr.length][];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] =
          "null".equals(data.trim()) ? null : data.trim().getBytes(StandardCharsets.UTF_8);
    }
    return valueArr;
  }

  public static byte[][] getCharArrayWithCharSet(
      String[] values, String dataType, Charset charSet) {
    String[] dataArr = getValues(values, dataType);
    byte[][] valueArr = new byte[dataArr.length][];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] = "null".equals(data.trim()) ? null : data.trim().getBytes(charSet);
    }
    return valueArr;
  }

  public static byte[][] getBinaryValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    byte[][] valueArr = new byte[dataArr.length][];
    int i = 0;
    for (String data : dataArr) {
      valueArr[i++] =
          "null".equals(data.trim()) ? null : data.trim().getBytes(StandardCharsets.UTF_8);
    }
    return valueArr;
  }

  @SuppressWarnings("StringSplitter")
  public static String[] getValues(String[] values, String dataType) {
    String value = "";
    for (String val : values) {
      if (val.startsWith(dataType)) {
        value = val.split("=")[1];
        break;
      }
    }
    return value.split(",");
  }

  public static Integer[][] getListValues(String[] values, String dataType) {
    String[] dataArr = getValues(values, dataType);
    return getListValues(dataArr);
  }

  @SuppressWarnings("StringSplitter")
  public static Integer[][] getListValues(String[] dataArr) {
    Integer[][] valueArr = new Integer[dataArr.length][];
    int i = 0;
    for (String data : dataArr) {
      if ("null".equals(data.trim())) {
        valueArr[i++] = null;
      } else if ("()".equals(data.trim())) {
        valueArr[i++] = new Integer[0];
      } else {
        String[] row = data.replace("(", "").replace(")", "").split(";");
        Integer[] arr = new Integer[row.length];
        for (int j = 0; j < arr.length; j++) {
          arr[j] = "null".equals(row[j]) ? null : Integer.parseInt(row[j]);
        }
        valueArr[i++] = arr;
      }
    }
    return valueArr;
  }
}