ArrowFlightStatementExecuteUpdateTest.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.arrow.driver.jdbc;

import static java.lang.String.format;
import static org.hamcrest.CoreMatchers.allOf;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLFeatureNotSupportedException;
import java.sql.Statement;
import java.util.Collections;
import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.Types.MinorType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaUtils;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

/** Tests for {@link ArrowFlightStatement#executeUpdate}. */
public class ArrowFlightStatementExecuteUpdateTest {
  private static final String UPDATE_SAMPLE_QUERY =
      "UPDATE sample_table SET sample_col = sample_val WHERE sample_condition";
  private static final int UPDATE_SAMPLE_QUERY_AFFECTED_COLS = 10;
  private static final String LARGE_UPDATE_SAMPLE_QUERY =
      "UPDATE large_sample_table SET large_sample_col = large_sample_val WHERE large_sample_condition";
  private static final long LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS = (long) Integer.MAX_VALUE + 1;
  private static final String REGULAR_QUERY_SAMPLE = "SELECT * FROM NOT_UPDATE_QUERY";
  private static final Schema REGULAR_QUERY_SCHEMA =
      new Schema(
          Collections.singletonList(Field.nullable("placeholder", MinorType.VARCHAR.getType())));
  private static final MockFlightSqlProducer PRODUCER = new MockFlightSqlProducer();

  @RegisterExtension
  public static final FlightServerTestExtension FLIGHT_SERVER_TEST_EXTENSION =
      FlightServerTestExtension.createStandardTestExtension(PRODUCER);

  public Connection connection;
  public Statement statement;

  @BeforeAll
  public static void setUpBeforeClass() {
    PRODUCER.addUpdateQuery(UPDATE_SAMPLE_QUERY, UPDATE_SAMPLE_QUERY_AFFECTED_COLS);
    PRODUCER.addUpdateQuery(LARGE_UPDATE_SAMPLE_QUERY, LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS);
    PRODUCER.addSelectQuery(
        REGULAR_QUERY_SAMPLE,
        REGULAR_QUERY_SCHEMA,
        Collections.singletonList(
            listener -> {
              try (final BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
                  final VectorSchemaRoot root =
                      VectorSchemaRoot.create(REGULAR_QUERY_SCHEMA, allocator)) {
                listener.start(root);
                listener.putNext();
              } catch (final Throwable throwable) {
                listener.error(throwable);
              } finally {
                listener.completed();
              }
            }));
  }

  @BeforeEach
  public void setUp() throws SQLException {
    connection = FLIGHT_SERVER_TEST_EXTENSION.getConnection(false);
    statement = connection.createStatement();
  }

  @AfterEach
  public void tearDown() throws Exception {
    AutoCloseables.close(statement, connection);
  }

  @AfterAll
  public static void tearDownAfterClass() throws Exception {
    AutoCloseables.close(PRODUCER);
  }

  @Test
  public void testExecuteUpdateShouldReturnNumColsAffectedForNumRowsFittingInt()
      throws SQLException {
    assertThat(statement.executeUpdate(UPDATE_SAMPLE_QUERY), is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS));
  }

  @Test
  public void testExecuteUpdateShouldReturnSaturatedNumColsAffectedIfDoesNotFitInInt()
      throws SQLException {
    final long result = statement.executeUpdate(LARGE_UPDATE_SAMPLE_QUERY);
    final long expectedRowCountRaw = LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS;
    assertThat(
        result,
        is(
            allOf(
                not(equalTo(expectedRowCountRaw)),
                equalTo(
                    (long)
                        AvaticaUtils.toSaturatedInt(
                            expectedRowCountRaw))))); // Because of long-to-integer overflow.
  }

  @Test
  public void testExecuteLargeUpdateShouldReturnNumColsAffected() throws SQLException {
    assertThat(
        statement.executeLargeUpdate(LARGE_UPDATE_SAMPLE_QUERY),
        is(LARGE_UPDATE_SAMPLE_QUERY_AFFECTED_COLS));
  }

  @Test
  // TODO Implement `Statement#executeUpdate(String, int)`
  public void testExecuteUpdateUnsupportedWithDriverFlag() throws SQLException {
    assertThrows(
        SQLFeatureNotSupportedException.class,
        () -> {
          assertThat(
              statement.executeUpdate(UPDATE_SAMPLE_QUERY, Statement.NO_GENERATED_KEYS),
              is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS));
        });
  }

  @Test
  // TODO Implement `Statement#executeUpdate(String, int[])`
  public void testExecuteUpdateUnsupportedWithArrayOfInts() throws SQLException {
    assertThrows(
        SQLFeatureNotSupportedException.class,
        () -> {
          assertThat(
              statement.executeUpdate(UPDATE_SAMPLE_QUERY, new int[0]),
              is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS));
        });
  }

  @Test
  // TODO Implement `Statement#executeUpdate(String, String[])`
  public void testExecuteUpdateUnsupportedWithArraysOfStrings() throws SQLException {
    assertThrows(
        SQLFeatureNotSupportedException.class,
        () -> {
          assertThat(
              statement.executeUpdate(UPDATE_SAMPLE_QUERY, new String[0]),
              is(UPDATE_SAMPLE_QUERY_AFFECTED_COLS));
        });
  }

  @Test
  public void testExecuteShouldExecuteUpdateQueryAutomatically() throws SQLException {
    assertThat(
        statement.execute(UPDATE_SAMPLE_QUERY), is(false)); // Meaning there was an update query.
    assertThat(
        statement.execute(REGULAR_QUERY_SAMPLE), is(true)); // Meaning there was a select query.
  }

  @Test
  public void testShouldFailToPrepareStatementForNullQuery() {
    int count = 0;
    try {
      assertThat(statement.execute(null), is(false));
    } catch (final SQLException e) {
      count++;
      assertThat(e.getCause(), is(instanceOf(NullPointerException.class)));
    }
    assertThat(count, is(1));
  }

  @Test
  public void testShouldFailToPrepareStatementForClosedStatement() throws SQLException {
    statement.close();
    assertThat(statement.isClosed(), is(true));
    int count = 0;
    try {
      statement.execute(UPDATE_SAMPLE_QUERY);
    } catch (final SQLException e) {
      count++;
      assertThat(e.getMessage(), is("Statement closed"));
    }
    assertThat(count, is(1));
  }

  @Test
  public void testShouldFailToPrepareStatementForBadStatement() {
    final String badQuery = "BAD INVALID STATEMENT";
    int count = 0;
    try {
      statement.execute(badQuery);
    } catch (final SQLException e) {
      count++;
      /*
       * The error message is up to whatever implementation of `FlightSqlProducer`
       * the driver is communicating with. However, for the purpose of this test,
       * we simply throw an `IllegalArgumentException` for queries not registered
       * in our `MockFlightSqlProducer`.
       */
      assertThat(
          e.getMessage(),
          is(format("Error while executing SQL \"%s\": Query not found", badQuery)));
    }
    assertThat(count, is(1));
  }
}