SqlCastFunction.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.sql.fun;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeFamily;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlDynamicParam;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.SqlWriter;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.InferTypes;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeMappingRule;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlMonotonicity;
import org.apache.calcite.sql.validate.SqlValidator;

import com.google.common.collect.ImmutableSetMultimap;
import com.google.common.collect.SetMultimap;

import java.text.Collator;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import static com.google.common.base.Preconditions.checkArgument;

import static org.apache.calcite.sql.type.SqlTypeUtil.isArray;
import static org.apache.calcite.sql.type.SqlTypeUtil.isCollection;
import static org.apache.calcite.sql.type.SqlTypeUtil.isMap;
import static org.apache.calcite.sql.type.SqlTypeUtil.isRow;
import static org.apache.calcite.util.Static.RESOURCE;

import static java.util.Objects.requireNonNull;

/**
 * SqlCastFunction. Note that the std functions are really singleton objects,
 * because they always get fetched via the StdOperatorTable. So you can't store
 * any local info in the class and hence the return type data is maintained in
 * operand[1] through the validation phase.
 *
 * <p>Can be used for both {@link SqlCall} and
 * {@link org.apache.calcite.rex.RexCall}.
 * Note that the {@code SqlCall} has two operands (expression and type),
 * while the {@code RexCall} has one operand (expression) and the type is
 * obtained from {@link org.apache.calcite.rex.RexNode#getType()}.
 *
 * @see SqlCastOperator
 */
public class SqlCastFunction extends SqlFunction {
  //~ Instance fields --------------------------------------------------------

  /** Map of all casts that do not preserve monotonicity. */
  private final SetMultimap<SqlTypeFamily, SqlTypeFamily> nonMonotonicCasts =
      ImmutableSetMultimap.<SqlTypeFamily, SqlTypeFamily>builder()
          .put(SqlTypeFamily.EXACT_NUMERIC, SqlTypeFamily.CHARACTER)
          .put(SqlTypeFamily.NUMERIC, SqlTypeFamily.CHARACTER)
          .put(SqlTypeFamily.APPROXIMATE_NUMERIC, SqlTypeFamily.CHARACTER)
          .put(SqlTypeFamily.DATETIME_INTERVAL, SqlTypeFamily.CHARACTER)
          .put(SqlTypeFamily.CHARACTER, SqlTypeFamily.EXACT_NUMERIC)
          .put(SqlTypeFamily.CHARACTER, SqlTypeFamily.NUMERIC)
          .put(SqlTypeFamily.CHARACTER, SqlTypeFamily.APPROXIMATE_NUMERIC)
          .put(SqlTypeFamily.CHARACTER, SqlTypeFamily.DATETIME_INTERVAL)
          .put(SqlTypeFamily.DATETIME, SqlTypeFamily.TIME)
          .put(SqlTypeFamily.TIMESTAMP, SqlTypeFamily.TIME)
          .put(SqlTypeFamily.TIME, SqlTypeFamily.DATETIME)
          .put(SqlTypeFamily.TIME, SqlTypeFamily.TIMESTAMP)
          .build();

  //~ Constructors -----------------------------------------------------------

  public SqlCastFunction() {
    this(SqlKind.CAST.toString(), SqlKind.CAST);
  }

  public SqlCastFunction(String name, SqlKind kind) {
    super(name, kind, returnTypeInference(kind == SqlKind.SAFE_CAST),
        InferTypes.FIRST_KNOWN, null, SqlFunctionCategory.SYSTEM);
    checkArgument(kind == SqlKind.CAST || kind == SqlKind.SAFE_CAST, kind);
  }

  //~ Methods ----------------------------------------------------------------

  static SqlReturnTypeInference returnTypeInference(boolean safe) {
    return opBinding -> {
      assert opBinding.getOperandCount() <= 3;
      final RelDataType ret =
          deriveType(opBinding.getTypeFactory(), opBinding.getOperandType(0),
              opBinding.getOperandType(1), safe);

      if (opBinding instanceof SqlCallBinding) {
        final SqlCallBinding callBinding = (SqlCallBinding) opBinding;
        SqlNode operand0 = callBinding.operand(0);

        // dynamic parameters and null constants need their types assigned
        // to them using the type they are cast to.
        if (SqlUtil.isNullLiteral(operand0, false)
            || operand0 instanceof SqlDynamicParam) {
          callBinding.getValidator().setValidatedNodeType(operand0, ret);
        }
      }
      return ret;
    };
  }

  /** Derives the type of "CAST(expression AS targetType)". */
  public static RelDataType deriveType(RelDataTypeFactory typeFactory,
      RelDataType expressionType, RelDataType targetType, boolean safe) {
    return createTypeWithNullabilityFromExpr(typeFactory, expressionType, targetType, safe);
  }

  private static RelDataType createTypeWithNullabilityFromExpr(RelDataTypeFactory typeFactory,
      RelDataType expressionType, RelDataType targetType, boolean safe) {
    boolean isNullable = expressionType.isNullable() || safe;

    if (targetType.getSqlTypeName() == SqlTypeName.VARIANT) {
      // A variant can be cast from any other type, and it inherits
      // the nullability of the source.
      // Note that the order of this test and the next one is important.
      return typeFactory.createTypeWithNullability(targetType, expressionType.isNullable());
    }

    if (expressionType.getSqlTypeName() == SqlTypeName.VARIANT) {
      // A variant can be cast to any other type, but the result
      // is always nullable, like in the case of a safe cast.
      return typeFactory.createTypeWithNullability(targetType, true);
    }

    if (isCollection(expressionType)) {
      RelDataType expressionElementType = expressionType.getComponentType();
      RelDataType targetElementType = targetType.getComponentType();
      requireNonNull(expressionElementType, () -> "componentType of " + expressionType);
      requireNonNull(targetElementType, () -> "componentType of " + targetType);
      RelDataType newElementType =
          createTypeWithNullabilityFromExpr(
              typeFactory, expressionElementType, targetElementType, safe);
      return isArray(targetType)
          ? SqlTypeUtil.createArrayType(typeFactory, newElementType, isNullable)
          : SqlTypeUtil.createMultisetType(typeFactory, newElementType, isNullable);
    }

    if (isRow(expressionType)) {
      final int fieldCount = expressionType.getFieldCount();
      final List<RelDataType> typeList = new ArrayList<>(fieldCount);
      for (int i = 0; i < fieldCount; ++i) {
        RelDataType expressionElementType = expressionType.getFieldList().get(i).getType();
        RelDataType targetElementType = targetType.getFieldList().get(i).getType();
        typeList.add(
            createTypeWithNullabilityFromExpr(typeFactory, expressionElementType,
                targetElementType, safe));
      }
      return typeFactory.createTypeWithNullability(
          typeFactory.createStructType(
              typeList,
              targetType.getFieldNames()), isNullable);
    }

    if (isMap(expressionType)) {
      RelDataType expressionKeyType =
          requireNonNull(expressionType.getKeyType(), () -> "keyType of " + expressionType);
      RelDataType expressionValueType =
          requireNonNull(expressionType.getValueType(), () -> "valueType of " + expressionType);
      RelDataType targetKeyType =
          requireNonNull(targetType.getKeyType(), () -> "keyType of " + targetType);
      RelDataType targetValueType =
          requireNonNull(targetType.getValueType(), () -> "valueType of " + targetType);

      RelDataType keyType =
          createTypeWithNullabilityFromExpr(
              typeFactory, expressionKeyType, targetKeyType, safe);
      RelDataType valueType =
          createTypeWithNullabilityFromExpr(
              typeFactory, expressionValueType, targetValueType, safe);
      SqlTypeUtil.createMapType(typeFactory, keyType, valueType, isNullable);
    }

    return typeFactory.createTypeWithNullability(targetType, isNullable);
  }

  @Override public String getSignatureTemplate(final int operandsCount) {
    assert operandsCount <= 3;
    return "{0}({1} AS {2} [FORMAT {3}])";
  }

  @Override public SqlOperandCountRange getOperandCountRange() {
    return SqlOperandCountRanges.between(2, 3);
  }

  /**
   * Makes sure that the number and types of arguments are allowable.
   * Operators (such as "ROW" and "AS") which do not check their arguments can
   * override this method.
   */
  @Override public boolean checkOperandTypes(
      SqlCallBinding callBinding,
      boolean throwOnFailure) {
    final SqlNode left = callBinding.operand(0);
    final SqlNode right = callBinding.operand(1);
    final SqlLiteral format = callBinding.getOperandCount() > 2
        ? (SqlLiteral) callBinding.operand(2) : SqlLiteral.createNull(SqlParserPos.ZERO);

    if (SqlUtil.isNullLiteral(left, false)
        || left instanceof SqlDynamicParam) {
      return true;
    }
    final SqlValidator validator = callBinding.getValidator();
    final RelDataType validatedNodeType =
        validator.getValidatedNodeType(left);
    final RelDataType returnType = SqlTypeUtil.deriveType(callBinding, right);
    final SqlTypeMappingRule mappingRule = validator.getTypeMappingRule();

    if (!SqlTypeUtil.canCastFrom(returnType, validatedNodeType, mappingRule)) {
      if (throwOnFailure) {
        throw callBinding.newError(
            RESOURCE.cannotCastValue(validatedNodeType.getFullTypeString(),
                returnType.getFullTypeString()));
      }
      return false;
    }
    if (SqlTypeUtil.areCharacterSetsMismatched(
        validatedNodeType,
        returnType)) {
      if (throwOnFailure) {
        // Include full type string to indicate character
        // set mismatch.
        throw callBinding.newError(
            RESOURCE.cannotCastValue(validatedNodeType.getFullTypeString(),
                returnType.getFullTypeString()));
      }
      return false;
    }

    // Validate format argument is string type if included
    return SqlUtil.isNullLiteral(format, false)
        || SqlLiteral.valueMatchesType(format.getValue(), SqlTypeName.CHAR);
  }

  @Override public SqlSyntax getSyntax() {
    return SqlSyntax.SPECIAL;
  }

  @Override public void unparse(
      SqlWriter writer,
      SqlCall call,
      int leftPrec,
      int rightPrec) {
    assert call.operandCount() <= 3;
    final SqlWriter.Frame frame = writer.startFunCall(getName());
    call.operand(0).unparse(writer, 0, 0);
    writer.sep("AS");
    if (call.operand(1) instanceof SqlIntervalQualifier) {
      writer.sep("INTERVAL");
    }
    call.operand(1).unparse(writer, 0, 0);
    if (call.getOperandList().size() > 2) {
      writer.sep("FORMAT");
      call.operand(2).unparse(writer, 0, 0);
    }
    writer.endFunCall(frame);
  }

  @Override public SqlMonotonicity getMonotonicity(SqlOperatorBinding call) {
    final RelDataType castFromType = call.getOperandType(0);
    final RelDataTypeFamily castFromFamily = castFromType.getFamily();
    final Collator castFromCollator = castFromType.getCollation() == null
        ? null
        : castFromType.getCollation().getCollator();
    final RelDataType castToType = call.getOperandType(1);
    final RelDataTypeFamily castToFamily = castToType.getFamily();
    final Collator castToCollator = castToType.getCollation() == null
        ? null
        : castToType.getCollation().getCollator();
    if (!Objects.equals(castFromCollator, castToCollator)) {
      // Cast between types compared with different collators: not monotonic.
      return SqlMonotonicity.NOT_MONOTONIC;
    } else if (castFromFamily instanceof SqlTypeFamily
        && castToFamily instanceof SqlTypeFamily
        && nonMonotonicCasts.containsEntry(castFromFamily, castToFamily)) {
      return SqlMonotonicity.NOT_MONOTONIC;
    } else {
      return call.getOperandMonotonicity(0);
    }
  }
}