AggregateGroupingSetsToUnionRule.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.rel.rules;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Values;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import com.google.common.collect.ImmutableList;
import org.immutables.value.Value;
import java.util.ArrayList;
import java.util.List;
/**
* Rule that converts a {@link org.apache.calcite.rel.core.Aggregate} with
* {@code GROUPING SETS} into a {@code UNION ALL} of simpler aggregates.
*
* <p>Example transformation:
* <pre>{@code
* SELECT a, b, c FROM t GROUP BY GROUPING SETS ((a,b), (a,c))
* }</pre>
*
* <p>Transformed to:
*
* <pre>{@code
* SELECT a, b, NULL AS c FROM t GROUP BY a, b
* UNION ALL
* SELECT a, NULL AS b, c FROM t GROUP BY a, c
* }</pre>
*/
@Value.Enclosing
public class AggregateGroupingSetsToUnionRule
extends RelRule<AggregateGroupingSetsToUnionRule.Config>
implements SubstitutionRule {
/** Creates an AggregateGroupingSetsToUnionRule. */
protected AggregateGroupingSetsToUnionRule(Config config) {
super(config);
}
//~ Methods ----------------------------------------------------------------
@Override public void onMatch(RelOptRuleCall call) {
final Aggregate aggregate = call.rel(0);
if (Aggregate.isSimple(aggregate)) {
return;
}
final RelBuilder relBuilder = call.builder();
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final RelNode input = aggregate.getInput();
final RelDataType rowType = aggregate.getRowType();
final ImmutableBitSet oriGroupSet = aggregate.getGroupSet();
final List<RelNode> unionInputs = new ArrayList<>();
for (ImmutableBitSet subGroupSet : aggregate.getGroupSets()) {
relBuilder.push(input);
final List<RexNode> subProjects = new ArrayList<>();
// Process aggregate group set
RelDataType subAggregateType =
Aggregate.deriveRowType(relBuilder.getTypeFactory(), relBuilder.peek().getRowType(),
false, subGroupSet, ImmutableList.of(subGroupSet), ImmutableList.of());
for (int i = 0; i < oriGroupSet.cardinality(); i++) {
int groupKey = oriGroupSet.nth(i);
if (subGroupSet.get(groupKey)) {
subProjects.add(
RexInputRef.of(
subGroupSet.indexOf(groupKey),
subAggregateType));
} else {
// If the groupKey is not in the GroupSet, use null as a placeholder.
subProjects.add(rexBuilder.makeNullLiteral(relBuilder.field(groupKey).getType()));
}
}
// Process aggregate calls
List<AggregateCall> subAggCalls = new ArrayList<>();
for (AggregateCall aggCall : aggregate.getAggCallList()) {
switch (aggCall.getAggregation().getKind()) {
case GROUPING:
int groupingValue = evaluateGroupingFunction(subGroupSet, aggCall.getArgList());
subProjects.add(
rexBuilder.makeLiteral(groupingValue, aggCall.getType(), true));
break;
case GROUP_ID:
// GROUP_ID is removed during RelNode conversion, no handling needed here.
return;
case GROUPING_ID:
// The GROUPING_ID aggregate function has been marked as deprecated
// and is no longer supported.
return;
default:
subProjects.add(
new RexInputRef(
subGroupSet.cardinality() + subAggCalls.size(),
aggCall.getType()));
subAggCalls.add(aggCall);
break;
}
}
relBuilder.aggregate(relBuilder.groupKey(subGroupSet), subAggCalls)
.project(subProjects, rowType.getFieldNames());
unionInputs.add(relBuilder.build());
}
relBuilder.pushAll(unionInputs)
.union(true, unionInputs.size());
call.transformTo(relBuilder.build());
}
private static int evaluateGroupingFunction(ImmutableBitSet groupSet, List<Integer> argIndices) {
final int argCount = argIndices.size();
if (argCount >= Integer.SIZE) {
throw new IllegalArgumentException(
"Too many grouping keys. Maximum is " + (Integer.SIZE - 1) + " for grouping functions.");
}
int result = 0;
for (int k = 0; k < argCount; k++) {
int index = argIndices.get(argCount - 1 - k);
if (!groupSet.get(index)) {
result |= 1 << k;
}
}
return result;
}
/** Rule configuration. */
@Value.Immutable
public interface Config extends RelRule.Config {
Config DEFAULT = ImmutableAggregateGroupingSetsToUnionRule.Config.of()
.withOperandFor(Aggregate.class, Values.class);
@Override default AggregateGroupingSetsToUnionRule toRule() {
return new AggregateGroupingSetsToUnionRule(this);
}
/** Defines an operand tree for the given classes. */
default Config withOperandFor(Class<? extends Aggregate> aggregateClass,
Class<? extends Values> valuesClass) {
return withOperandSupplier(b0 ->
b0.operand(aggregateClass)
.predicate(aggregate -> aggregate.getGroupType() != Aggregate.Group.SIMPLE)
.anyInputs())
.as(Config.class);
}
}
}