TestFlightSqlStreams.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.flight.sql.test;
import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static org.apache.arrow.flight.sql.util.FlightStreamUtils.getResults;
import static org.apache.arrow.util.AutoCloseables.close;
import static org.apache.arrow.vector.types.Types.MinorType.INT;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertAll;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.Any;
import com.google.protobuf.Message;
import java.util.Collections;
import java.util.List;
import org.apache.arrow.flight.CallStatus;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.Location;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.sql.BasicFlightSqlProducer;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.flight.sql.FlightSqlProducer;
import org.apache.arrow.flight.sql.impl.FlightSql;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.Text;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
public class TestFlightSqlStreams {
/**
* A limited {@link FlightSqlProducer} for testing GetTables, GetTableTypes, GetSqlInfo, and
* limited SQL commands.
*/
private static class FlightSqlTestProducer extends BasicFlightSqlProducer {
// Note that for simplicity the getStream* implementations are blocking, but a proper
// FlightSqlProducer should
// have non-blocking implementations of getStream*.
private static final String FIXED_QUERY = "SELECT 1 AS c1 FROM test_table";
private static final Schema FIXED_SCHEMA =
new Schema(asList(Field.nullable("c1", Types.MinorType.INT.getType())));
private BufferAllocator allocator;
FlightSqlTestProducer(BufferAllocator allocator) {
this.allocator = allocator;
}
@Override
protected <T extends Message> List<FlightEndpoint> determineEndpoints(
T request, FlightDescriptor flightDescriptor, Schema schema) {
if (request instanceof FlightSql.CommandGetTables
|| request instanceof FlightSql.CommandGetTableTypes
|| request instanceof FlightSql.CommandGetXdbcTypeInfo
|| request instanceof FlightSql.CommandGetSqlInfo) {
return Collections.singletonList(
new FlightEndpoint(new Ticket(Any.pack(request).toByteArray())));
} else if (request instanceof FlightSql.CommandStatementQuery
&& ((FlightSql.CommandStatementQuery) request).getQuery().equals(FIXED_QUERY)) {
// Tickets from CommandStatementQuery requests should be built using TicketStatementQuery
// then packed() into
// a ticket. The content of the statement handle is specific to the FlightSqlProducer. It
// does not need to
// be the query. It can be a query ID for example.
FlightSql.TicketStatementQuery ticketStatementQuery =
FlightSql.TicketStatementQuery.newBuilder()
.setStatementHandle(((FlightSql.CommandStatementQuery) request).getQueryBytes())
.build();
return Collections.singletonList(
new FlightEndpoint(new Ticket(Any.pack(ticketStatementQuery).toByteArray())));
}
throw CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException();
}
@Override
public FlightInfo getFlightInfoStatement(
FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) {
return generateFlightInfo(command, descriptor, FIXED_SCHEMA);
}
@Override
public void getStreamStatement(
FlightSql.TicketStatementQuery ticket, CallContext context, ServerStreamListener listener) {
final String query = ticket.getStatementHandle().toStringUtf8();
if (!query.equals(FIXED_QUERY)) {
listener.error(
CallStatus.UNIMPLEMENTED.withDescription("Not implemented.").toRuntimeException());
}
try (VectorSchemaRoot root = VectorSchemaRoot.create(FIXED_SCHEMA, allocator)) {
root.setRowCount(1);
((IntVector) root.getVector("c1")).setSafe(0, 1);
listener.start(root);
listener.putNext();
listener.completed();
}
}
@Override
public void getStreamSqlInfo(
FlightSql.CommandGetSqlInfo command, CallContext context, ServerStreamListener listener) {
try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_SQL_INFO_SCHEMA, allocator)) {
root.setRowCount(0);
listener.start(root);
listener.putNext();
listener.completed();
}
}
@Override
public void getStreamTypeInfo(
FlightSql.CommandGetXdbcTypeInfo request,
CallContext context,
ServerStreamListener listener) {
try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TYPE_INFO_SCHEMA, allocator)) {
root.setRowCount(1);
((VarCharVector) root.getVector("type_name")).setSafe(0, new Text("Integer"));
((IntVector) root.getVector("data_type")).setSafe(0, INT.ordinal());
((IntVector) root.getVector("column_size")).setSafe(0, 400);
root.getVector("literal_prefix").setNull(0);
root.getVector("literal_suffix").setNull(0);
root.getVector("create_params").setNull(0);
((IntVector) root.getVector("nullable"))
.setSafe(0, FlightSql.Nullable.NULLABILITY_NULLABLE.getNumber());
((BitVector) root.getVector("case_sensitive")).setSafe(0, 1);
((IntVector) root.getVector("nullable"))
.setSafe(0, FlightSql.Searchable.SEARCHABLE_FULL.getNumber());
((BitVector) root.getVector("unsigned_attribute")).setSafe(0, 1);
root.getVector("fixed_prec_scale").setNull(0);
((BitVector) root.getVector("auto_increment")).setSafe(0, 1);
((VarCharVector) root.getVector("local_type_name")).setSafe(0, new Text("Integer"));
root.getVector("minimum_scale").setNull(0);
root.getVector("maximum_scale").setNull(0);
((IntVector) root.getVector("sql_data_type")).setSafe(0, INT.ordinal());
root.getVector("datetime_subcode").setNull(0);
((IntVector) root.getVector("num_prec_radix")).setSafe(0, 10);
root.getVector("interval_precision").setNull(0);
listener.start(root);
listener.putNext();
listener.completed();
}
}
@Override
public void getStreamTables(
FlightSql.CommandGetTables command, CallContext context, ServerStreamListener listener) {
try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TABLES_SCHEMA_NO_SCHEMA, allocator)) {
root.setRowCount(1);
root.getVector("catalog_name").setNull(0);
root.getVector("db_schema_name").setNull(0);
((VarCharVector) root.getVector("table_name")).setSafe(0, new Text("test_table"));
((VarCharVector) root.getVector("table_type")).setSafe(0, new Text("TABLE"));
listener.start(root);
listener.putNext();
listener.completed();
}
}
@Override
public void getStreamTableTypes(CallContext context, ServerStreamListener listener) {
try (VectorSchemaRoot root =
VectorSchemaRoot.create(Schemas.GET_TABLE_TYPES_SCHEMA, allocator)) {
root.setRowCount(1);
((VarCharVector) root.getVector("table_type")).setSafe(0, new Text("TABLE"));
listener.start(root);
listener.putNext();
listener.completed();
}
}
}
private static BufferAllocator allocator;
private static FlightServer server;
private static FlightSqlClient sqlClient;
@BeforeAll
public static void setUp() throws Exception {
allocator = new RootAllocator(Integer.MAX_VALUE);
final Location serverLocation = Location.forGrpcInsecure("localhost", 0);
server =
FlightServer.builder(allocator, serverLocation, new FlightSqlTestProducer(allocator))
.build()
.start();
final Location clientLocation = Location.forGrpcInsecure("localhost", server.getPort());
sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build());
}
@AfterAll
public static void tearDown() throws Exception {
close(sqlClient, server);
// Manually close all child allocators.
allocator.getChildAllocators().forEach(BufferAllocator::close);
close(allocator);
}
@Test
public void testGetTablesResultNoSchema() throws Exception {
try (final FlightStream stream =
sqlClient.getStream(
sqlClient.getTables(null, null, null, null, false).getEndpoints().get(0).getTicket())) {
assertAll(
() ->
assertThat(stream.getSchema())
.isEqualTo(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA),
() -> {
final List<List<String>> results = getResults(stream);
final List<List<String>> expectedResults =
ImmutableList.of(
// catalog_name | schema_name | table_name | table_type | table_schema
asList(null, null, "test_table", "TABLE"));
assertThat(results).isEqualTo(expectedResults);
});
}
}
@Test
public void testGetTableTypesResult() throws Exception {
try (final FlightStream stream =
sqlClient.getStream(sqlClient.getTableTypes().getEndpoints().get(0).getTicket())) {
assertAll(
() ->
assertThat(stream.getSchema())
.isEqualTo(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA),
() -> {
final List<List<String>> tableTypes = getResults(stream);
final List<List<String>> expectedTableTypes =
ImmutableList.of(
// table_type
singletonList("TABLE"));
assertThat(tableTypes).isEqualTo(expectedTableTypes);
});
}
}
@Test
public void testGetSqlInfoResults() throws Exception {
final FlightInfo info = sqlClient.getSqlInfo();
try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) {
assertAll(
() ->
assertThat(stream.getSchema())
.isEqualTo(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA),
() -> assertThat(getResults(stream)).isEqualTo(emptyList()));
}
}
@Test
public void testGetTypeInfo() throws Exception {
FlightInfo flightInfo = sqlClient.getXdbcTypeInfo();
try (FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket())) {
final List<List<String>> results = getResults(stream);
final List<List<String>> matchers =
ImmutableList.of(
asList(
"Integer", "4", "400", null, null, "3", "true", null, "true", null, "true",
"Integer", null, null, "4", null, "10", null));
assertThat(results).isEqualTo(matchers);
}
}
@Test
public void testExecuteQuery() throws Exception {
try (final FlightStream stream =
sqlClient.getStream(
sqlClient
.execute(FlightSqlTestProducer.FIXED_QUERY)
.getEndpoints()
.get(0)
.getTicket())) {
assertAll(
() -> assertThat(stream.getSchema()).isEqualTo(FlightSqlTestProducer.FIXED_SCHEMA),
() -> assertThat(getResults(stream)).isEqualTo(singletonList(singletonList("1"))));
}
}
}