ClientParserRewritableTest.java

// SPDX-License-Identifier: LGPL-2.1-or-later
// Copyright (c) 2012-2014 Monty Program Ab
// Copyright (c) 2015-2025 MariaDB Corporation Ab
package org.mariadb.jdbc.unit.util;

import static org.junit.jupiter.api.Assertions.*;

import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.mariadb.jdbc.util.ClientParser;

public class ClientParserRewritableTest {

  /** SELECT query cannot be rewritable. */
  @Test
  public void selectQuery() {
    // SELECT query cannot be rewritable
    assertFalse(checkRewritable("SELECT * FROM MyTable", 0, 0));
    assertFalse(checkRewritable("SELECT\n * FROM MyTable", 0, 0));
    assertFalse(checkRewritable("SELECT(1)", 0, 0));
    assertFalse(checkRewritable("INSERT MyTable (a) VALUES (1);SELECT(1)", 0, 0));
  }

  /** INSERT FROM SELECT are not be rewritable. */
  @Test
  public void insertSelectQuery() {
    assertFalse(checkRewritable("INSERT INTO MyTable (a) SELECT * FROM seq_1_to_1000", 0, 0));
    assertFalse(checkRewritable("INSERT INTO MyTable (a);SELECT * FROM seq_1_to_1000", 0, 0));
    assertFalse(checkRewritable("INSERT INTO MyTable (a)SELECT * FROM seq_1_to_1000", 0, 0));
    assertFalse(checkRewritable("INSERT INTO MyTable (a) (SELECT * FROM seq_1_to_1000)", 0, 0));
    assertFalse(checkRewritable("INSERT INTO MyTable (a) SELECT\n * FROM seq_1_to_1000", 0, 0));
  }

  /** If parameters exist outside the VALUES() block, not rewritable. */
  @Test
  public void insertParametersOutsideValues() {
    assertFalse(
        checkRewritable("INSERT INTO TABLE(col1) VALUES (?) ON DUPLICATE KEY UPDATE col2=?", 0, 0));
  }

  /** LAST_INSERT_ID is not rewritable. */
  @Test
  public void insertLastInsertId() {
    assertFalse(
        checkRewritable("INSERT INTO TABLE(col1, col2) VALUES (?, LAST_INSERT_ID())", 0, 0));
  }

  /**
   * Insert query that contain table/column name with select keyword, or select in comment can be
   * rewritten.
   */
  @Test
  public void rewritableThatContainSelectQuery() {
    // but 'SELECT' keyword in column/table name can be rewritable
    assertTrue(checkRewritable("INSERT INTO TABLE_SELECT VALUES (?)", 32, 34));
    assertTrue(checkRewritable("INSERT INTO TABLE_SELECT VALUES (?)", 32, 34));
    assertTrue(checkRewritable("INSERT INTO SELECT_TABLE VALUES (?)", 32, 34));
    assertTrue(checkRewritable("INSERT INTO `TABLE SELECT ` VALUES (?)", 35, 37));
    assertTrue(checkRewritable("INSERT INTO TABLE /* SELECT in comment */  VALUES (?)", 50, 52));
    assertTrue(checkRewritable("INSERT INTO TABLE  VALUES (?) //SELECT", 26, 28));
    assertTrue(checkRewritable("INSERT INTO TABLE VALUES ('abc', ?)", 25, 34));
    assertTrue(checkRewritable("INSERT INTO TABLE VALUES (\"a''bc\", ?)", 25, 36));
    assertTrue(checkRewritable("INSERT INTO TABLE VALUES ('\\\\test', ?) /*test* #/ ;`*/", 25, 37));
    assertTrue(checkRewritable("INSERT INTO TABLE VALUES ('\\\\test', ?) # EOL ", 25, 37));
    assertTrue(checkRewritable("INSERT INTO TABLE VALUES ('\\\\test', ?) -- EOL ", 25, 37));
  }

  private boolean checkRewritable(String query, int pos1, int pos2) {
    List<Integer> valuesBracketPositions =
        ClientParser.rewritableParts(query, true).getValuesBracketPositions();
    if (valuesBracketPositions == null) {
      return false;
    } else if (valuesBracketPositions.size() == 2) {
      assertEquals(pos1, valuesBracketPositions.get(0));
      assertEquals(pos2, valuesBracketPositions.get(1));
      return true;
    } else {
      fail("valuesBracketPositions().size() != 2");
      return false; // appeasing the compiler: this line will never be executed.
    }
  }

  static Stream<Arguments> rewriteTestData() {
    return Stream.of(
        Arguments.of("INSERT INTO b VALUES (?)", 1, new int[] {22}, new int[] {21, 23}),
        Arguments.of("UPDATE b set a=?", 1, new int[] {15}, null),
        Arguments.of("INSERT INTO b VALUES (?,? )", 2, new int[] {22, 24}, new int[] {21, 26}),
        Arguments.of("INSERT INTO b (SELECT a FROM b where c=? )", 1, new int[] {39}, null),
        Arguments.of(
            "INSERT INTO b VALUES (?,?), (?,?)",
            4,
            new int[] {22, 24, 29, 31},
            new int[] {21, 32}));
  }

  @ParameterizedTest()
  @MethodSource("rewriteTestData")
  public void rewritableParser(
      String sql, int paramCount, int[] paramPosition, int[] valuesBracketPositions) {
    ClientParser parser = ClientParser.rewritableParts(sql, false);
    assertEquals(parser.getSql(), sql);
    assertEquals(parser.getParamCount(), paramCount);
    assertArrayEquals(
        parser.getParamPositions().stream().mapToInt(Integer::intValue).toArray(), paramPosition);
    if (valuesBracketPositions == null) {
      assertNull(parser.getValuesBracketPositions());
    } else {
      assertArrayEquals(
          parser.getValuesBracketPositions().stream().mapToInt(Integer::intValue).toArray(),
          valuesBracketPositions);
    }
  }
}