AggregateExtractProjectRule.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalTableScan;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.mapping.Mapping;
import org.apache.calcite.util.mapping.MappingType;
import org.apache.calcite.util.mapping.Mappings;

import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.List;

import static com.google.common.collect.ImmutableList.toImmutableList;

/**
 * Rule to extract a {@link org.apache.calcite.rel.core.Project}
 * from an {@link org.apache.calcite.rel.core.Aggregate}
 * and push it down towards the input.
 *
 * <p>What projections can be safely pushed down depends upon which fields the
 * Aggregate uses.
 *
 * <p>To prevent cycles, this rule will not extract a {@code Project} if the
 * {@code Aggregate}s input is already a {@code Project}.
 */
@Value.Enclosing
public class AggregateExtractProjectRule
    extends RelRule<AggregateExtractProjectRule.Config>
    implements TransformationRule {
  public static final AggregateExtractProjectRule SCAN =
      Config.DEFAULT.toRule();

  /** Creates an AggregateExtractProjectRule. */
  protected AggregateExtractProjectRule(Config config) {
    super(config);
  }

  @Deprecated // to be removed before 2.0
  public AggregateExtractProjectRule(
      Class<? extends Aggregate> aggregateClass,
      Class<? extends RelNode> inputClass,
      RelBuilderFactory relBuilderFactory) {
    this(Config.DEFAULT
        .withRelBuilderFactory(relBuilderFactory)
        .as(Config.class)
        .withOperandFor(aggregateClass, inputClass));
  }

  @Deprecated // to be removed before 2.0
  public AggregateExtractProjectRule(RelOptRuleOperand operand,
      RelBuilderFactory builderFactory) {
    this(Config.DEFAULT
        .withRelBuilderFactory(builderFactory)
        .withOperandSupplier(b -> b.exactly(operand))
        .as(Config.class));
  }

  @Override public void onMatch(RelOptRuleCall call) {
    final Aggregate aggregate = call.rel(0);
    final RelNode input = call.rel(1);
    // Compute which input fields are used.
    // 1. group fields are always used
    final ImmutableBitSet.Builder inputFieldsUsed =
        aggregate.getGroupSet().rebuild();
    // 2. agg functions
    for (AggregateCall aggCall : aggregate.getAggCallList()) {
      for (int i : aggCall.getArgList()) {
        inputFieldsUsed.set(i);
      }
      if (aggCall.filterArg >= 0) {
        inputFieldsUsed.set(aggCall.filterArg);
      }
    }
    final RelBuilder relBuilder = call.builder().push(input);
    final List<RexNode> projects = new ArrayList<>();
    final Mapping mapping =
        Mappings.create(MappingType.INVERSE_SURJECTION,
            aggregate.getInput().getRowType().getFieldCount(),
            inputFieldsUsed.cardinality());
    int j = 0;
    for (int i : inputFieldsUsed.build()) {
      projects.add(relBuilder.field(i));
      mapping.set(i, j++);
    }

    relBuilder.project(projects);

    final ImmutableBitSet newGroupSet =
        Mappings.apply(mapping, aggregate.getGroupSet());
    final List<ImmutableBitSet> newGroupSets =
        aggregate.getGroupSets().stream()
            .map(bitSet -> Mappings.apply(mapping, bitSet))
            .collect(toImmutableList());
    final List<RelBuilder.AggCall> newAggCallList =
        aggregate.getAggCallList().stream()
            .map(aggCall -> relBuilder.aggregateCall(aggCall, mapping))
            .collect(toImmutableList());

    final RelBuilder.GroupKey groupKey =
        relBuilder.groupKey(newGroupSet, newGroupSets);
    relBuilder.aggregate(groupKey, newAggCallList);
    call.transformTo(relBuilder.build());
  }

  /** Rule configuration. */
  @Value.Immutable
  public interface Config extends RelRule.Config {
    Config DEFAULT = ImmutableAggregateExtractProjectRule.Config.of()
        .withOperandFor(Aggregate.class, LogicalTableScan.class);

    @Override default AggregateExtractProjectRule toRule() {
      return new AggregateExtractProjectRule(this);
    }

    /** Defines an operand tree for the given classes. */
    default Config withOperandFor(Class<? extends Aggregate> aggregateClass,
        Class<? extends RelNode> inputClass) {
      return withOperandSupplier(b0 ->
          b0.operand(aggregateClass).oneInput(b1 ->
              b1.operand(inputClass)
                  // Predicate prevents matching against an Aggregate whose
                  // input is already a Project. Prevents this rule firing
                  // repeatedly.
                  .predicate(r -> !(r instanceof Project)).anyInputs()))
          .as(Config.class);
    }
  }
}