AggregateReduceFunctionsOnGroupKeysRule.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelCollations;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.tools.RelBuilder;

import org.checkerframework.checker.nullness.qual.Nullable;
import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.List;

/**
 * Planner rule that eliminates aggregate functions of GROUP BY keys.
 *
 * <p>For example,
 * {@code SELECT sal, max(sal) FROM emp GROUP BY sal}
 * can be simplified to
 * {@code SELECT sal, sal FROM emp GROUP BY sal}.
 *
 * <p>Currently supports the following aggregate functions when their
 * arguments exist in the aggregate's group set or are deterministic
 * expressions involving only group set columns and constants:
 * <ul>
 *   <li>{@code MAX}</li>
 *   <li>{@code MIN}</li>
 *   <li>{@code AVG}</li>
 *   <li>{@code ANY_VALUE}</li>
 * </ul>
 *
 * <p>Note: This optimization preserves NULL semantics correctly. For aggregate
 * functions like MAX, MIN, and ANY_VALUE, NULL values in the source columns or
 * expressions are handled the same way before and after the transformation:
 * nulls are ignored by the aggregation, and if all grouped values are NULL,
 * the result is NULL.
 *
 * @see CoreRules#AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS
 */
@Value.Enclosing
public class AggregateReduceFunctionsOnGroupKeysRule
    extends RelRule<AggregateReduceFunctionsOnGroupKeysRule.Config>
    implements TransformationRule {

  /** Creates an AggregateReduceFunctionsOnGroupKeysRule. */
  protected AggregateReduceFunctionsOnGroupKeysRule(Config config) {
    super(config);
  }

  @Override public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final List<AggregateCall> oldCalls = aggregate.getAggCallList();
    final int groupCount = aggregate.getGroupCount();
    final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
    final RelBuilder relBuilder = call.builder();

    final List<AggregateCall> newCalls = new ArrayList<>();
    final List<RexNode> projects = new ArrayList<>();
    final List<String> fieldNames =
        new ArrayList<>(aggregate.getRowType().getFieldNames());

    // Pass through group keys.
    for (int i = 0; i < groupCount; i++) {
      projects.add(rexBuilder.makeInputRef(aggregate, i));
    }

    boolean changed = false;
    int newCallOrdinal = 0;
    for (AggregateCall oldCall : oldCalls) {
      final @Nullable RexNode reduced = reduce(aggregate, oldCall, rexBuilder);
      if (reduced != null) {
        projects.add(reduced);
        changed = true;
      } else {
        newCalls.add(oldCall);
        projects.add(
            rexBuilder.makeInputRef(
                oldCall.getType(), groupCount + newCallOrdinal));
        newCallOrdinal++;
      }
    }

    if (!changed) {
      return;
    }

    final RelNode newAggregate =
        aggregate.copy(
            aggregate.getTraitSet(),
            aggregate.getInput(),
            aggregate.getGroupSet(),
            aggregate.getGroupSets(),
            newCalls);
    relBuilder.push(newAggregate);
    relBuilder.project(projects, fieldNames);
    call.transformTo(relBuilder.build());
  }

  /**
   * Tries to reduce an aggregate call to a reference to a group-by key
   * or to an expression involving only group-by keys and constants.
   *
   * @return the reduced expression, or null if cannot reduce
   */
  private static @Nullable RexNode reduce(
      Aggregate aggregate,
      AggregateCall call,
      RexBuilder rexBuilder) {
    if (!Aggregate.isSimple(aggregate)) {
      return null;
    }
    if (call.hasFilter()
        || call.distinctKeys != null
        || call.collation != RelCollations.EMPTY) {
      return null;
    }
    final SqlKind kind = call.getAggregation().getKind();
    switch (kind) {
    case AVG:
    case MAX:
    case MIN:
    case ANY_VALUE:
      break;
    default:
      return null;
    }
    final List<Integer> argList = call.getArgList();
    if (argList.size() != 1) {
      return null;
    }
    final int arg = argList.get(0);

    // Case 1: argument directly references a group-by key
    if (aggregate.getGroupSet().get(arg)) {
      final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
      RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList());
      if (!ref.getType().equals(call.getType())) {
        ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref);
      }
      return ref;
    }

    // Case 2: argument is an expression in a Project below the Aggregate
    RelNode input = aggregate.getInput();
    if (input instanceof HepRelVertex) {
      input = ((HepRelVertex) input).getCurrentRel();
    }
    if (!(input instanceof Project)) {
      return null;
    }
    final Project project = (Project) input;
    if (arg < 0 || arg >= project.getProjects().size()) {
      return null;
    }
    final RexNode expr = project.getProjects().get(arg);
    if (!RexUtil.isDeterministic(expr)) {
      return null;
    }
    // Check that all columns referenced in the expression are group-by keys.
    // This ensures that the expression value is constant within each group.
    final @Nullable RexNode translated =
        translateToGroupRefs(expr, project, aggregate);
    if (translated == null) {
      return null;
    }
    if (!translated.getType().equals(call.getType())) {
      return rexBuilder.makeCast(call.getParserPosition(), call.getType(), translated);
    }
    return translated;
  }

  /**
   * Translates an expression so that its {@link RexInputRef}s reference
   * the group keys of the aggregate rather than the input to the project.
   *
   * @return the translated expression, or null if the expression references
   * columns that are not group-by keys
   */
  private static @Nullable RexNode translateToGroupRefs(
      RexNode expr, Project project, Aggregate aggregate) {
    final List<RexNode> projects = project.getProjects();
    final GroupRefTranslator translator = new GroupRefTranslator(projects, aggregate);
    final RexNode result = expr.accept(translator);
    return translator.failed ? null : result;
  }

  /**
   * Shuttle that translates input refs to aggregate group key refs.
   *
   * <p>For each column reference in the expression being examined:
   * 1. If the expression is a direct pass-through of a project column,
   *    check if that project column is in the GROUP BY set
   * 2. If the expression contains references to input columns,
   *    verify that those input columns are in the GROUP BY set
   * 3. Map to the corresponding group key index in the aggregate
   *
   * <p>This ensures the expression references only columns that are constant
   * within each group.
   */
  private static class GroupRefTranslator extends RexShuttle {
    private final List<RexNode> projects;
    private final Aggregate aggregate;
    private boolean failed = false;

    GroupRefTranslator(List<RexNode> projects, Aggregate aggregate) {
      this.projects = projects;
      this.aggregate = aggregate;
    }

    @Override public RexNode visitInputRef(RexInputRef inputRef) {
      if (failed) {
        return inputRef;
      }
      final int inputIndex = inputRef.getIndex();
      // Look for a project column that is a direct pass-through of this input.
      // For example, if a project has SAL=[$5], and the expression references $5,
      // we need to map it to the corresponding group key.
      int projectOutputIndex = -1;
      for (int i = 0; i < projects.size(); i++) {
        final RexNode projExpr = projects.get(i);
        if (projExpr instanceof RexInputRef
            && ((RexInputRef) projExpr).getIndex() == inputIndex) {
          projectOutputIndex = i;
          break;
        }
      }
      // The input column must be available through a project column that is in
      // the GROUP BY set. If not found, the input is embedded in a computed
      // expression, which means the optimization cannot proceed safely.
      if (projectOutputIndex < 0
          || !aggregate.getGroupSet().get(projectOutputIndex)) {
        failed = true;
        return inputRef;
      }
      final int groupIndex =
          aggregate.getGroupSet().asList().indexOf(projectOutputIndex);
      return RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList());
    }
  }

  /** Rule configuration. */
  @Value.Immutable
  public interface Config extends RelRule.Config {
    Config DEFAULT = ImmutableAggregateReduceFunctionsOnGroupKeysRule.Config.of()
        .withRelBuilderFactory(RelFactories.LOGICAL_BUILDER)
        .withOperandFor(LogicalAggregate.class);

    @Override default AggregateReduceFunctionsOnGroupKeysRule toRule() {
      return new AggregateReduceFunctionsOnGroupKeysRule(this);
    }

    /** Defines an operand tree for the given class. */
    default Config withOperandFor(Class<? extends Aggregate> aggregateClass) {
      return withOperandSupplier(b ->
          b.operand(aggregateClass)
              .predicate(Aggregate::isSimple)
              .anyInputs())
          .as(Config.class);
    }
  }
}