TestingArrowProducer.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.plugin.arrow.testingServer;
import com.facebook.airlift.json.JsonCodec;
import com.facebook.airlift.log.Logger;
import com.google.common.collect.ImmutableList;
import org.apache.arrow.adapter.jdbc.ArrowVectorIterator;
import org.apache.arrow.adapter.jdbc.JdbcToArrow;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig;
import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.ActionType;
import org.apache.arrow.flight.Criteria;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.FlightEndpoint;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.flight.FlightProducer;
import org.apache.arrow.flight.FlightStream;
import org.apache.arrow.flight.PutResult;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.Ticket;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.TimeZone;
import java.util.concurrent.ThreadLocalRandom;
import static com.facebook.airlift.json.JsonCodec.jsonCodec;
import static com.facebook.presto.common.Utils.checkArgument;
import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema;
public class TestingArrowProducer
implements FlightProducer
{
private final BufferAllocator allocator;
private final Connection connection;
private static final Logger logger = Logger.get(TestingArrowProducer.class);
private final JsonCodec<TestingArrowFlightRequest> requestCodec;
private final JsonCodec<TestingArrowFlightResponse> responseCodec;
public TestingArrowProducer(BufferAllocator allocator) throws Exception
{
this.allocator = allocator;
String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1";
TestingH2DatabaseSetup.setup(h2JdbcUrl);
this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", "");
this.requestCodec = jsonCodec(TestingArrowFlightRequest.class);
this.responseCodec = jsonCodec(TestingArrowFlightResponse.class);
}
@Override
public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener)
{
try (Statement stmt = connection.createStatement()) {
TestingArrowFlightRequest request = requestCodec.fromJson(ticket.getBytes());
checkArgument(request != null, "Request is null");
checkArgument(request.getQuery().isPresent(), "Query is missing");
// Extract and validate the SQL query
String query = request.getQuery().get();
if (query.trim().isEmpty()) {
throw new IllegalArgumentException("Query cannot be empty.");
}
logger.debug("Executing query: %s", query);
try (ResultSet resultSet = stmt.executeQuery(query.toUpperCase())) {
JdbcToArrowConfig config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048)
.setCalendar(Calendar.getInstance(TimeZone.getDefault())).build();
Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), config);
try (VectorSchemaRoot streamRoot = VectorSchemaRoot.create(schema, allocator)) {
VectorLoader loader = new VectorLoader(streamRoot);
serverStreamListener.start(streamRoot);
ArrowVectorIterator iterator = JdbcToArrow.sqlToArrowVectorIterator(resultSet, config);
while (iterator.hasNext()) {
try (VectorSchemaRoot iteratorRoot = iterator.next()) {
VectorUnloader vectorUnloader = new VectorUnloader(iteratorRoot);
try (ArrowRecordBatch batch = vectorUnloader.getRecordBatch()) {
loader.load(batch);
serverStreamListener.putNext();
}
streamRoot.clear();
}
}
}
serverStreamListener.completed();
}
}
// Handle Arrow processing errors
catch (IOException e) {
logger.error("Arrow data processing failed", e);
serverStreamListener.error(e);
throw new RuntimeException("Failed to process Arrow data", e);
}
// Handle all other exceptions, including parsing errors
catch (Exception e) {
logger.error("Ticket processing failed", e);
serverStreamListener.error(e);
throw new RuntimeException("Failed to process the ticket", e);
}
}
@Override
public void listFlights(CallContext callContext, Criteria criteria, StreamListener<FlightInfo> streamListener)
{
throw new UnsupportedOperationException("This operation is not supported");
}
@Override
public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor)
{
try {
TestingArrowFlightRequest request = requestCodec.fromJson(flightDescriptor.getCommand());
checkArgument(request != null, "Request is null");
checkArgument(request.getSchema().isPresent(), "Schema is missing");
String schemaName = request.getSchema().get();
Optional<String> tableName = request.getTable();
String selectStatement = request.getQuery().orElse(null);
List<Field> fields = new ArrayList<>();
if (tableName.isPresent()) {
String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " +
"WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " +
"AND TABLE_NAME='" + tableName.get().toUpperCase() + "'";
try (ResultSet rs = connection.createStatement().executeQuery(query)) {
while (rs.next()) {
String columnName = rs.getString("COLUMN_NAME");
String dataType = rs.getString("TYPE_NAME");
String charMaxLength = rs.getString("CHARACTER_MAXIMUM_LENGTH");
int precision = rs.getInt("NUMERIC_PRECISION");
int scale = rs.getInt("NUMERIC_SCALE");
ArrowType arrowType = convertSqlTypeToArrowType(dataType, precision, scale);
Map<String, String> metaDataMap = new HashMap<>();
metaDataMap.put("columnNativeType", dataType);
if (charMaxLength != null) {
metaDataMap.put("columnLength", charMaxLength);
}
FieldType fieldType = new FieldType(true, arrowType, null, metaDataMap);
Field field = new Field(columnName, fieldType, null);
fields.add(field);
}
}
}
else if (selectStatement != null) {
selectStatement = selectStatement.toUpperCase();
logger.debug("Executing SELECT query: %s", selectStatement);
try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) {
ResultSetMetaData metaData = rs.getMetaData();
int columnCount = metaData.getColumnCount();
for (int i = 1; i <= columnCount; i++) {
String columnName = metaData.getColumnName(i);
String columnType = metaData.getColumnTypeName(i);
int precision = metaData.getPrecision(i);
int scale = metaData.getScale(i);
ArrowType arrowType = convertSqlTypeToArrowType(columnType, precision, scale);
Field field = new Field(columnName, FieldType.nullable(arrowType), null);
fields.add(field);
}
}
}
else {
throw new IllegalArgumentException("Either schema_name/table_name or select_statement must be provided.");
}
Schema schema = new Schema(fields);
FlightEndpoint endpoint = new FlightEndpoint(new Ticket(flightDescriptor.getCommand()));
return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1);
}
catch (Exception e) {
logger.error(e);
throw new RuntimeException("Failed to retrieve FlightInfo", e);
}
}
@Override
public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener<PutResult> streamListener)
{
throw new UnsupportedOperationException("This operation is not supported");
}
@Override
public void doAction(CallContext callContext, Action action, StreamListener<Result> streamListener)
{
try {
TestingArrowFlightRequest request = requestCodec.fromJson(action.getBody());
Optional<String> schemaName = request.getSchema();
String query;
if (!schemaName.isPresent()) {
query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA";
}
else {
query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + schemaName.get().toUpperCase() + "'";
}
ResultSet rs = connection.createStatement().executeQuery(query);
List<String> names = new ArrayList<>();
while (rs.next()) {
names.add(rs.getString(1));
}
TestingArrowFlightResponse response;
if (!schemaName.isPresent()) {
response = new TestingArrowFlightResponse(names, ImmutableList.of());
}
else {
response = new TestingArrowFlightResponse(ImmutableList.of(), names);
}
streamListener.onNext(new Result(responseCodec.toJsonBytes(response)));
streamListener.onCompleted();
}
catch (Exception e) {
streamListener.onError(e);
}
}
@Override
public void listActions(CallContext callContext, StreamListener<ActionType> streamListener)
{
throw new UnsupportedOperationException("This operation is not supported");
}
private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int scale)
{
switch (sqlType.toUpperCase()) {
case "VARCHAR":
case "CHAR":
case "CHARACTER VARYING":
case "CHARACTER":
case "CLOB":
return new ArrowType.Utf8();
case "INTEGER":
case "INT":
return new ArrowType.Int(32, true);
case "BIGINT":
return new ArrowType.Int(64, true);
case "SMALLINT":
return new ArrowType.Int(16, true);
case "TINYINT":
return new ArrowType.Int(8, true);
case "DOUBLE":
case "DOUBLE PRECISION":
case "FLOAT":
return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
case "REAL":
return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE);
case "BOOLEAN":
return new ArrowType.Bool();
case "DATE":
return new ArrowType.Date(DateUnit.DAY);
case "TIMESTAMP":
return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null);
case "TIME":
return new ArrowType.Time(TimeUnit.MILLISECOND, 32);
case "DECIMAL":
case "NUMERIC":
return new ArrowType.Decimal(precision, scale);
case "BINARY":
case "VARBINARY":
return new ArrowType.Binary();
case "NULL":
return new ArrowType.Null();
default:
throw new IllegalArgumentException("Unsupported SQL type: " + sqlType);
}
}
}