VariantNonNull.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.runtime.variant;

import org.apache.calcite.linq4j.tree.Primitive;
import org.apache.calcite.linq4j.tree.UnsignedType;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.runtime.rtti.BasicSqlTypeRtti;
import org.apache.calcite.runtime.rtti.RowSqlTypeRtti;
import org.apache.calcite.runtime.rtti.RuntimeTypeInformation;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.joou.UByte;
import org.joou.UInteger;
import org.joou.ULong;
import org.joou.UShort;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.UUID;

import static org.apache.calcite.runtime.rtti.RuntimeTypeInformation.RuntimeSqlTypeName.NAME;

import static java.util.Objects.requireNonNull;

/** A VARIANT value that contains a non-null value. */
public class VariantNonNull extends VariantSqlValue {
  final RoundingMode roundingMode;
  /** Actual value - can have any SQL type. */
  final Object value;

  VariantNonNull(RoundingMode roundingMode, Object value, RuntimeTypeInformation runtimeType) {
    super(runtimeType.getTypeName());
    this.roundingMode = roundingMode;
    // sanity check
    switch (runtimeType.getTypeName()) {
    case UUID:
      assert value instanceof UUID;
      this.value = value;
      break;
    case NAME:
      assert value instanceof String;
      this.value = value;
      break;
    case BOOLEAN:
      assert value instanceof Boolean;
      this.value = value;
      break;
    case TINYINT:
      assert value instanceof Byte;
      this.value = value;
      break;
    case UTINYINT:
      assert value instanceof UByte;
      this.value = value;
      break;
    case SMALLINT:
      assert value instanceof Short;
      this.value = value;
      break;
    case USMALLINT:
      assert value instanceof UShort;
      this.value = value;
      break;
    case INTEGER:
      assert value instanceof Integer;
      this.value = value;
      break;
    case UINTEGER:
      assert value instanceof UInteger;
      this.value = value;
      break;
    case BIGINT:
      assert value instanceof Long;
      this.value = value;
      break;
    case UBIGINT:
      assert value instanceof ULong;
      this.value = value;
      break;
    case DECIMAL:
      assert value instanceof BigDecimal;
      this.value = value;
      break;
    case REAL:
      assert value instanceof Float;
      this.value = value;
      break;
    case DOUBLE:
      assert value instanceof Double;
      this.value = value;
      break;
    case DATE:
    case TIME:
    case TIME_WITH_LOCAL_TIME_ZONE:
    case TIME_TZ:
    case TIMESTAMP:
    case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
    case TIMESTAMP_TZ:
    case INTERVAL_LONG:
    case INTERVAL_SHORT:
      this.value = value;
      break;
    case VARCHAR:
      this.value = value;
      assert value instanceof String;
      break;
    case NULL:
    default:
      throw new RuntimeException("Unreachable");
    case VARBINARY:
    case GEOMETRY:
    case VARIANT:
      this.value = value;
      break;
    case MAP: {
      RuntimeTypeInformation keyType = runtimeType.asGeneric().getTypeArgument(0);
      RuntimeTypeInformation valueType = runtimeType.asGeneric().getTypeArgument(1);
      assert value instanceof Map<?, ?>;
      Map<?, ?> map = (Map<?, ?>) value;
      LinkedHashMap<VariantValue, VariantValue> converted = new LinkedHashMap<>(map.size());
      for (Map.Entry<?, ?> o : map.entrySet()) {
        VariantValue key = VariantSqlValue.create(roundingMode, o.getKey(), keyType);
        VariantValue val = VariantSqlValue.create(roundingMode, o.getValue(), valueType);
        converted.put(key, val);
      }
      this.value = converted;
      break;
    }
    case ROW: {
      assert value instanceof Object[];
      Object[] a = (Object[]) value;
      assert runtimeType instanceof RowSqlTypeRtti;
      RowSqlTypeRtti rowType = (RowSqlTypeRtti) runtimeType;
      LinkedHashMap<VariantValue, VariantValue> converted = new LinkedHashMap<>(a.length);
      RuntimeTypeInformation name = new BasicSqlTypeRtti(NAME);
      for (int i = 0; i < a.length; i++) {
        Map.Entry<String, RuntimeTypeInformation> fieldType = rowType.getField(i);
        VariantValue key = VariantSqlValue.create(roundingMode, fieldType.getKey(), name);
        VariantValue val = VariantSqlValue.create(roundingMode, a[i], fieldType.getValue());
        converted.put(key, val);
      }
      this.value = converted;
      break;
    }
    case MULTISET:
    case ARRAY: {
      RuntimeTypeInformation elementType = runtimeType.asGeneric().getTypeArgument(0);
      assert value instanceof List<?>;
      List<?> list = (List<?>) value;
      List<VariantValue> converted = new ArrayList<>(list.size());
      for (Object o : list) {
        VariantValue element = VariantSqlValue.create(roundingMode, o, elementType);
        converted.add(element);
      }
      this.value = converted;
      break;
    }
    }
  }

  @Override public boolean equals(@Nullable Object o) {
    if (this == o) {
      return true;
    }
    if (o == null || getClass() != o.getClass()) {
      return false;
    }

    VariantNonNull variant = (VariantNonNull) o;
    return Objects.equals(value, variant.value)
        && runtimeType == variant.runtimeType;
  }

  @Override public int hashCode() {
    int result = Objects.hashCode(value);
    result = 31 * result + runtimeType.hashCode();
    return result;
  }

  /** Cast this value to the specified type.  Currently, the rule is:
   * if the value has the specified type, the value field is returned, otherwise a SQL
   * NULL is returned. */
  // This method is invoked from {@link RexToLixTranslator} VARIANT_CAST
  @Override public @Nullable Object cast(RuntimeTypeInformation type) {
    if (this.runtimeType.isScalar()) {
      if (this.runtimeType == type.getTypeName()) {
        return this.value;
      } else {
        // Convert numeric values
        @Nullable Primitive target = type.asPrimitive();
        switch (this.runtimeType) {
        case TINYINT: {
          byte b = (byte) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(b, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(b);
          case UTINYINT:
            return UnsignedType.toUByte(b);
          case USMALLINT:
            return UnsignedType.toUShort(b);
          case UINTEGER:
            return UnsignedType.toUInteger(b);
          case UBIGINT:
            return UnsignedType.toULong(b);
          default:
            break;
          }
          break;
        }
        case SMALLINT: {
          short s = (short) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(s, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(s);
          case UTINYINT:
            return UnsignedType.toUByte(s);
          case USMALLINT:
            return UnsignedType.toUShort(s);
          case UINTEGER:
            return UnsignedType.toUInteger(s);
          case UBIGINT:
            return UnsignedType.toULong(s);
          default:
            break;
          }
          break;
        }
        case INTEGER: {
          int i = (int) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(i, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(i);
          case UTINYINT:
            return UnsignedType.toUByte(i);
          case USMALLINT:
            return UnsignedType.toUShort(i);
          case UINTEGER:
            return UnsignedType.toUInteger(i);
          case UBIGINT:
            return UnsignedType.toULong(i);
          default:
            break;
          }
          break;
        }
        case BIGINT: {
          long l = (int) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(l, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(l);
          case UTINYINT:
            return UnsignedType.toUByte(l);
          case USMALLINT:
            return UnsignedType.toUShort(l);
          case UINTEGER:
            return UnsignedType.toUInteger(l);
          case UBIGINT:
            return UnsignedType.toULong(l);
          default:
            break;
          }
          break;
        }
        case DECIMAL: {
          BigDecimal d = (BigDecimal) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(d, roundingMode);
          case DECIMAL:
            return d;
          case UTINYINT:
            return UnsignedType.toUByte(d.longValueExact());
          case USMALLINT:
            return UnsignedType.toUShort(d.longValueExact());
          case UINTEGER:
            return UnsignedType.toUInteger(d.longValueExact());
          case UBIGINT:
            return UnsignedType.toULong(d.longValueExact());
          default:
            break;
          }
          break;
        }
        case REAL: {
          float f = (float) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(f, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(f);
          case UTINYINT:
            return UnsignedType.toUByte(f);
          case USMALLINT:
            return UnsignedType.toUShort(f);
          case UINTEGER:
            return UnsignedType.toUInteger(f);
          case UBIGINT:
            return UnsignedType.toULong(f);
          default:
            break;
          }
          break;
        }
        case DOUBLE: {
          double d = (double) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(d, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(d);
          case UTINYINT:
            return UnsignedType.toUByte(d);
          case USMALLINT:
            return UnsignedType.toUShort(d);
          case UINTEGER:
            return UnsignedType.toUInteger(d);
          case UBIGINT:
            return UnsignedType.toULong(d);
          default:
            break;
          }
          break;
        }
        case UTINYINT: {
          UByte b = (UByte) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(b, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(b.intValue());
          case UTINYINT:
            return UnsignedType.toUByte(b.intValue());
          case USMALLINT:
            return UnsignedType.toUShort(b.intValue());
          case UINTEGER:
            return UnsignedType.toUInteger(b.intValue());
          case UBIGINT:
            return UnsignedType.toULong(b.intValue());
          default:
            break;
          }
          break;
        }
        case USMALLINT: {
          UShort s = (UShort) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(s, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(s.intValue());
          case UTINYINT:
            return UnsignedType.toUByte(s.intValue());
          case USMALLINT:
            return UnsignedType.toUShort(s.intValue());
          case UINTEGER:
            return UnsignedType.toUInteger(s.intValue());
          case UBIGINT:
            return UnsignedType.toULong(s.intValue());
          default:
            break;
          }
          break;
        }
        case UINTEGER: {
          UInteger b = (UInteger) value;
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(b, roundingMode);
          case DECIMAL:
            return BigDecimal.valueOf(b.longValue());
          case UTINYINT:
            return UnsignedType.toUByte(b.longValue());
          case USMALLINT:
            return UnsignedType.toUShort(b.longValue());
          case UINTEGER:
            return UnsignedType.toUInteger(b.longValue());
          case UBIGINT:
            return UnsignedType.toULong(b.longValue());
          default:
            break;
          }
          break;
        }
        case UBIGINT: {
          BigInteger i = UnsignedType.toBigInteger((ULong) value);
          switch (type.getTypeName()) {
          case TINYINT:
          case SMALLINT:
          case INTEGER:
          case BIGINT:
          case REAL:
          case DOUBLE:
            return requireNonNull(target, "target").numberValue(i, roundingMode);
          case DECIMAL:
            return new BigDecimal(i);
          case UTINYINT:
            return UnsignedType.toUByte(i.longValueExact());
          case USMALLINT:
            return UnsignedType.toUShort(i.longValueExact());
          case UINTEGER:
            return UnsignedType.toUInteger(i.longValueExact());
          case UBIGINT:
            return value;
          default:
            break;
          }
          break;
        }
        default:
          break;
        }
        return null;
      }
    } else {
      switch (this.runtimeType) {
      case ARRAY:
        if (type.getTypeName() == RuntimeTypeInformation.RuntimeSqlTypeName.ARRAY) {
          RuntimeTypeInformation elementType = type.asGeneric().getTypeArgument(0);
          assert value instanceof List;
          List<VariantSqlValue> list = (List<VariantSqlValue>) value;
          List<@Nullable Object> result = new ArrayList<>(list.size());
          for (VariantSqlValue o : list) {
            @Nullable Object converted = o.cast(elementType);
            result.add(converted);
          }
          return result;
        }
        break;
      case MAP:
        assert value instanceof Map;
        Map<VariantSqlValue, VariantSqlValue> map = (Map<VariantSqlValue, VariantSqlValue>) value;
        if (type.getTypeName() == RuntimeTypeInformation.RuntimeSqlTypeName.MAP) {
          // Convert map to map: cast keys and values recursively
          RuntimeTypeInformation keyType = type.asGeneric().getTypeArgument(0);
          RuntimeTypeInformation valueType = type.asGeneric().getTypeArgument(0);
          LinkedHashMap<@Nullable Object, @Nullable Object> result =
              new LinkedHashMap<>(map.size());
          for (Map.Entry<VariantSqlValue, VariantSqlValue> e : map.entrySet()) {
            @Nullable Object key = e.getKey().cast(keyType);
            @Nullable Object value = e.getValue().cast(valueType);
            result.put(key, value);
          }
          return result;
        } else if (type.getTypeName() == RuntimeTypeInformation.RuntimeSqlTypeName.ROW) {
          // Convert map to row: lookup the row's fields in the map
          RowSqlTypeRtti rowType = (RowSqlTypeRtti) type;
          @Nullable Object [] result = new Object[rowType.size()];
          for (int i = 0; i < rowType.size(); i++) {
            Map.Entry<String, RuntimeTypeInformation> field = rowType.getField(i);
            Object fieldValue = null;
            VariantValue v = this.item(field.getKey());
            if (v != null) {
              fieldValue = v.cast(field.getValue());
            }
            result[i] = fieldValue;
          }
          return result;
        }
        break;
      default:
        break;
      }
    }
    return null;
  }

  // Implementation of the array index operator for VARIANT values
  @Override public @Nullable VariantValue item(Object index) {
    boolean isInteger = index instanceof Integer;
    switch (this.runtimeType) {
    case ROW:
      if (index instanceof String) {
        RuntimeTypeInformation string =
            new BasicSqlTypeRtti(RuntimeTypeInformation.RuntimeSqlTypeName.NAME);
        index = VariantSqlValue.create(roundingMode, index, string);
      }
      break;
    case MAP:
      if (index instanceof String) {
        RuntimeTypeInformation string =
            new BasicSqlTypeRtti(RuntimeTypeInformation.RuntimeSqlTypeName.VARCHAR);
        index = VariantSqlValue.create(roundingMode, index, string);
      } else if (isInteger) {
        RuntimeTypeInformation i =
            new BasicSqlTypeRtti(RuntimeTypeInformation.RuntimeSqlTypeName.INTEGER);
        index = VariantSqlValue.create(roundingMode, index, i);
      }
      break;
    case ARRAY:
      if (!isInteger) {
        // Arrays only support integer indexes
        return null;
      }
      break;
    default:
      return null;
    }

    // If index is VARIANT, leave it unchanged
    Object result = SqlFunctions.itemOptional(this.value, index);
    if (result == null) {
      return null;
    }
    // If result is a variant, return as is
    if (result instanceof VariantValue) {
      return (VariantValue) result;
    }
    return null;
  }

  // This method is called by the testing code.
  @Override public String toString() {
    if (this.runtimeType == RuntimeTypeInformation.RuntimeSqlTypeName.ROW) {
      if (value instanceof Map<?, ?>) {
        // Do not print field names, only their values
        Map<?, ?> map = (Map<?, ?>) value;
        StringBuilder buf = new StringBuilder("{");

        boolean first = true;
        for (Map.Entry<?, ?> o : map.entrySet()) {
          if (!first) {
            buf.append(", ");
          }
          first = false;
          if (o.getValue() != null) {
            // This should always be true
            buf.append(o.getValue());
          }
        }
        buf.append("}");
        return buf.toString();
      }
    }
    String quote = "";
    switch (this.runtimeType) {
    case TIME:
    case TIME_WITH_LOCAL_TIME_ZONE:
    case TIME_TZ:
    case TIMESTAMP:
    case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
    case TIMESTAMP_TZ:
    case INTERVAL_LONG:
    case INTERVAL_SHORT:
    case VARCHAR:
    case VARBINARY:
      // At least in Snowflake VARIANT values that are strings
      // are printed with double quotes
      // https://docs.snowflake.com/en/sql-reference/data-types-semistructured
      quote = "\"";
      break;
    default:
      break;
    }
    return quote + value + quote;
  }
}