ArrowFlightStatementExecuteTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.driver.jdbc;

import static org.hamcrest.CoreMatchers.allOf;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.CoreMatchers.nullValue;
import static org.hamcrest.MatcherAssert.assertThat;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.UInt1Vector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/** Tests for {@link ArrowFlightStatement#execute}. */
public class ArrowFlightStatementExecuteTest {
  private static final String SAMPLE_QUERY_CMD = "SELECT * FROM this_test";
  private static final int SAMPLE_QUERY_ROWS = Byte.MAX_VALUE;
  private static final String VECTOR_NAME = "Unsigned Byte";
  private static final Schema SAMPLE_QUERY_SCHEMA =
      new Schema(Collections.singletonList(Field.nullable(VECTOR_NAME, MinorType.UINT1.getType())));
  private static final String SAMPLE_UPDATE_QUERY =
      "UPDATE this_table SET this_field = that_field FROM this_test WHERE this_condition";
  private static final long SAMPLE_UPDATE_COUNT = 100L;
  private static final String SAMPLE_LARGE_UPDATE_QUERY =
      "UPDATE this_large_table SET this_large_field = that_large_field FROM this_large_test WHERE this_large_condition";
  private static final long SAMPLE_LARGE_UPDATE_COUNT = Long.MAX_VALUE;
  private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer();

  @RegisterExtension
  public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION =
      FlightServerTestExtension.createStandardTestExtension(PRODUCER);

  private Connection connection;
  private Statement statement;

  @BeforeAll
  public static void setUpBeforeClass() {
    PRODUCER.addSelectQuery(
        SAMPLE_QUERY_CMD,
        SAMPLE_QUERY_SCHEMA,
        Collections.singletonList(
            listener -> {
              try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
                  final VectorSchemaRoot root =
                      VectorSchemaRoot.create(SAMPLE_QUERY_SCHEMA, allocator)) {
                final UInt1Vector vector = (UInt1Vector) root.getVector(VECTOR_NAME);
                IntStream.range(0, SAMPLE_QUERY_ROWS)
                    .forEach(index -> vector.setSafe(index, index));
                vector.setValueCount(SAMPLE_QUERY_ROWS);
                root.setRowCount(SAMPLE_QUERY_ROWS);
                listener.start(root);
                listener.putNext();
              } catch (final Throwable throwable) {
                listener.error(throwable);
              } finally {
                listener.completed();
              }
            }));
    PRODUCER.addUpdateQuery(SAMPLE_UPDATE_QUERY, SAMPLE_UPDATE_COUNT);
    PRODUCER.addUpdateQuery(SAMPLE_LARGE_UPDATE_QUERY, SAMPLE_LARGE_UPDATE_COUNT);
  }

  @BeforeEach
  public void setUp() throws SQLException {
    connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false);
    statement = connection.createStatement();
  }

  @AfterEach
  public void tearDown() throws Exception {
    AutoCloseables.close(statement, connection);
  }

  @AfterAll
  public static void tearDownAfterClass() throws Exception {
    AutoCloseables.close(PRODUCER);
  }

  @Test
  public void testExecuteShouldRunSelectQuery() throws SQLException {
    assertThat(statement.execute(SAMPLE_QUERY_CMD), is(true)); // Means this is a SELECT query.
    final Set<Byte> numbers =
        IntStream.range(0, SAMPLE_QUERY_ROWS)
            .boxed()
            .map(Integer::byteValue)
            .collect(Collectors.toCollection(HashSet::new));
    try (final ResultSet resultSet = statement.getResultSet()) {
      final int columnCount = resultSet.getMetaData().getColumnCount();
      assertThat(columnCount, is(1));
      int rowCount = 0;
      for (; resultSet.next(); rowCount++) {
        assertThat(numbers.remove(resultSet.getByte(1)), is(true));
      }
      assertThat(rowCount, is(equalTo(SAMPLE_QUERY_ROWS)));
    }
    assertThat(numbers, is(Collections.emptySet()));
    assertThat(
        (long) statement.getUpdateCount(),
        is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(-1L))));
  }

  @Test
  public void testExecuteShouldRunUpdateQueryForSmallUpdate() throws SQLException {
    assertThat(statement.execute(SAMPLE_UPDATE_QUERY), is(false)); // Means this is an UPDATE query.
    assertThat(
        (long) statement.getUpdateCount(),
        is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(SAMPLE_UPDATE_COUNT))));
    assertThat(statement.getResultSet(), is(nullValue()));
  }

  @Test
  public void testExecuteShouldRunUpdateQueryForLargeUpdate() throws SQLException {
    assertThat(statement.execute(SAMPLE_LARGE_UPDATE_QUERY), is(false)); // UPDATE query.
    final long updateCountSmall = statement.getUpdateCount();
    final long updateCountLarge = statement.getLargeUpdateCount();
    assertThat(updateCountLarge, is(equalTo(SAMPLE_LARGE_UPDATE_COUNT)));
    assertThat(
        updateCountSmall,
        is(
            allOf(
                equalTo((long) AvaticaUtils.toSaturatedInt(updateCountLarge)),
                not(equalTo(updateCountLarge)))));
    assertThat(statement.getResultSet(), is(nullValue()));
  }

  @Test
  public void testUpdateCountShouldStartOnZero() throws SQLException {
    assertThat(
        (long) statement.getUpdateCount(),
        is(allOf(equalTo(statement.getLargeUpdateCount()), equalTo(0L))));
    assertThat(statement.getResultSet(), is(nullValue()));
  }
}