TestingArrowFlightClientHandler.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.plugin.arrow.testingConnector;

import com.facebook.airlift.json.JsonCodec;
import com.facebook.plugin.arrow.ArrowFlightConfig;
import com.facebook.plugin.arrow.ArrowTableHandle;
import com.facebook.plugin.arrow.ArrowTableLayoutHandle;
import com.facebook.plugin.arrow.BaseArrowFlightClientHandler;
import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest;
import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.SchemaTableName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.flight.Action;
import org.apache.arrow.flight.CallOption;
import org.apache.arrow.flight.CallOptions;
import org.apache.arrow.flight.FlightClient;
import org.apache.arrow.flight.FlightDescriptor;
import org.apache.arrow.flight.Result;
import org.apache.arrow.flight.auth2.BearerCredentialWriter;
import org.apache.arrow.flight.grpc.CredentialCallOption;
import org.apache.arrow.memory.BufferAllocator;

import javax.inject.Inject;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.TimeUnit;

import static com.facebook.presto.common.Utils.checkArgument;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

public class TestingArrowFlightClientHandler
        extends BaseArrowFlightClientHandler
{
    private final JsonCodec<TestingArrowFlightRequest> requestCodec;
    private final JsonCodec<TestingArrowFlightResponse> responseCodec;

    @Inject
    public TestingArrowFlightClientHandler(
            BufferAllocator allocator,
            ArrowFlightConfig config,
            JsonCodec<TestingArrowFlightRequest> requestCodec,
            JsonCodec<TestingArrowFlightResponse> responseCodec)
    {
        super(allocator, config);
        this.requestCodec = requireNonNull(requestCodec, "requestCodec is null");
        this.responseCodec = requireNonNull(responseCodec, "responseCodec is null");
    }

    @Override
    public CallOption[] getCallOptions(ConnectorSession connectorSession)
    {
        return new CallOption[] {
                new CredentialCallOption(new BearerCredentialWriter(null)),
                CallOptions.timeout(300, TimeUnit.SECONDS)
        };
    }

    @Override
    public FlightDescriptor getFlightDescriptorForSchema(ConnectorSession session, String schemaName, String tableName)
    {
        TestingArrowFlightRequest request = TestingArrowFlightRequest.createDescribeTableRequest(schemaName, tableName);
        return FlightDescriptor.command(requestCodec.toBytes(request));
    }

    @Override
    public List<String> listSchemaNames(ConnectorSession session)
    {
        List<String> res;
        try (FlightClient client = createFlightClient()) {
            List<String> names1 = new ArrayList<>();
            TestingArrowFlightRequest request = TestingArrowFlightRequest.createListSchemaRequest();
            Iterator<Result> iterator = client.doAction(new Action("discovery", requestCodec.toJsonBytes(request)), getCallOptions(session));
            while (iterator.hasNext()) {
                Result result = iterator.next();
                TestingArrowFlightResponse response = responseCodec.fromJson(result.getBody());
                checkArgument(response != null, "response is null");
                checkArgument(response.getSchemaNames() != null, "response.getSchemaNames() is null");
                names1.addAll(response.getSchemaNames());
            }
            res = names1;
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        List<String> listSchemas = res;
        List<String> names = new ArrayList<>();
        for (String value : listSchemas) {
            names.add(value.toLowerCase(ENGLISH));
        }
        return ImmutableList.copyOf(names);
    }

    @Override
    public List<SchemaTableName> listTables(ConnectorSession session, Optional<String> schemaName)
    {
        String schemaValue = schemaName.orElse("");
        List<String> res;
        try (FlightClient client = createFlightClient()) {
            List<String> names = new ArrayList<>();
            TestingArrowFlightRequest request = TestingArrowFlightRequest.createListTablesRequest(schemaName.orElse(""));
            Iterator<Result> iterator = client.doAction(new Action("discovery", requestCodec.toJsonBytes(request)), getCallOptions(session));
            while (iterator.hasNext()) {
                Result result = iterator.next();
                TestingArrowFlightResponse response = responseCodec.fromJson(result.getBody());
                checkArgument(response != null, "response is null");
                checkArgument(response.getTableNames() != null, "response.getTableNames() is null");
                names.addAll(response.getTableNames());
            }
            res = names;
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        List<String> listTables = res;
        List<SchemaTableName> tables = new ArrayList<>();
        for (String value : listTables) {
            tables.add(new SchemaTableName(schemaValue.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH)));
        }

        return tables;
    }

    @Override
    public FlightDescriptor getFlightDescriptorForTableScan(ConnectorSession session, ArrowTableLayoutHandle tableLayoutHandle)
    {
        ArrowTableHandle tableHandle = tableLayoutHandle.getTable();
        String query = new TestingArrowQueryBuilder().buildSql(
                tableHandle.getSchema(),
                tableHandle.getTable(),
                tableLayoutHandle.getColumnHandles(), ImmutableMap.of(),
                tableLayoutHandle.getTupleDomain());
        TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), tableHandle.getTable(), query);
        return FlightDescriptor.command(requestCodec.toBytes(request));
    }
}