AbstractSqlTester.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.sql.test;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.runtime.PairList;
import org.apache.calcite.runtime.Utilities;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlUnresolvedFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.parser.SqlParserUtil;
import org.apache.calcite.sql.parser.StringAndPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.util.SqlShuttle;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
import org.apache.calcite.sql2rel.RelFieldTrimmer;
import org.apache.calcite.sql2rel.SqlToRelConverter;
import org.apache.calcite.test.DiffRepository;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.TestUtil;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.hamcrest.Matcher;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.function.Consumer;

import static org.apache.calcite.test.Matchers.relIsValid;

import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.hasSize;
import static org.junit.jupiter.api.Assertions.assertNotNull;

import static java.util.Objects.requireNonNull;

/**
 * Abstract implementation of {@link SqlTester}
 * that talks to a mock catalog.
 *
 * <p>This is to implement the default behavior: testing is only against the
 * {@link SqlValidator}.
 */
public abstract class AbstractSqlTester implements SqlTester, AutoCloseable {
  private  static final String NL = System.getProperty("line.separator");

  public AbstractSqlTester() {
  }

  /**
   * {@inheritDoc}
   *
   * <p>This default implementation does nothing.
   */
  @Override public void close() {
    // no resources to release
  }

  @Override public void assertExceptionIsThrown(SqlTestFactory factory,
      StringAndPos sap, @Nullable String expectedMsgPattern) {
    final SqlNode sqlNode;
    try {
      sqlNode = parseQuery(factory, sap.sql);
    } catch (Throwable e) {
      SqlTests.checkEx(e, expectedMsgPattern, sap, SqlTests.Stage.PARSE);
      return;
    }

    final SqlValidator validator = factory.createValidator();
    Throwable thrown = null;
    try {
      validator.validate(sqlNode);
    } catch (Throwable ex) {
      thrown = ex;
    }

    SqlTests.checkEx(thrown, expectedMsgPattern, sap, SqlTests.Stage.VALIDATE);
  }

  protected void checkParseEx(Throwable e, @Nullable String expectedMsgPattern,
      StringAndPos sap) {
    try {
      throw e;
    } catch (SqlParseException spe) {
      String errMessage = spe.getMessage();
      if (expectedMsgPattern == null) {
        throw new RuntimeException("Error while parsing query:" + sap, spe);
      } else if (errMessage == null
          || !Util.toLinux(errMessage).matches(expectedMsgPattern)) {
        throw new RuntimeException("Error did not match expected ["
            + expectedMsgPattern + "] while parsing query ["
            + sap + "]", spe);
      }
    } catch (Throwable t) {
      throw new RuntimeException("Error while parsing query: " + sap, t);
    }
  }

  @Override public RelDataType getColumnType(SqlTestFactory factory,
      String sql) {
    return validateAndApply(factory, StringAndPos.of(sql),
        (sql1, validator, n) -> {
          final RelDataType rowType =
              validator.getValidatedNodeType(n);
          final List<RelDataTypeField> fields = rowType.getFieldList();
          assertThat("expected query to return 1 field", fields, hasSize(1));
          return fields.get(0).getType();
        });
  }

  @Override public RelDataType getResultType(SqlTestFactory factory,
      String sql) {
    return validateAndApply(factory, StringAndPos.of(sql),
        (sql1, validator, n) ->
            validator.getValidatedNodeType(n));
  }

  Pair<SqlValidator, SqlNode> parseAndValidate(SqlTestFactory factory,
      String sql) {
    SqlNode sqlNode;
    try {
      sqlNode = parseQuery(factory, sql);
    } catch (Throwable e) {
      throw new RuntimeException("Error while parsing query: " + sql, e);
    }
    SqlValidator validator = factory.createValidator();
    return Pair.of(validator, validator.validate(sqlNode));
  }

  @Override public SqlNode parseQuery(SqlTestFactory factory, String sql)
      throws SqlParseException {
    SqlParser parser = factory.createParser(sql);
    return parser.parseQuery();
  }

  @Override public SqlNode parseExpression(SqlTestFactory factory,
      String expr) throws SqlParseException {
    SqlParser parser = factory.createParser(expr);
    return parser.parseExpression();
  }

  @Override public void checkColumnType(SqlTestFactory factory, String sql,
      String expected) {
    validateAndThen(factory, StringAndPos.of(sql),
        checkColumnTypeAction(is(expected)));
  }

  private static ValidatedNodeConsumer checkColumnTypeAction(
      Matcher<String> matcher) {
    return (sql1, validator, validatedNode) -> {
      final RelDataType rowType =
          validator.getValidatedNodeType(validatedNode);
      final List<RelDataTypeField> fields = rowType.getFieldList();
      assertThat("expected query to return 1 field", fields, hasSize(1));
      final RelDataType actualType = fields.get(0).getType();
      String actual = SqlTests.getTypeString(actualType);
      assertThat(actual, matcher);
    };
  }

  // SqlTester methods

  @Override public void setFor(
      SqlOperator operator,
      VmName... unimplementedVmNames) {
    // do nothing
  }

  @Override public void checkAgg(SqlTestFactory factory,
      String expr,
      String[] inputValues,
      ResultChecker resultChecker) {
    String query =
        SqlTests.generateAggQuery(expr, inputValues);
    check(factory, query, SqlTests.ANY_TYPE_CHECKER, resultChecker);
  }

  @Override public void checkWinAgg(SqlTestFactory factory,
      String expr,
      String[] inputValues,
      String windowSpec,
      String type,
      ResultChecker resultChecker) {
    String query =
        SqlTests.generateWinAggQuery(
            expr, windowSpec, inputValues);
    check(factory, query, SqlTests.ANY_TYPE_CHECKER, resultChecker);
  }

  @Override public void check(SqlTestFactory factory,
      String queryWithCarets, TypeChecker typeChecker,
      ParameterChecker parameterChecker, ResultChecker resultChecker) {
    // This implementation does NOT check the result!
    // All it does is check the return type.
    requireNonNull(typeChecker, "typeChecker");
    requireNonNull(parameterChecker, "parameterChecker");
    requireNonNull(resultChecker, "resultChecker");

    // Parse and validate. There should be no errors.
    // There must be 1 column. Get its type.
    RelDataType actualType = getColumnType(factory, queryWithCarets);

    // Check result type.
    typeChecker.checkType(() -> "Query: " + queryWithCarets, actualType);

    // Unescape the carets that may be present in the query
    final StringAndPos pos = StringAndPos.of(queryWithCarets);
    Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, pos.sql);
    SqlValidator validator = p.left;
    SqlNode n = p.right;
    final RelDataType parameterRowType = validator.getParameterRowType(n);
    parameterChecker.checkParameters(parameterRowType);
  }

  @Override public void validateAndThen(SqlTestFactory factory,
      StringAndPos sap, ValidatedNodeConsumer consumer) {
    Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sap.sql);
    SqlValidator validator = p.left;
    SqlNode rewrittenNode = p.right;
    consumer.accept(sap, validator, rewrittenNode);
  }

  @Override public <R> R validateAndApply(SqlTestFactory factory,
      StringAndPos sap, ValidatedNodeFunction<R> function) {
    Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sap.sql);
    SqlValidator validator = p.left;
    SqlNode rewrittenNode = p.right;
    return function.apply(sap, validator, rewrittenNode);
  }

  @Override public void checkFails(SqlTestFactory factory, StringAndPos sap,
      String expectedError, boolean runtime) {
    if (runtime) {
      // We need to test that the expression fails at runtime.
      // Ironically, that means that it must succeed at prepare time.
      final String sql = buildQuery(sap);
      Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sql);
      SqlNode n = p.right;
      assertNotNull(n);
    } else {
      StringAndPos sap1 = StringAndPos.of(buildQuery(sap));
      checkQueryFails(factory, sap1, expectedError);
    }
  }

  @Override public void checkQueryFails(SqlTestFactory factory,
      StringAndPos sap, String expectedError) {
    assertExceptionIsThrown(factory, sap, expectedError);
  }

  @Override public void checkAggFails(SqlTestFactory factory,
      String expr,
      String[] inputValues,
      String expectedError,
      boolean runtime) {
    final String sql =
        SqlTests.generateAggQuery(expr, inputValues);
    if (runtime) {
      Pair<SqlValidator, SqlNode> p = parseAndValidate(factory, sql);
      SqlNode n = p.right;
      assertNotNull(n);
    } else {
      checkQueryFails(factory, StringAndPos.of(sql), expectedError);
    }
  }

  /**
   * Build a legal query from an expression with carets.
   *
   * <p>@deprecated Please use {@link AbstractSqlTester#buildQuery(StringAndPos)}
   */
  @Deprecated
  public static String buildQuery(String expressionWithCarets) {
    StringAndPos sap = StringAndPos.of(expressionWithCarets);
    return buildQuery(sap);
  }

  public static String buildQuery(StringAndPos sap) {
    return "values (" + sap.sql + ")";
  }

  public static StringAndPos buildQueryWithPos(StringAndPos sap) {
    return StringAndPos.of("values (" + sap.addCarets() + ")");
  }

  public static String buildQueryAgg(String expression) {
    return "select " + expression + " from (values (1)) as t(x) group by x";
  }

  /**
   * Builds a query that extracts all literals as columns in an underlying
   * select.
   *
   * <p>For example,
   *
   * <blockquote>{@code 1 < 5}</blockquote>
   *
   * <p>becomes
   *
   * <blockquote>{@code SELECT p0 < p1
   * FROM (VALUES (1, 5)) AS t(p0, p1)}</blockquote>
   *
   * <p>Null literals don't have enough type information to be extracted.
   * We push down {@code CAST(NULL AS type)} but raw nulls such as
   * {@code CASE 1 WHEN 2 THEN 'a' ELSE NULL END} are left as is.
   *
   * @param factory Test factory
   * @param expressionWithCarets Scalar expression that may contain carets
   * @return Query that evaluates a scalar expression
   */
  protected String buildQuery2(SqlTestFactory factory, String expressionWithCarets) {
    if (expressionWithCarets.matches(
        "(?i).*(percentile_(cont|disc)|convert|sort_array|cast)\\(.*")) {
      // PERCENTILE_CONT requires its argument to be a literal,
      // so converting its argument to a column will cause false errors.
      // Similarly, MSSQL-style CONVERT.
      return buildQuery(expressionWithCarets);
    }
    // "values (1 < 5)"
    // becomes
    // "select p0 < p1 from (values (1, 5)) as t(p0, p1)"
    SqlNode x;
    // Expression may contain escaped carets; unescape them.
    final StringAndPos sap = StringAndPos.of(expressionWithCarets);
    final String sql = "values (" + sap.sql + ")";
    try {
      x = parseQuery(factory, sql);
    } catch (SqlParseException e) {
      throw TestUtil.rethrow(e);
    }
    final Collection<SqlNode> literalSet = new LinkedHashSet<>();
    x.accept(
        new SqlShuttle() {
          private final List<SqlOperator> ops =
              ImmutableList.of(
                  SqlStdOperatorTable.LITERAL_CHAIN,
                  SqlStdOperatorTable.LOCALTIME,
                  SqlStdOperatorTable.LOCALTIMESTAMP,
                  SqlStdOperatorTable.CURRENT_TIME,
                  SqlStdOperatorTable.CURRENT_TIMESTAMP);

          @Override public SqlNode visit(SqlLiteral literal) {
            if (!isNull(literal)
                && literal.getTypeName() != SqlTypeName.SYMBOL) {
              literalSet.add(literal);
            }
            return literal;
          }

          @Override public @Nullable SqlNode visit(SqlCall call) {
            SqlOperator operator = call.getOperator();
            if (operator.getKind() == SqlKind.LAMBDA) {
              return call;
            }
            if (operator instanceof SqlUnresolvedFunction) {
              final SqlUnresolvedFunction unresolvedFunction =
                  (SqlUnresolvedFunction) operator;
              final SqlOperator lookup =
                  SqlValidatorUtil.lookupSqlFunctionByID(
                      SqlStdOperatorTable.instance(),
                      requireNonNull(unresolvedFunction.getSqlIdentifier()),
                      unresolvedFunction.getFunctionType());
              if (lookup != null) {
                operator = lookup;
                call =
                    operator.createCall(call.getFunctionQuantifier(),
                        call.getParserPosition(), call.getOperandList());
              }
            }
            if (operator == SqlStdOperatorTable.CAST
                && isNull(call.operand(0))) {
              literalSet.add(call);
              return call;
            } else if (operator == SqlStdOperatorTable.TRIM) {
              // see https://issues.apache.org/jira/projects/CALCITE/issues/CALCITE-6780
              // don't extract trimmed literal for TRIM function
              call.operand(2).accept(this);
              return call;
            } else if (ops.contains(operator)) {
              // "Argument to function 'LOCALTIME' must be a
              // literal"
              return call;
            } else {
              return super.visit(call);
            }
          }

          private boolean isNull(SqlNode sqlNode) {
            return sqlNode instanceof SqlLiteral
                && ((SqlLiteral) sqlNode).getTypeName()
                == SqlTypeName.NULL;
          }
        });
    final List<SqlNode> nodes = new ArrayList<>(literalSet);
    nodes.sort((o1, o2) -> {
      final SqlParserPos pos0 = o1.getParserPosition();
      final SqlParserPos pos1 = o2.getParserPosition();
      int c = -Utilities.compare(pos0.getLineNum(), pos1.getLineNum());
      if (c != 0) {
        return c;
      }
      return -Utilities.compare(pos0.getColumnNum(), pos1.getColumnNum());
    });
    String sql2 = sql;
    final PairList<String, String> values = PairList.of();
    int p = 0;
    for (SqlNode literal : nodes) {
      final SqlParserPos pos = literal.getParserPosition();
      final int start =
          SqlParserUtil.lineColToIndex(
              sql, pos.getLineNum(), pos.getColumnNum());
      final int end =
          SqlParserUtil.lineColToIndex(
              sql,
              pos.getEndLineNum(),
              pos.getEndColumnNum()) + 1;
      String param = "p" + p++;
      values.add(sql2.substring(start, end), param);
      sql2 = sql2.substring(0, start)
          + param
          + sql2.substring(end);
    }
    if (values.isEmpty()) {
      values.add("1", "p0");
    }
    return "select "
        + sql2.substring("values (".length(), sql2.length() - 1)
        + " from (values ("
        + Util.commaList(values.leftList())
        + ")) as t("
        + Util.commaList(values.rightList())
        + ")";
  }

  @Override public void forEachQuery(SqlTestFactory factory,
      String expressionWithCarets, Consumer<String> consumer) {
    // Why not return a list? If there is a syntax error in the expression, the
    // consumer will discover it before we try to parse it to do substitutions
    // on the parse tree.
    consumer.accept("values (" + expressionWithCarets + ")");
    final String query2 = buildQuery2(factory, expressionWithCarets);
    consumer.accept(SqlParserUtil.escapeCarets(query2));
  }

  @Override public void assertConvertsTo(SqlTestFactory factory,
      DiffRepository diffRepos,
      String sql,
      String plan,
      boolean trim,
      boolean expression,
      boolean decorrelate) {
    if (expression) {
      assertExprConvertsTo(factory, diffRepos, sql, plan);
    } else {
      assertSqlConvertsTo(factory, diffRepos, sql, plan, trim, decorrelate);
    }
  }

  private void assertExprConvertsTo(SqlTestFactory factory,
      DiffRepository diffRepos, String expr, String plan) {
    String expr2 = diffRepos.expand("sql", expr);
    RexNode rex = convertExprToRex(factory, expr2);
    assertNotNull(rex);
    // NOTE jvs 28-Mar-2006:  insert leading newline so
    // that plans come out nicely stacked instead of first
    // line immediately after CDATA start
    String actual = NL + rex + NL;
    diffRepos.assertEquals("plan", plan, actual);
  }

  private void assertSqlConvertsTo(SqlTestFactory factory,
      DiffRepository diffRepos, String sql, String plan,
      boolean trim,
      boolean decorrelate) {
    String sql2 = diffRepos.expand("sql", sql);
    final Pair<SqlValidator, RelRoot> pair =
        convertSqlToRel2(factory, sql2, decorrelate, trim);
    final RelRoot root = pair.right;
    final SqlValidator validator = pair.left;
    RelNode rel = root.project();

    assertNotNull(rel);
    assertThat(rel, relIsValid());

    if (trim) {
      final RelBuilder relBuilder =
          RelFactories.LOGICAL_BUILDER.create(rel.getCluster(), null);
      final RelFieldTrimmer trimmer =
          createFieldTrimmer(validator, relBuilder);
      rel = trimmer.trim(rel);
      assertNotNull(rel);
      assertThat(rel, relIsValid());
    }

    // NOTE jvs 28-Mar-2006:  insert leading newline so
    // that plans come out nicely stacked instead of first
    // line immediately after CDATA start
    String actual = NL + RelOptUtil.toString(rel);
    diffRepos.assertEquals("plan", plan, actual);
  }

  private RexNode convertExprToRex(SqlTestFactory factory, String expr) {
    requireNonNull(expr, "expr");
    final SqlNode sqlQuery;
    try {
      sqlQuery = parseExpression(factory, expr);
    } catch (RuntimeException | Error e) {
      throw e;
    } catch (Exception e) {
      throw TestUtil.rethrow(e);
    }

    final SqlToRelConverter converter = factory.createSqlToRelConverter();
    final SqlValidator validator = requireNonNull(converter.validator);
    final SqlNode validatedQuery = validator.validate(sqlQuery);
    return converter.convertExpression(validatedQuery);
  }

  @Override public Pair<SqlValidator, RelRoot> convertSqlToRel2(
      SqlTestFactory factory, String sql, boolean decorrelate,
      boolean trim) {
    requireNonNull(sql, "sql");
    final SqlNode sqlQuery;
    try {
      sqlQuery = parseQuery(factory, sql);
    } catch (RuntimeException | Error e) {
      throw e;
    } catch (Exception e) {
      throw TestUtil.rethrow(e);
    }
    final SqlToRelConverter converter = factory.createSqlToRelConverter();
    final SqlValidator validator = requireNonNull(converter.validator);

    final SqlNode validatedQuery = validator.validate(sqlQuery);
    RelRoot root =
        converter.convertQuery(validatedQuery, false, true);
    requireNonNull(root, "root");
    if (decorrelate || trim) {
      root = root.withRel(converter.flattenTypes(root.rel, true));
    }
    if (decorrelate) {
      root = root.withRel(converter.decorrelate(sqlQuery, root.rel));
    }
    if (trim) {
      root = root.withRel(converter.trimUnusedFields(true, root.rel));
    }
    return Pair.of(validator, root);
  }

  @Override public RelNode trimRelNode(SqlTestFactory factory,
      RelNode relNode) {
    final SqlToRelConverter converter = factory.createSqlToRelConverter();
    RelNode r2 = converter.flattenTypes(relNode, true);
    return converter.trimUnusedFields(true, r2);
  }

  /**
   * Creates a RelFieldTrimmer.
   *
   * @param validator Validator
   * @param relBuilder Builder
   * @return Field trimmer
   */
  public RelFieldTrimmer createFieldTrimmer(SqlValidator validator,
      RelBuilder relBuilder) {
    return new RelFieldTrimmer(validator, relBuilder);
  }
}