ArrowFlightPreparedStatementTest.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.driver.jdbc;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertAll;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.flight.sql.FlightSqlUtils;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;
public class ArrowFlightPreparedStatementTest {
public static final MockFlightSqlProducer PRODUCER = CoreMockedSqlProducers.getLegacyProducer();
@RegisterExtension
public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION =
FlightServerTestExtension.createStandardTestExtension(PRODUCER);
private static Connection connection;
@BeforeAll
public static void setup() throws SQLException {
connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false);
}
@AfterAll
public static void tearDown() throws SQLException {
connection.close();
}
@BeforeEach
public void before() {
PRODUCER.clearActionTypeCounter();
}
@Test
public void testSimpleQueryNoParameterBinding() throws SQLException {
final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD;
try (final PreparedStatement preparedStatement = connection.prepareStatement(query);
final ResultSet resultSet = preparedStatement.executeQuery()) {
CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet);
}
}
@Test
public void testSimpleQueryNoParameterBindingWithExecute() throws SQLException {
final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD;
try (final PreparedStatement preparedStatement = connection.prepareStatement(query)) {
boolean isResultSet = preparedStatement.execute();
assertTrue(isResultSet);
final ResultSet resultSet = preparedStatement.getResultSet();
CoreMockedSqlProducers.assertLegacyRegularSqlResultSet(resultSet);
assertFalse(preparedStatement.getMoreResults());
assertEquals(-1, preparedStatement.getUpdateCount());
}
}
@Test
public void testQueryWithParameterBinding() throws SQLException {
final String query = "Fake query with parameters";
final Schema schema =
new Schema(Collections.singletonList(Field.nullable("", Types.MinorType.INT.getType())));
final Schema parameterSchema =
new Schema(
Arrays.asList(
Field.nullable("", ArrowType.Utf8.INSTANCE),
new Field(
"",
FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(Field.nullable("", Types.MinorType.INT.getType())))));
final List<List<Object>> expected =
Collections.singletonList(Arrays.asList(new Text("foo"), new Integer[] {1, 2, null}));
PRODUCER.addSelectQuery(
query,
schema,
Collections.singletonList(
listener -> {
try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
((IntVector) root.getVector(0)).setSafe(0, 10);
root.setRowCount(1);
listener.start(root);
listener.putNext();
} catch (final Throwable throwable) {
listener.error(throwable);
} finally {
listener.completed();
}
}));
PRODUCER.addExpectedParameters(query, parameterSchema, expected);
try (final PreparedStatement preparedStatement = connection.prepareStatement(query)) {
preparedStatement.setString(1, "foo");
preparedStatement.setArray(
2, connection.createArrayOf("INTEGER", new Integer[] {1, 2, null}));
try (final ResultSet resultSet = preparedStatement.executeQuery()) {
resultSet.next();
assert true;
}
}
}
@Test
@Disabled("https://github.com/apache/arrow/issues/34741: flaky test")
public void testPreparedStatementExecutionOnce() throws SQLException {
final PreparedStatement statement =
connection.prepareStatement(CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD);
// Expect that there is one entry in the map -- {prepared statement action type, invocation
// count}.
assertEquals(PRODUCER.getActionTypeCounter().size(), 1);
// Expect that the prepared statement was executed exactly once.
assertEquals(
PRODUCER
.getActionTypeCounter()
.get(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType()),
1);
statement.close();
}
@Test
public void testReturnColumnCount() throws SQLException {
final String query = CoreMockedSqlProducers.LEGACY_REGULAR_SQL_CMD;
try (final PreparedStatement psmt = connection.prepareStatement(query)) {
assertAll(
"Column count is as expected",
() -> assertThat("ID", equalTo(psmt.getMetaData().getColumnName(1))),
() -> assertThat("Name", equalTo(psmt.getMetaData().getColumnName(2))),
() -> assertThat("Age", equalTo(psmt.getMetaData().getColumnName(3))),
() -> assertThat("Salary", equalTo(psmt.getMetaData().getColumnName(4))),
() -> assertThat("Hire Date", equalTo(psmt.getMetaData().getColumnName(5))),
() -> assertThat("Last Sale", equalTo(psmt.getMetaData().getColumnName(6))),
() -> assertThat(6, equalTo(psmt.getMetaData().getColumnCount())));
}
}
@Test
public void testUpdateQuery() throws SQLException {
String query = "Fake update";
PRODUCER.addUpdateQuery(query, /*updatedRows*/ 42);
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
int updated = stmt.executeUpdate();
assertEquals(42, updated);
}
}
@Test
public void testUpdateQueryWithExecute() throws SQLException {
String query = "Fake update with execute";
PRODUCER.addUpdateQuery(query, /*updatedRows*/ 42);
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
boolean isResultSet = stmt.execute();
assertFalse(isResultSet);
int updated = stmt.getUpdateCount();
assertEquals(42, updated);
assertFalse(stmt.getMoreResults());
assertEquals(-1, stmt.getUpdateCount());
}
}
@Test
public void testUpdateQueryWithParameters() throws SQLException {
String query = "Fake update with parameters";
PRODUCER.addUpdateQuery(query, /*updatedRows*/ 42);
PRODUCER.addExpectedParameters(
query,
new Schema(Collections.singletonList(Field.nullable("", ArrowType.Utf8.INSTANCE))),
Collections.singletonList(
Collections.singletonList(new Text("foo".getBytes(StandardCharsets.UTF_8)))));
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
// TODO: make sure this is validated on the server too
stmt.setString(1, "foo");
int updated = stmt.executeUpdate();
assertEquals(42, updated);
}
}
@Test
public void testUpdateQueryWithBatchedParameters() throws SQLException {
String query = "Fake update with batched parameters";
Schema parameterSchema =
new Schema(
Arrays.asList(
Field.nullable("", ArrowType.Utf8.INSTANCE),
new Field(
"",
FieldType.nullable(ArrowType.List.INSTANCE),
Collections.singletonList(Field.nullable("", Types.MinorType.INT.getType())))));
List<List<Object>> expected =
Arrays.asList(
Arrays.asList(new Text("foo"), new Integer[] {1, 2, null}),
Arrays.asList(new Text("bar"), new Integer[] {0, -1, 100000}));
PRODUCER.addUpdateQuery(query, /*updatedRows*/ 42);
PRODUCER.addExpectedParameters(query, parameterSchema, expected);
try (final PreparedStatement stmt = connection.prepareStatement(query)) {
// TODO: make sure this is validated on the server too
stmt.setString(1, "foo");
stmt.setArray(2, connection.createArrayOf("INTEGER", new Integer[] {1, 2, null}));
stmt.addBatch();
stmt.setString(1, "bar");
stmt.setArray(2, connection.createArrayOf("INTEGER", new Integer[] {0, -1, 100000}));
stmt.addBatch();
int[] updated = stmt.executeBatch();
assertEquals(42, updated[0]);
}
}
}