Aggregate.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.core;

import org.apache.calcite.linq4j.Ord;
import org.apache.calcite.linq4j.function.Experimental;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptCost;
import org.apache.calcite.plan.RelOptPlanner;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelInput;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.hint.Hintable;
import org.apache.calcite.rel.hint.RelHint;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.runtime.CalciteException;
import org.apache.calcite.runtime.Resources;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.SqlUtil;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidatorException;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Litmus;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;

import com.google.common.collect.ImmutableList;
import com.google.common.math.IntMath;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import static com.google.common.base.Preconditions.checkArgument;

import static java.util.Objects.requireNonNull;

/**
 * Relational operator that eliminates
 * duplicates and computes totals.
 *
 * <p>It corresponds to the {@code GROUP BY} operator in a SQL query
 * statement, together with the aggregate functions in the {@code SELECT}
 * clause.
 *
 * <p>Rules:
 *
 * <ul>
 * <li>{@link org.apache.calcite.rel.rules.AggregateProjectPullUpConstantsRule}
 * <li>{@link org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}
 * <li>{@link org.apache.calcite.rel.rules.AggregateReduceFunctionsRule}.
 * </ul>
 */
public abstract class Aggregate extends SingleRel implements Hintable {

  protected final ImmutableList<RelHint> hints;

  public static boolean isSimple(Aggregate aggregate) {
    return aggregate.getGroupType() == Group.SIMPLE;
  }

  @SuppressWarnings("Guava")
  @Deprecated // to be converted to Java Predicate before 2.0
  public static final com.google.common.base.Predicate<Aggregate> IS_SIMPLE =
      Aggregate::isSimple;

  @SuppressWarnings("Guava")
  @Deprecated // to be converted to Java Predicate before 2.0
  public static final com.google.common.base.Predicate<Aggregate> NO_INDICATOR =
      Aggregate::noIndicator;

  @SuppressWarnings("Guava")
  @Deprecated // to be converted to Java Predicate before 2.0
  public static final com.google.common.base.Predicate<Aggregate>
      IS_NOT_GRAND_TOTAL = Aggregate::isNotGrandTotal;

  /** Used internally; will removed when {@link #indicator} is removed,
   * before 2.0. */
  @Experimental
  public static void checkIndicator(boolean indicator) {
    checkArgument(!indicator,
        "indicator is no longer supported; use GROUPING function instead");
  }

  //~ Instance fields --------------------------------------------------------

  @Deprecated // unused field, to be removed before 2.0
  public final boolean indicator = false;

  protected final List<AggregateCall> aggCalls;
  protected final ImmutableBitSet groupSet;
  public final ImmutableList<ImmutableBitSet> groupSets;

  //~ Constructors -----------------------------------------------------------

  /**
   * Creates an Aggregate.
   *
   * <p>All members of {@code groupSets} must be sub-sets of {@code groupSet}.
   * For a simple {@code GROUP BY}, {@code groupSets} is a singleton list
   * containing {@code groupSet}.
   *
   * <p>It is allowed for {@code groupSet} to contain bits that are not in any
   * of the {@code groupSets}, even this does not correspond to valid SQL. See
   * discussion in
   * {@link org.apache.calcite.tools.RelBuilder#groupKey(ImmutableBitSet, Iterable)}.
   *
   * <p>If {@code GROUP BY} is not specified,
   * or equivalently if {@code GROUP BY ()} is specified,
   * {@code groupSet} will be the empty set,
   * and {@code groupSets} will have one element, that empty set.
   *
   * <p>If {@code CUBE}, {@code ROLLUP} or {@code GROUPING SETS} are
   * specified, {@code groupSets} will have additional elements,
   * but they must each be a subset of {@code groupSet},
   * and they must be sorted by inclusion:
   * {@code (0, 1, 2), (1), (0, 2), (0), ()}.
   *
   * @param cluster  Cluster
   * @param traitSet Trait set
   * @param hints    Hints of this relational expression
   * @param input    Input relational expression
   * @param groupSet Bit set of grouping fields
   * @param groupSets List of all grouping sets; null for just {@code groupSet}
   * @param aggCalls Collection of calls to aggregate functions
   */
  @SuppressWarnings("method.invocation.invalid")
  protected Aggregate(
      RelOptCluster cluster,
      RelTraitSet traitSet,
      List<RelHint> hints,
      RelNode input,
      ImmutableBitSet groupSet,
      @Nullable List<ImmutableBitSet> groupSets,
      List<AggregateCall> aggCalls) {
    super(cluster, traitSet, input);
    this.hints = ImmutableList.copyOf(hints);
    this.aggCalls = ImmutableList.copyOf(aggCalls);
    this.groupSet = requireNonNull(groupSet, "groupSet");
    if (groupSets == null) {
      this.groupSets = ImmutableList.of(groupSet);
    } else {
      this.groupSets = ImmutableList.copyOf(groupSets);
      assert ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets) : groupSets;
      for (ImmutableBitSet set : groupSets) {
        assert groupSet.contains(set);
      }
    }
    assert groupSet.length() <= input.getRowType().getFieldCount();
    for (AggregateCall aggCall : aggCalls) {
      assert typeMatchesInferred(aggCall, Litmus.THROW);
      checkArgument(aggCall.filterArg < 0
          || isPredicate(input, aggCall.filterArg),
          "filter must be BOOLEAN NOT NULL");
    }
  }

  @Deprecated // to be removed before 2.0
  protected Aggregate(
      RelOptCluster cluster,
      RelTraitSet traitSet,
      RelNode input,
      ImmutableBitSet groupSet,
      List<ImmutableBitSet> groupSets,
      List<AggregateCall> aggCalls) {
    this(cluster, traitSet, new ArrayList<>(), input, groupSet, groupSets, aggCalls);
  }

  @Deprecated // to be removed before 2.0
  protected Aggregate(
      RelOptCluster cluster,
      RelTraitSet traits,
      RelNode child,
      boolean indicator,
      ImmutableBitSet groupSet,
      List<ImmutableBitSet> groupSets,
      List<AggregateCall> aggCalls) {
    this(cluster, traits, ImmutableList.of(), child, groupSet, groupSets, aggCalls);
    checkIndicator(indicator);
  }

  public static boolean isNotGrandTotal(Aggregate aggregate) {
    return aggregate.getGroupCount() > 0;
  }

  @Deprecated // to be removed before 2.0
  public static boolean noIndicator(Aggregate aggregate) {
    return true;
  }

  private static boolean isPredicate(RelNode input, int index) {
    final RelDataType type =
        input.getRowType().getFieldList().get(index).getType();
    return type.getSqlTypeName() == SqlTypeName.BOOLEAN
        && !type.isNullable();
  }

  /**
   * Creates an Aggregate by parsing serialized output.
   */
  protected Aggregate(RelInput input) {
    this(input.getCluster(), input.getTraitSet(), new ArrayList<>(),
        input.getInput(), input.getBitSet("group"),
        input.getBitSetList("groups"), input.getAggregateCalls("aggs"));
  }

  //~ Methods ----------------------------------------------------------------

  @Override public final RelNode copy(RelTraitSet traitSet,
      List<RelNode> inputs) {
    return copy(traitSet, sole(inputs), groupSet, groupSets, aggCalls);
  }

  /** Creates a copy of this aggregate.
   *
   * @param traitSet Traits
   * @param input Input
   * @param groupSet Bit set of grouping fields
   * @param groupSets List of all grouping sets; null for just {@code groupSet}
   * @param aggCalls Collection of calls to aggregate functions
   * @return New {@code Aggregate} if any parameter differs from the value of
   *   this {@code Aggregate}, or just {@code this} if all the parameters are
   *   the same
   *
   * @see #copy(org.apache.calcite.plan.RelTraitSet, java.util.List)
   */
  public abstract Aggregate copy(RelTraitSet traitSet, RelNode input,
      ImmutableBitSet groupSet,
      @Nullable List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls);

  @Deprecated // to be removed before 2.0
  public Aggregate copy(RelTraitSet traitSet, RelNode input,
      boolean indicator, ImmutableBitSet groupSet,
      List<ImmutableBitSet> groupSets, List<AggregateCall> aggCalls) {
    checkIndicator(indicator);
    return copy(traitSet, input, groupSet, groupSets, aggCalls);
  }

  /**
   * Returns a list of calls to aggregate functions.
   *
   * @return list of calls to aggregate functions
   */
  public List<AggregateCall> getAggCallList() {
    return aggCalls;
  }

  /**
   * Returns a list of calls to aggregate functions together with their output
   * field names.
   *
   * @return list of calls to aggregate functions and their output field names
   */
  public List<Pair<AggregateCall, String>> getNamedAggCalls() {
    final int offset = getGroupCount();
    return Pair.zip(aggCalls, Util.skip(getRowType().getFieldNames(), offset));
  }

  /**
   * Returns the number of grouping fields.
   * These grouping fields are the leading fields in both the input and output
   * records.
   *
   * <p>NOTE: The {@link #getGroupSet()} data structure allows for the
   * grouping fields to not be on the leading edge. New code should, if
   * possible, assume that grouping fields are in arbitrary positions in the
   * input relational expression.
   *
   * @return number of grouping fields
   */
  public int getGroupCount() {
    return groupSet.cardinality();
  }

  public boolean hasEmptyGroup() {
    return groupSets.contains(ImmutableBitSet.of());
  }

  /**
   * Returns the number of indicator fields.
   *
   * <p>Always zero.
   *
   * @return number of indicator fields, always zero
   */
  @Deprecated // to be removed before 2.0
  public int getIndicatorCount() {
    return 0;
  }

  /**
   * Returns a bit set of the grouping fields.
   *
   * @return bit set of ordinals of grouping fields
   */
  public ImmutableBitSet getGroupSet() {
    return groupSet;
  }

  /**
   * Returns the list of grouping sets computed by this Aggregate.
   *
   * @return List of all grouping sets
   */
  public ImmutableList<ImmutableBitSet> getGroupSets() {
    return groupSets;
  }

  @Override public RelWriter explainTerms(RelWriter pw) {
    // We skip the "groups" element if it is a singleton of "group".
    super.explainTerms(pw)
        .item("group", groupSet)
        .itemIf("groups", groupSets, getGroupType() != Group.SIMPLE)
        .itemIf("aggs", aggCalls, pw.nest());
    if (!pw.nest()) {
      for (Ord<AggregateCall> ord : Ord.zip(aggCalls)) {
        pw.item(Util.first(ord.e.name, "agg#" + ord.i), ord.e);
      }
    }
    return pw;
  }

  @Override public double estimateRowCount(RelMetadataQuery mq) {
    // Assume that each sort column has 50% of the value count.
    // Therefore one sort column has .5 * rowCount,
    // 2 sort columns give .75 * rowCount.
    // Zero sort columns yields 1 row (or 0 if the input is empty).
    final int groupCount = groupSet.cardinality();
    if (groupCount == 0) {
      return 1;
    } else {
      double rowCount = super.estimateRowCount(mq);
      rowCount *= 1.0 - Math.pow(.5, groupCount);
      return rowCount;
    }
  }

  @Override public @Nullable RelOptCost computeSelfCost(RelOptPlanner planner,
      RelMetadataQuery mq) {
    // REVIEW jvs 24-Aug-2008:  This is bogus, but no more bogus
    // than what's currently in Join.
    double rowCount = mq.getRowCount(this);
    // Aggregates with more aggregate functions cost a bit more
    float multiplier = 1f + (float) aggCalls.size() * 0.125f;
    for (AggregateCall aggCall : aggCalls) {
      if (aggCall.getAggregation().getName().equals("SUM")) {
        // Pretend that SUM costs a little bit more than $SUM0,
        // to make things deterministic.
        multiplier += 0.0125f;
      }
    }
    return planner.getCostFactory().makeCost(rowCount * multiplier, 0, 0);
  }

  @Override protected RelDataType deriveRowType() {
    return deriveRowType(getCluster().getTypeFactory(), getInput().getRowType(),
        false, groupSet, groupSets, aggCalls);
  }

  /**
   * Computes the row type of an {@code Aggregate} before it exists.
   *
   * @param typeFactory Type factory
   * @param inputRowType Input row type
   * @param indicator Deprecated, always false
   * @param groupSet Bit set of grouping fields
   * @param groupSets List of all grouping sets; null for just {@code groupSet}
   * @param aggCalls Collection of calls to aggregate functions
   * @return Row type of the aggregate
   */
  public static RelDataType deriveRowType(RelDataTypeFactory typeFactory,
      final RelDataType inputRowType, boolean indicator,
      ImmutableBitSet groupSet, @Nullable List<ImmutableBitSet> groupSets,
      final List<AggregateCall> aggCalls) {
    final List<Integer> groupList = groupSet.asList();
    assert groupList.size() == groupSet.cardinality();
    final RelDataTypeFactory.Builder builder = typeFactory.builder();
    final List<RelDataTypeField> fieldList = inputRowType.getFieldList();
    final Set<String> containedNames = new HashSet<>();
    for (int groupKey : groupList) {
      final RelDataTypeField field = fieldList.get(groupKey);
      containedNames.add(field.getName());
      builder.add(field);
      if (groupSets != null && !ImmutableBitSet.allContain(groupSets, groupKey)) {
        builder.nullable(true);
      }
    }
    checkIndicator(indicator);
    for (Ord<AggregateCall> aggCall : Ord.zip(aggCalls)) {
      final String base;
      if (aggCall.e.name != null) {
        base = aggCall.e.name;
      } else {
        base = "$f" + (groupList.size() + aggCall.i);
      }
      String name = base;
      int i = 0;
      while (containedNames.contains(name)) {
        name = base + "_" + i++;
      }
      containedNames.add(name);
      builder.add(name, aggCall.e.type);
    }
    return builder.build();
  }

  @Override public boolean isValid(Litmus litmus, @Nullable Context context) {
    return super.isValid(litmus, context)
        && litmus.check(Util.isDistinct(getRowType().getFieldNames()),
            "distinct field names: {}", getRowType());
  }

  /**
   * Returns whether the inferred type of an {@link AggregateCall} matches the
   * type it was given when it was created.
   *
   * @param aggCall Aggregate call
   * @param litmus What to do if an error is detected (types do not match)
   * @return Whether the inferred and declared types match
   */
  private boolean typeMatchesInferred(
      final AggregateCall aggCall,
      final Litmus litmus) {
    SqlAggFunction aggFunction = aggCall.getAggregation();
    AggCallBinding callBinding = aggCall.createBinding(this);
    RelDataType type = aggFunction.inferReturnType(callBinding);
    RelDataType expectedType = aggCall.type;
    return RelOptUtil.eq("aggCall type",
        expectedType,
        "inferred type",
        type,
        litmus);
  }

  /**
   * Returns whether any of the aggregates are DISTINCT.
   *
   * @return Whether any of the aggregates are DISTINCT
   */
  public boolean containsDistinctCall() {
    for (AggregateCall call : aggCalls) {
      if (call.isDistinct()) {
        return true;
      }
    }
    return false;
  }

  @Override public ImmutableList<RelHint> getHints() {
    return hints;
  }

  /**
   * Returns the type of roll-up.
   *
   * @return Type of roll-up
   */
  public Group getGroupType() {
    return Group.induce(groupSet, groupSets);
  }

  /** Describes the kind of roll-up. */
  public enum Group {
    SIMPLE,
    ROLLUP,
    CUBE,
    OTHER;

    public static Group induce(ImmutableBitSet groupSet,
        List<ImmutableBitSet> groupSets) {
      if (!ImmutableBitSet.ORDERING.isStrictlyOrdered(groupSets)) {
        throw new IllegalArgumentException("must be sorted: " + groupSets);
      }
      if (groupSets.size() == 1 && groupSets.get(0).equals(groupSet)) {
        return SIMPLE;
      }
      if (groupSets.size() == IntMath.pow(2, groupSet.cardinality())) {
        return CUBE;
      }
      if (isRollup(groupSet, groupSets)) {
        return ROLLUP;
      }
      return OTHER;
    }

    /** Returns whether a list of sets is a rollup.
     *
     * <p>For example, if {@code groupSet} is <code>{2, 4, 5}</code>, then
     * <code>[{2, 4, 5], {2, 5}, {5}, {}]</code> is a rollup. The first item is
     * equal to {@code groupSet}, and each subsequent item is a subset with one
     * fewer bit than the previous.
     *
     * @see #getRollup(List) */
    public static boolean isRollup(ImmutableBitSet groupSet,
        List<ImmutableBitSet> groupSets) {
      if (groupSets.size() != groupSet.cardinality() + 1) {
        return false;
      }
      ImmutableBitSet g = null;
      for (ImmutableBitSet bitSet : groupSets) {
        if (g == null) {
          // First item must equal groupSet
          if (!bitSet.equals(groupSet)) {
            return false;
          }
        } else {
          // Each subsequent items must be a subset with one fewer bit than the
          // previous item
          if (!g.contains(bitSet)
              || g.cardinality() - bitSet.cardinality() != 1) {
            return false;
          }
        }
        g = bitSet;
      }
      requireNonNull(g, "groupSet must not be empty");
      checkArgument(g.isEmpty());
      return true;
    }

    /** Returns the ordered list of bits in a rollup.
     *
     * <p>For example, given a {@code groupSets} value
     * <code>[{2, 4, 5], {2, 5}, {5}, {}]</code>, returns the list
     * {@code [5, 2, 4]}, which are the succession of bits
     * added to each of the sets starting with the empty set.
     *
     * @see #isRollup(ImmutableBitSet, List) */
    public static List<Integer> getRollup(List<ImmutableBitSet> groupSets) {
      final List<Integer> rollUpBits = new ArrayList<>(groupSets.size() - 1);
      ImmutableBitSet g = null;
      for (ImmutableBitSet bitSet : groupSets) {
        if (g == null) {
          // First item must equal groupSet
        } else {
          // Each subsequent items must be a subset with one fewer bit than the
          // previous item
          ImmutableBitSet diff = g.except(bitSet);
          assert diff.cardinality() == 1;
          rollUpBits.add(diff.nth(0));
        }
        g = bitSet;
      }
      Collections.reverse(rollUpBits);
      return ImmutableList.copyOf(rollUpBits);
    }
  }

  //~ Inner Classes ----------------------------------------------------------

  /**
   * Implementation of the {@link SqlOperatorBinding} interface for an
   * {@link AggregateCall aggregate call} applied to a set of operands in the
   * context of a {@link org.apache.calcite.rel.logical.LogicalAggregate}.
   */
  public static class AggCallBinding extends SqlOperatorBinding {
    private final List<RelDataType> preOperands;
    private final List<RelDataType> operands;
    @Deprecated // to be removed before 2.0
    private final int groupCount;
    private final boolean filter;
    private final boolean hasEmptyGroup;

    /**
     * Creates an AggCallBinding.
     *
     * @param typeFactory  Type factory
     * @param aggFunction  Aggregate function
     * @param preOperands  Data types of pre-operands
     * @param operands     Data types of operands
     * @param groupCount   Number of columns in the GROUP BY clause
     * @param filter       Whether the aggregate function has a FILTER clause
     *
     * @deprecated Use
     * {@link #AggCallBinding(RelDataTypeFactory, SqlAggFunction, List, List, boolean, boolean)}
     */
    @Deprecated // to be removed before 2.0
    public AggCallBinding(RelDataTypeFactory typeFactory,
        SqlAggFunction aggFunction, List<RelDataType> preOperands,
        List<RelDataType> operands, int groupCount,
        boolean filter) {
      super(typeFactory, aggFunction);
      this.preOperands = requireNonNull(preOperands, "preOperands");
      this.operands =
          requireNonNull(operands,
              "operands of aggregate call should not be null");
      this.groupCount = groupCount;
      this.hasEmptyGroup = groupCount == 0;
      this.filter = filter;
      checkArgument(groupCount >= 0,
          "number of group by columns should be greater than zero in "
              + "aggregate call. Got %s", groupCount);
    }

    /**
     * Creates an AggCallBinding.
     *
     * @param typeFactory   Type factory
     * @param aggFunction   Aggregate function
     * @param preOperands   Data types of pre-operands
     * @param operands      Data types of operands
     * @param hasEmptyGroup Whether the aggregate has a empty group
     * @param filter        Whether the aggregate function has a FILTER clause
     */
    public AggCallBinding(RelDataTypeFactory typeFactory,
        SqlAggFunction aggFunction, List<RelDataType> preOperands,
        List<RelDataType> operands, boolean hasEmptyGroup,
        boolean filter) {
      super(typeFactory, aggFunction);
      this.preOperands = requireNonNull(preOperands, "preOperands");
      this.operands =
          requireNonNull(operands,
              "operands of aggregate call should not be null");
      this.filter = filter;
      this.hasEmptyGroup = hasEmptyGroup;
      this.groupCount = hasEmptyGroup ? 0 : 1;
    }

    @Deprecated // to be removed before 2.0
    public AggCallBinding(RelDataTypeFactory typeFactory,
        SqlAggFunction aggFunction, List<RelDataType> operands, int groupCount,
        boolean filter) {
      this(typeFactory, aggFunction, ImmutableList.of(), operands, groupCount,
          filter);
    }

    @Deprecated // to be removed before 2.0
    @Override public int getGroupCount() {
      return groupCount;
    }

    @Override public boolean hasEmptyGroup() {
      return hasEmptyGroup;
    }

    @Override public boolean hasFilter() {
      return filter;
    }

    @Override public int getPreOperandCount() {
      return preOperands.size();
    }

    @Override public int getOperandCount() {
      return preOperands.size() + operands.size();
    }

    @Override public RelDataType getOperandType(int ordinal) {
      return ordinal < preOperands.size()
          ? preOperands.get(ordinal)
          : operands.get(ordinal - preOperands.size());
    }

    @Override public CalciteException newError(
        Resources.ExInst<SqlValidatorException> e) {
      return SqlUtil.newContextException(SqlParserPos.ZERO, e);
    }
  }

  /** Used for PERCENTILE_DISC return type inference. */
  public static class PercentileDiscAggCallBinding extends AggCallBinding {
    private final RelDataType collationType;

    PercentileDiscAggCallBinding(RelDataTypeFactory typeFactory, SqlAggFunction aggFunction,
        List<RelDataType> operands, RelDataType collationType, boolean hasEmptyGroup,
        boolean filter) {
      super(typeFactory, aggFunction, ImmutableList.of(), operands, hasEmptyGroup, filter);
      assert aggFunction.isPercentile();
      this.collationType = collationType;
    }

    @Override public RelDataType getCollationType() {
      return collationType;
    }
  }
}