JdbcToArrowArrayTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.adapter.jdbc.h2;

import static org.apache.arrow.adapter.jdbc.AbstractJdbcToArrowTest.sqlToArrow;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;

import java.nio.charset.StandardCharsets;
import java.sql.Array;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
import java.sql.Types;
import java.util.HashMap;
import java.util.Map;
import org.apache.arrow.adapter.jdbc.JdbcFieldInfo;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder;
import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils;
import org.apache.arrow.memory.ArrowBuf;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

public class JdbcToArrowArrayTest {
  private Connection conn = null;

  private static final String CREATE_STATEMENT =
      "CREATE TABLE array_table (id INTEGER, int_array INTEGER ARRAY, float_array REAL ARRAY, "
          + "string_array VARCHAR ARRAY);";
  private static final String INSERT_STATEMENT =
      "INSERT INTO array_table (id, int_array, float_array, string_array) VALUES (?, ?, ?, ?);";
  private static final String QUERY =
      "SELECT int_array, float_array, string_array FROM array_table ORDER BY id;";
  private static final String DROP_STATEMENT = "DROP TABLE array_table;";

  private static Map<String, JdbcFieldInfo> arrayFieldMapping;

  private static final String INT_ARRAY_FIELD_NAME = "INT_ARRAY";
  private static final String FLOAT_ARRAY_FIELD_NAME = "FLOAT_ARRAY";
  private static final String STRING_ARRAY_FIELD_NAME = "STRING_ARRAY";

  @BeforeEach
  public void setUp() throws Exception {
    String url = "jdbc:h2:mem:JdbcToArrowTest";
    String driver = "org.h2.Driver";
    Class.forName(driver);
    conn = DriverManager.getConnection(url);
    try (Statement stmt = conn.createStatement()) {
      stmt.executeUpdate(CREATE_STATEMENT);
    }

    arrayFieldMapping = new HashMap<String, JdbcFieldInfo>();
    arrayFieldMapping.put(INT_ARRAY_FIELD_NAME, new JdbcFieldInfo(Types.INTEGER));
    arrayFieldMapping.put(FLOAT_ARRAY_FIELD_NAME, new JdbcFieldInfo(Types.REAL));
    arrayFieldMapping.put(STRING_ARRAY_FIELD_NAME, new JdbcFieldInfo(Types.VARCHAR));
  }

  // This test verifies reading an array field from an H2 database
  // works as expected.  If this test fails, something is either wrong
  // with the setup, or the H2 SQL behavior changed.
  @Test
  public void testReadH2Array() throws Exception {
    int rowCount = 4;

    Integer[][] intArrays = generateIntegerArrayField(rowCount);
    Float[][] floatArrays = generateFloatArrayField(rowCount);
    String[][] strArrays = generateStringArrayField(rowCount);

    insertRows(rowCount, intArrays, floatArrays, strArrays);

    try (ResultSet resultSet = conn.createStatement().executeQuery(QUERY)) {
      ResultSetMetaData rsmd = resultSet.getMetaData();
      assertEquals(3, rsmd.getColumnCount());

      for (int i = 1; i <= rsmd.getColumnCount(); ++i) {
        assertEquals(Types.ARRAY, rsmd.getColumnType(i));
      }

      int rowNum = 0;

      while (resultSet.next()) {
        Array intArray = resultSet.getArray(INT_ARRAY_FIELD_NAME);
        assertFalse(resultSet.wasNull());
        try (ResultSet rs = intArray.getResultSet()) {
          int arrayIndex = 0;
          while (rs.next()) {
            assertEquals(intArrays[rowNum][arrayIndex].intValue(), rs.getInt(2));
            ++arrayIndex;
          }
          assertEquals(intArrays[rowNum].length, arrayIndex);
        }

        Array floatArray = resultSet.getArray(FLOAT_ARRAY_FIELD_NAME);
        assertFalse(resultSet.wasNull());
        try (ResultSet rs = floatArray.getResultSet()) {
          int arrayIndex = 0;
          while (rs.next()) {
            assertEquals(floatArrays[rowNum][arrayIndex].floatValue(), rs.getFloat(2), 0.001);
            ++arrayIndex;
          }
          assertEquals(floatArrays[rowNum].length, arrayIndex);
        }

        Array strArray = resultSet.getArray(STRING_ARRAY_FIELD_NAME);
        assertFalse(resultSet.wasNull());
        try (ResultSet rs = strArray.getResultSet()) {
          int arrayIndex = 0;
          while (rs.next()) {
            assertEquals(strArrays[rowNum][arrayIndex], rs.getString(2));
            ++arrayIndex;
          }
          assertEquals(strArrays[rowNum].length, arrayIndex);
        }

        ++rowNum;
      }

      assertEquals(rowCount, rowNum);
    }
  }

  @Test
  public void testJdbcToArrow() throws Exception {
    int rowCount = 4;

    Integer[][] intArrays = generateIntegerArrayField(rowCount);
    Float[][] floatArrays = generateFloatArrayField(rowCount);
    String[][] strArrays = generateStringArrayField(rowCount);

    insertRows(rowCount, intArrays, floatArrays, strArrays);

    final JdbcToArrowConfigBuilder builder =
        new JdbcToArrowConfigBuilder(
            new RootAllocator(Integer.MAX_VALUE), JdbcToArrowUtils.getUtcCalendar(), false);
    builder.setArraySubTypeByColumnNameMap(arrayFieldMapping);

    final JdbcToArrowConfig config = builder.build();

    try (ResultSet resultSet = conn.createStatement().executeQuery(QUERY)) {
      final VectorSchemaRoot vector = sqlToArrow(resultSet, config);

      assertEquals(rowCount, vector.getRowCount());

      assertIntegerVectorEquals(
          (ListVector) vector.getVector(INT_ARRAY_FIELD_NAME), rowCount, intArrays);
      assertFloatVectorEquals(
          (ListVector) vector.getVector(FLOAT_ARRAY_FIELD_NAME), rowCount, floatArrays);
      assertStringVectorEquals(
          (ListVector) vector.getVector(STRING_ARRAY_FIELD_NAME), rowCount, strArrays);
    }
  }

  @Test
  public void testJdbcToArrowWithNulls() throws Exception {
    int rowCount = 4;

    Integer[][] intArrays = {
      null, {0}, {1}, {},
    };

    Float[][] floatArrays = {
      {2.0f}, null, {3.0f}, {},
    };

    String[][] stringArrays = {
      {"4"}, null, {"5"}, {},
    };

    insertRows(rowCount, intArrays, floatArrays, stringArrays);

    final JdbcToArrowConfigBuilder builder =
        new JdbcToArrowConfigBuilder(
            new RootAllocator(Integer.MAX_VALUE), JdbcToArrowUtils.getUtcCalendar(), false);
    builder.setArraySubTypeByColumnNameMap(arrayFieldMapping);

    final JdbcToArrowConfig config = builder.build();

    try (ResultSet resultSet = conn.createStatement().executeQuery(QUERY)) {
      final VectorSchemaRoot vector = sqlToArrow(resultSet, config);

      assertEquals(rowCount, vector.getRowCount());

      assertIntegerVectorEquals(
          (ListVector) vector.getVector(INT_ARRAY_FIELD_NAME), rowCount, intArrays);
      assertFloatVectorEquals(
          (ListVector) vector.getVector(FLOAT_ARRAY_FIELD_NAME), rowCount, floatArrays);
      assertStringVectorEquals(
          (ListVector) vector.getVector(STRING_ARRAY_FIELD_NAME), rowCount, stringArrays);
    }
  }

  private void assertIntegerVectorEquals(
      ListVector listVector, int rowCount, Integer[][] expectedValues) {
    IntVector vector = (IntVector) listVector.getDataVector();
    ArrowBuf offsetBuffer = listVector.getOffsetBuffer();

    int prevOffset = 0;
    for (int row = 0; row < rowCount; ++row) {
      int offset = offsetBuffer.getInt((row + 1) * ListVector.OFFSET_WIDTH);

      if (expectedValues[row] == null) {
        assertEquals(0, listVector.isSet(row));
        assertEquals(0, offset - prevOffset);
        continue;
      }

      assertEquals(1, listVector.isSet(row));
      assertEquals(expectedValues[row].length, offset - prevOffset);

      for (int i = prevOffset; i < offset; ++i) {
        assertEquals(expectedValues[row][i - prevOffset].intValue(), vector.get(i));
      }

      prevOffset = offset;
    }
  }

  private void assertFloatVectorEquals(
      ListVector listVector, int rowCount, Float[][] expectedValues) {
    Float4Vector vector = (Float4Vector) listVector.getDataVector();
    ArrowBuf offsetBuffer = listVector.getOffsetBuffer();

    int prevOffset = 0;
    for (int row = 0; row < rowCount; ++row) {
      int offset = offsetBuffer.getInt((row + 1) * ListVector.OFFSET_WIDTH);

      if (expectedValues[row] == null) {
        assertEquals(0, listVector.isSet(row));
        assertEquals(0, offset - prevOffset);
        continue;
      }

      assertEquals(1, listVector.isSet(row));
      assertEquals(expectedValues[row].length, offset - prevOffset);

      for (int i = prevOffset; i < offset; ++i) {
        assertEquals(expectedValues[row][i - prevOffset].floatValue(), vector.get(i), 0);
      }

      prevOffset = offset;
    }
  }

  private void assertStringVectorEquals(
      ListVector listVector, int rowCount, String[][] expectedValues) {
    VarCharVector vector = (VarCharVector) listVector.getDataVector();
    ArrowBuf offsetBuffer = listVector.getOffsetBuffer();

    int prevOffset = 0;
    for (int row = 0; row < rowCount; ++row) {
      int offset = offsetBuffer.getInt((row + 1) * ListVector.OFFSET_WIDTH);

      if (expectedValues[row] == null) {
        assertEquals(0, listVector.isSet(row));
        assertEquals(0, offset - prevOffset);
        continue;
      }

      assertEquals(1, listVector.isSet(row));
      assertEquals(expectedValues[row].length, offset - prevOffset);
      for (int i = prevOffset; i < offset; ++i) {
        assertArrayEquals(
            expectedValues[row][i - prevOffset].getBytes(StandardCharsets.UTF_8), vector.get(i));
      }

      prevOffset = offset;
    }
  }

  @AfterEach
  public void tearDown() throws SQLException {
    try (Statement stmt = conn.createStatement()) {
      stmt.executeUpdate(DROP_STATEMENT);
    } finally {
      if (conn != null) {
        conn.close();
        conn = null;
      }
    }
  }

  private Integer[][] generateIntegerArrayField(int numRows) {
    Integer[][] result = new Integer[numRows][];

    for (int i = 0; i < numRows; ++i) {
      int val = i * 4;
      result[i] = new Integer[] {val, val + 1, val + 2, val + 3};
    }

    return result;
  }

  private Float[][] generateFloatArrayField(int numRows) {
    Float[][] result = new Float[numRows][];

    for (int i = 0; i < numRows; ++i) {
      int val = i * 4;
      result[i] = new Float[] {(float) val, (float) val + 1, (float) val + 2, (float) val + 3};
    }

    return result;
  }

  private String[][] generateStringArrayField(int numRows) {
    String[][] result = new String[numRows][];

    for (int i = 0; i < numRows; ++i) {
      int val = i * 4;
      result[i] =
          new String[] {
            String.valueOf(val),
            String.valueOf(val + 1),
            String.valueOf(val + 2),
            String.valueOf(val + 3)
          };
    }

    return result;
  }

  private void insertRows(
      int numRows, Integer[][] integerArrays, Float[][] floatArrays, String[][] strArrays)
      throws SQLException {

    // Insert 4 Rows
    try (PreparedStatement stmt = conn.prepareStatement(INSERT_STATEMENT)) {

      for (int i = 0; i < numRows; ++i) {
        Integer[] integerArray = integerArrays[i];
        Float[] floatArray = floatArrays[i];
        String[] strArray = strArrays[i];

        Array intArray = integerArray != null ? conn.createArrayOf("INT", integerArray) : null;
        Array realArray = floatArray != null ? conn.createArrayOf("REAL", floatArray) : null;
        Array varcharArray = strArray != null ? conn.createArrayOf("VARCHAR", strArray) : null;

        // Insert Arrays of 4 Values in Each Row
        stmt.setInt(1, i);
        stmt.setArray(2, intArray);
        stmt.setArray(3, realArray);
        stmt.setArray(4, varcharArray);

        stmt.executeUpdate();

        if (intArray != null) {
          intArray.free();
        }
        if (realArray != null) {
          realArray.free();
        }
        if (varcharArray != null) {
          varcharArray.free();
        }
      }
    }
  }
}