JdbcToArrowVectorIteratorTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.adapter.jdbc.h2;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getBinaryValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getBooleanValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getCharArray;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getDecimalValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getDoubleValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getFloatValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getIntValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getListValues;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper.getLongValues;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.io.IOException;
import java.math.BigDecimal;
import java.sql.SQLException;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Calendar;
import java.util.List;
import org.apache.arrow.adapter.jdbc.ArrowVectorIterator;
import org.apache.arrow.adapter.jdbc.JdbcToArrow;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder;
import org.apache.arrow.adapter.jdbc.JdbcToArrowTestHelper;
import org.apache.arrow.adapter.jdbc.Table;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.TimeMilliVector;
import org.apache.arrow.vector.TimeStampMilliTZVector;
import org.apache.arrow.vector.TimeStampMilliVector;
import org.apache.arrow.vector.TimeStampVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
public class JdbcToArrowVectorIteratorTest extends JdbcToArrowTest {
@ParameterizedTest
@MethodSource("getTestData")
@Override
public void testJdbcToArrowValues(Table table)
throws SQLException, IOException, ClassNotFoundException {
this.initializeDatabase(table);
JdbcToArrowConfig config =
new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())
.setTargetBatchSize(3)
.setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP)
.build();
ArrowVectorIterator iterator =
JdbcToArrow.sqlToArrowVectorIterator(
conn.createStatement().executeQuery(table.getQuery()), config);
validate(iterator);
}
@ParameterizedTest
@MethodSource("getTestData")
public void testVectorSchemaRootReuse(Table table, boolean reuseVectorSchemaRoot)
throws SQLException, IOException, ClassNotFoundException {
this.initializeDatabase(table);
Integer[][] intValues = {
{101, 102, 103},
{104, null, null},
{107, 108, 109},
{110}
};
Integer[][][] listValues = {
{{1, 2, 3}, {1, 2}, {1}},
{{2, 3, 4}, {2, 3}, {2}},
{{3, 4, 5}, {3, 4}, {3}},
{{}}
};
JdbcToArrowConfig config =
new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())
.setTargetBatchSize(3)
.setReuseVectorSchemaRoot(reuseVectorSchemaRoot)
.setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP)
.build();
ArrowVectorIterator iterator =
JdbcToArrow.sqlToArrowVectorIterator(
conn.createStatement().executeQuery(table.getQuery()), config);
int batchCount = 0;
VectorSchemaRoot prev = null;
VectorSchemaRoot cur = null;
while (iterator.hasNext()) {
cur = iterator.next();
assertNotNull(cur);
// verify the first column, with may contain nulls.
List<IntVector> intVectors = new ArrayList<>();
intVectors.add((IntVector) cur.getVector(0));
assertIntVectorValues(intVectors, intValues[batchCount].length, intValues[batchCount]);
// verify arrays are handled correctly
List<ListVector> listVectors = new ArrayList<>();
listVectors.add((ListVector) cur.getVector(18));
assertListVectorValues(listVectors, listValues[batchCount].length, listValues[batchCount]);
if (prev != null) {
// skip the first iteration
if (reuseVectorSchemaRoot) {
// when reuse is enabled, different iterations are based on the same vector schema root.
assertTrue(prev == cur);
} else {
// when reuse is enabled, a new vector schema root is created in each iteration.
assertFalse(prev == cur);
if (batchCount < 3) {
cur.close();
}
}
}
prev = cur;
batchCount += 1;
}
iterator.close();
if (!reuseVectorSchemaRoot) {
assertNotNull(cur);
// test that closing the iterator does not close the vectors held by the consumers
assertNotEquals(cur.getVector(0).getValueCount(), 0);
cur.close();
}
// make sure we have at least two batches, so the above test paths are actually covered
assertTrue(batchCount > 1);
}
@ParameterizedTest
@MethodSource("getTestData")
public void testJdbcToArrowValuesNoLimit(Table table)
throws SQLException, IOException, ClassNotFoundException {
this.initializeDatabase(table);
JdbcToArrowConfig config =
new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())
.setTargetBatchSize(JdbcToArrowConfig.NO_LIMIT_BATCH_SIZE)
.setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP)
.build();
ArrowVectorIterator iterator =
JdbcToArrow.sqlToArrowVectorIterator(
conn.createStatement().executeQuery(table.getQuery()), config);
validate(iterator);
}
@ParameterizedTest
@MethodSource("getTestData")
public void testTimeStampConsumer(Table table, boolean reuseVectorSchemaRoot)
throws SQLException, IOException, ClassNotFoundException {
this.initializeDatabase(table);
final String sql = "select timestamp_field11 from table1";
// first experiment, with calendar and time zone.
JdbcToArrowConfig config =
new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())
.setTargetBatchSize(3)
.setReuseVectorSchemaRoot(reuseVectorSchemaRoot)
.setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP)
.build();
assertNotNull(config.getCalendar());
try (ArrowVectorIterator iterator =
JdbcToArrow.sqlToArrowVectorIterator(conn.createStatement().executeQuery(sql), config)) {
VectorSchemaRoot root = iterator.next();
assertEquals(1, root.getFieldVectors().size());
// vector with time zone info.
assertTrue(root.getVector(0) instanceof TimeStampMilliTZVector);
}
// second experiment, without calendar and time zone.
config =
new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), null)
.setTargetBatchSize(3)
.setReuseVectorSchemaRoot(reuseVectorSchemaRoot)
.setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP)
.build();
assertNull(config.getCalendar());
try (ArrowVectorIterator iterator =
JdbcToArrow.sqlToArrowVectorIterator(conn.createStatement().executeQuery(sql), config)) {
VectorSchemaRoot root = iterator.next();
assertEquals(1, root.getFieldVectors().size());
// vector without time zone info.
assertTrue(root.getVector(0) instanceof TimeStampMilliVector);
}
}
private void validate(ArrowVectorIterator iterator) throws SQLException, IOException {
List<BigIntVector> bigIntVectors = new ArrayList<>();
List<TinyIntVector> tinyIntVectors = new ArrayList<>();
List<IntVector> intVectors = new ArrayList<>();
List<SmallIntVector> smallIntVectors = new ArrayList<>();
List<VarBinaryVector> vectorsForBinary = new ArrayList<>();
List<VarBinaryVector> vectorsForBlob = new ArrayList<>();
List<VarCharVector> vectorsForClob = new ArrayList<>();
List<VarCharVector> vectorsForVarChar = new ArrayList<>();
List<VarCharVector> vectorsForChar = new ArrayList<>();
List<BitVector> vectorsForBit = new ArrayList<>();
List<BitVector> vectorsForBool = new ArrayList<>();
List<DateDayVector> dateDayVectors = new ArrayList<>();
List<TimeMilliVector> timeMilliVectors = new ArrayList<>();
List<TimeStampVector> timeStampVectors = new ArrayList<>();
List<DecimalVector> decimalVectors = new ArrayList<>();
List<Float4Vector> float4Vectors = new ArrayList<>();
List<Float8Vector> float8Vectors = new ArrayList<>();
List<ListVector> listVectors = new ArrayList<>();
List<VectorSchemaRoot> roots = new ArrayList<>();
while (iterator.hasNext()) {
VectorSchemaRoot root = iterator.next();
roots.add(root);
JdbcToArrowTestHelper.assertFieldMetadataIsEmpty(root);
bigIntVectors.add((BigIntVector) root.getVector(BIGINT));
tinyIntVectors.add((TinyIntVector) root.getVector(TINYINT));
intVectors.add((IntVector) root.getVector(INT));
smallIntVectors.add((SmallIntVector) root.getVector(SMALLINT));
vectorsForBinary.add((VarBinaryVector) root.getVector(BINARY));
vectorsForBlob.add((VarBinaryVector) root.getVector(BLOB));
vectorsForClob.add((VarCharVector) root.getVector(CLOB));
vectorsForVarChar.add((VarCharVector) root.getVector(VARCHAR));
vectorsForChar.add((VarCharVector) root.getVector(CHAR));
vectorsForBit.add((BitVector) root.getVector(BIT));
vectorsForBool.add((BitVector) root.getVector(BOOL));
dateDayVectors.add((DateDayVector) root.getVector(DATE));
timeMilliVectors.add((TimeMilliVector) root.getVector(TIME));
timeStampVectors.add((TimeStampVector) root.getVector(TIMESTAMP));
decimalVectors.add((DecimalVector) root.getVector(DECIMAL));
float4Vectors.add((Float4Vector) root.getVector(REAL));
float8Vectors.add((Float8Vector) root.getVector(DOUBLE));
listVectors.add((ListVector) root.getVector(LIST));
}
assertBigIntVectorValues(
bigIntVectors, table.getRowCount(), getLongValues(table.getValues(), BIGINT));
assertTinyIntVectorValues(
tinyIntVectors, table.getRowCount(), getIntValues(table.getValues(), TINYINT));
assertIntVectorValues(intVectors, table.getRowCount(), getIntValues(table.getValues(), INT));
assertSmallIntVectorValues(
smallIntVectors, table.getRowCount(), getIntValues(table.getValues(), SMALLINT));
assertBinaryVectorValues(
vectorsForBinary, table.getRowCount(), getBinaryValues(table.getValues(), BINARY));
assertBinaryVectorValues(
vectorsForBlob, table.getRowCount(), getBinaryValues(table.getValues(), BLOB));
assertVarCharVectorValues(
vectorsForClob, table.getRowCount(), getCharArray(table.getValues(), CLOB));
assertVarCharVectorValues(
vectorsForVarChar, table.getRowCount(), getCharArray(table.getValues(), VARCHAR));
assertVarCharVectorValues(
vectorsForChar, table.getRowCount(), getCharArray(table.getValues(), CHAR));
assertBitVectorValues(vectorsForBit, table.getRowCount(), getIntValues(table.getValues(), BIT));
assertBooleanVectorValues(
vectorsForBool, table.getRowCount(), getBooleanValues(table.getValues(), BOOL));
assertDateDayVectorValues(
dateDayVectors, table.getRowCount(), getLongValues(table.getValues(), DATE));
assertTimeMilliVectorValues(
timeMilliVectors, table.getRowCount(), getLongValues(table.getValues(), TIME));
assertTimeStampVectorValues(
timeStampVectors, table.getRowCount(), getLongValues(table.getValues(), TIMESTAMP));
assertDecimalVectorValues(
decimalVectors, table.getRowCount(), getDecimalValues(table.getValues(), DECIMAL));
assertFloat4VectorValues(
float4Vectors, table.getRowCount(), getFloatValues(table.getValues(), REAL));
assertFloat8VectorValues(
float8Vectors, table.getRowCount(), getDoubleValues(table.getValues(), DOUBLE));
assertListVectorValues(
listVectors, table.getRowCount(), getListValues(table.getValues(), LIST));
roots.forEach(root -> root.close());
}
private void assertFloat8VectorValues(List<Float8Vector> vectors, int rowCount, Double[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (Float8Vector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].doubleValue(), vector.get(i), 0.01);
}
}
}
private void assertFloat4VectorValues(List<Float4Vector> vectors, int rowCount, Float[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (Float4Vector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].floatValue(), vector.get(i), 0.01);
}
}
}
private void assertDecimalVectorValues(
List<DecimalVector> vectors, int rowCount, BigDecimal[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (DecimalVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertNotNull(vector.getObject(i));
assertEquals(values[index++].doubleValue(), vector.getObject(i).doubleValue(), 0);
}
}
}
private void assertTimeStampVectorValues(
List<TimeStampVector> vectors, int rowCount, Long[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (TimeStampVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].longValue(), vector.get(i));
}
}
}
private void assertTimeMilliVectorValues(
List<TimeMilliVector> vectors, int rowCount, Long[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (TimeMilliVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].longValue(), vector.get(i));
}
}
}
private void assertDateDayVectorValues(List<DateDayVector> vectors, int rowCount, Long[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (DateDayVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].longValue(), vector.get(i));
}
}
}
private void assertBitVectorValues(List<BitVector> vectors, int rowCount, Integer[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (BitVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].intValue(), vector.get(i));
}
}
}
private void assertBooleanVectorValues(List<BitVector> vectors, int rowCount, Boolean[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (BitVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++], vector.get(i) == 1);
}
}
}
private void assertVarCharVectorValues(
List<VarCharVector> vectors, int rowCount, byte[][] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (VarCharVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertArrayEquals(values[index++], vector.get(i));
}
}
}
private void assertBinaryVectorValues(
List<VarBinaryVector> vectors, int rowCount, byte[][] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (VarBinaryVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertArrayEquals(values[index++], vector.get(i));
}
}
}
private void assertSmallIntVectorValues(
List<SmallIntVector> vectors, int rowCount, Integer[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (SmallIntVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].intValue(), vector.get(i));
}
}
}
private void assertTinyIntVectorValues(
List<TinyIntVector> vectors, int rowCount, Integer[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (TinyIntVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].intValue(), vector.get(i));
}
}
}
private void assertBigIntVectorValues(List<BigIntVector> vectors, int rowCount, Long[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (BigIntVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
assertEquals(values[index++].longValue(), vector.get(i));
}
}
}
private void assertIntVectorValues(List<IntVector> vectors, int rowCount, Integer[] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (IntVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
if (values[index] == null) {
assertTrue(vector.isNull(i));
} else {
assertEquals(values[index].longValue(), vector.get(i));
}
index++;
}
}
}
public static void assertListVectorValues(
List<ListVector> vectors, int rowCount, Integer[][] values) {
int valueCount = vectors.stream().mapToInt(ValueVector::getValueCount).sum();
assertEquals(rowCount, valueCount);
int index = 0;
for (ListVector vector : vectors) {
for (int i = 0; i < vector.getValueCount(); i++) {
if (values[index] == null) {
assertTrue(vector.isNull(i));
} else {
List<Integer> list = (List<Integer>) vector.getObject(i);
assertEquals(Arrays.asList(values[index]), list);
}
index++;
}
}
}
/** Runs a simple query, and encapsulates the result into a field vector. */
private FieldVector getQueryResult(JdbcToArrowConfig config) throws SQLException, IOException {
ArrowVectorIterator iterator =
JdbcToArrow.sqlToArrowVectorIterator(
conn.createStatement().executeQuery("select real_field8 from table1"), config);
VectorSchemaRoot root = iterator.next();
// only one vector, since there is one column in the select statement.
assertEquals(1, root.getFieldVectors().size());
FieldVector result = root.getVector(0);
// make sure some data is actually read
assertTrue(result.getValueCount() > 0);
return result;
}
@ParameterizedTest
@MethodSource("getTestData")
public void testJdbcToArrowCustomTypeConversion(Table table, boolean reuseVectorSchemaRoot)
throws SQLException, IOException, ClassNotFoundException {
this.initializeDatabase(table);
JdbcToArrowConfigBuilder builder =
new JdbcToArrowConfigBuilder(new RootAllocator(Integer.MAX_VALUE), Calendar.getInstance())
.setTargetBatchSize(JdbcToArrowConfig.NO_LIMIT_BATCH_SIZE)
.setReuseVectorSchemaRoot(reuseVectorSchemaRoot)
.setArraySubTypeByColumnNameMap(ARRAY_SUB_TYPE_BY_COLUMN_NAME_MAP);
// first experiment, using default type converter
JdbcToArrowConfig config = builder.build();
try (FieldVector vector = getQueryResult(config)) {
// the default converter translates real to float4
assertTrue(vector instanceof Float4Vector);
}
// second experiment, using customized type converter
builder.setJdbcToArrowTypeConverter(
(fieldInfo) -> {
switch (fieldInfo.getJdbcType()) {
case Types.REAL:
// this is different from the default type converter
return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
default:
return null;
}
});
config = builder.build();
try (FieldVector vector = getQueryResult(config)) {
// the customized converter translates real to float8
assertTrue(vector instanceof Float8Vector);
}
}
}