RelMdFunctionalDependency.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.metadata;

import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Calc;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.SetOp;
import org.apache.calcite.rel.core.TableScan;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.List;

/**
 * Default implementation of
 * {@link RelMetadataQuery#determines(RelNode, int, int)}
 * for the standard logical algebra.
 *
 * <p>The goal of this provider is to determine whether
 * key is functionally dependent on column.
 *
 * <p>If the functional dependency cannot be determined, we return false.
 */
public class RelMdFunctionalDependency
    implements MetadataHandler<BuiltInMetadata.FunctionalDependency> {
  public static final RelMetadataProvider SOURCE =
      ReflectiveRelMetadataProvider.reflectiveSource(
          new RelMdFunctionalDependency(), BuiltInMetadata.FunctionalDependency.Handler.class);

  //~ Constructors -----------------------------------------------------------

  protected RelMdFunctionalDependency() {}

  //~ Methods ----------------------------------------------------------------

  @Override public MetadataDef<BuiltInMetadata.FunctionalDependency> getDef() {
    return BuiltInMetadata.FunctionalDependency.DEF;
  }

  public @Nullable Boolean determines(RelNode rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl2(rel, mq, key, column);
  }

  public @Nullable Boolean determines(SetOp rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl2(rel, mq, key, column);
  }

  public @Nullable Boolean determines(Join rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl2(rel, mq, key, column);
  }

  public @Nullable Boolean determines(Correlate rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl2(rel, mq, key, column);
  }

  public @Nullable Boolean determines(Aggregate rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl(rel, mq, key, column);
  }

  public @Nullable Boolean determines(Calc rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl(rel, mq, key, column);
  }

  public @Nullable Boolean determines(Project rel, RelMetadataQuery mq,
      int key, int column) {
    return determinesImpl(rel, mq, key, column);
  }

  /**
   * Checks if a column is functionally determined by a key column through expression analysis.
   *
   * @param rel The input relation
   * @param mq Metadata query instance
   * @param key Index of the determinant expression
   * @param column Index of the dependent expression
   * @return TRUE if column is determined by key,
   *         FALSE if not determined,
   *         NULL if undetermined
   */
  private static @Nullable Boolean determinesImpl(RelNode rel, RelMetadataQuery mq,
      int key, int column) {
    if (preCheck(rel, key, column)) {
      return true;
    }

    ImmutableBitSet keyInputIndices = null;
    ImmutableBitSet columnInputIndices = null;
    if (rel instanceof Project || rel instanceof Calc) {
      List<RexNode> exprs = null;
      if (rel instanceof Project) {
        Project project = (Project) rel;
        exprs = project.getProjects();
      } else {
        Calc calc = (Calc) rel;
        final RexProgram program = calc.getProgram();
        exprs = program.expandList(program.getProjectList());
      }

      // TODO: Supports dependency analysis for all types of expressions
      if (!(exprs.get(column) instanceof RexInputRef)) {
        return false;
      }

      RexNode keyExpr = exprs.get(key);
      RexNode columnExpr = exprs.get(column);

      // Identical expressions imply functional dependency
      if (keyExpr.equals(columnExpr)) {
        return true;
      }

      keyInputIndices = extractDeterministicRefs(keyExpr);
      columnInputIndices = extractDeterministicRefs(columnExpr);
    } else if (rel instanceof Aggregate) {
      Aggregate aggregate = (Aggregate) rel;

      int groupByCnt = aggregate.getGroupCount();
      if (key < groupByCnt && column >= groupByCnt) {
        return false;
      }

      keyInputIndices = extractDeterministicRefs(aggregate, key);
      columnInputIndices = extractDeterministicRefs(aggregate, column);
    } else {
      throw new UnsupportedOperationException("Unsupported RelNode type: "
          + rel.getClass().getSimpleName());
    }

    // Early return if invalid cases
    if (keyInputIndices.isEmpty()
        || columnInputIndices.isEmpty()) {
      return false;
    }

    // Currently only supports multiple (keyInputIndices) to one (columnInputIndices)
    // dependency detection
    for (Integer keyRef : keyInputIndices) {
      if (Boolean.FALSE.equals(
          mq.determines(rel.getInput(0), keyRef,
          columnInputIndices.nextSetBit(0)))) {
        return false;
      }
    }

    return true;
  }

  /**
   * determinesImpl2is similar to determinesImpl, but it doesn't need to handle the
   * mapping between output and input columns.
   */
  private static @Nullable Boolean determinesImpl2(RelNode rel, RelMetadataQuery mq,
      int key, int column) {
    if (preCheck(rel, key, column)) {
      return true;
    }

    if (rel instanceof TableScan) {
      TableScan tableScan = (TableScan) rel;
      RelOptTable table = tableScan.getTable();
      List<ImmutableBitSet> keys = table.getKeys();
      return keys != null
          && keys.size() == 1
          && keys.get(0).equals(ImmutableBitSet.of(column));
    } else if (rel instanceof Join) {
      Join join = (Join) rel;
      // TODO Considering column mapping based on equality conditions in join
      int leftFieldCnt = join.getLeft().getRowType().getFieldCount();
      if (key < leftFieldCnt && column < leftFieldCnt) {
        return mq.determines(join.getLeft(), key, column);
      } else if (key >= leftFieldCnt && column >= leftFieldCnt) {
        return mq.determines(join.getRight(), key - leftFieldCnt, column - leftFieldCnt);
      }
      return false;
    } else if (rel instanceof Correlate) {
      // TODO Support Correlate.
      return false;
    } else if (rel instanceof SetOp) {
      // TODO Support SetOp
      return false;
    }

    return mq.determines(rel.getInput(0), key, column);
  }

  private static Boolean preCheck(RelNode rel, int key, int column) {
    verifyIndex(rel, key, column);

    // Equal index values indicate the same expression reference
    if (key == column) {
      return true;
    }

    return false;
  }

  private static void verifyIndex(RelNode rel, int... indices) {
    for (int index : indices) {
      if (index < 0 || index >= rel.getRowType().getFieldCount()) {
        throw new IndexOutOfBoundsException(
            "Column index " + index + " is out of bounds. "
                + "Valid range is [0, " + rel.getRowType().getFieldCount() + ")");
      }
    }
  }

  /**
   * Extracts input indices referenced by an output column in an Aggregate.
   * For group-by columns, returns the column index itself since they directly
   * reference input columns. For aggregate function columns, returns the input
   * column indices used by the aggregate call.
   *
   * @param aggregate The Aggregate relational expression to analyze
   * @param index Index of the output column in the Aggregate (0-based)
   * @return ImmutableBitSet of input column indices referenced by the output column.
   *         For group-by columns, returns a singleton set of the column index.
   *         For aggregate columns, returns the argument indices of the aggregate call.
   */
  private static ImmutableBitSet extractDeterministicRefs(Aggregate aggregate, int index) {
    int groupByCnt = aggregate.getGroupCount();
    if (index < groupByCnt) {
      return ImmutableBitSet.of(index);
    }

    List<AggregateCall> aggCalls = aggregate.getAggCallList();
    AggregateCall call = aggCalls.get(index - groupByCnt);
    return ImmutableBitSet.of(call.getArgList());
  }

  /**
   * Extracts input indices referenced by a deterministic RexNode expression.
   *
   * @param rex The expression to analyze
   * @return referenced input indices if deterministic
   */
  private static ImmutableBitSet extractDeterministicRefs(RexNode rex) {
    if (rex instanceof RexCall && !RexUtil.isDeterministic(rex)) {
      return ImmutableBitSet.of();
    }
    return RelOptUtil.InputFinder.bits(rex);
  }
}