IntersectToDistinctRule.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Intersect;
import org.apache.calcite.rel.logical.LogicalIntersect;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilder.AggCall;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Util;

import org.immutables.value.Value;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;

import static org.apache.calcite.util.Util.skipLast;

/**
 * Planner rule that translates a distinct
 * {@link org.apache.calcite.rel.core.Intersect}
 * (<code>all</code> = <code>false</code>)
 * into a group of operators composed of
 * {@link org.apache.calcite.rel.core.Union},
 * {@link org.apache.calcite.rel.core.Aggregate}, etc.
 *
 * <p>The rule has a configuration option to control whether it should also perform
 * a (partial) aggregation pushdown in the union branches (default behavior).
 *
 * @see org.apache.calcite.rel.rules.UnionToDistinctRule
 * @see CoreRules#INTERSECT_TO_DISTINCT
 * @see CoreRules#INTERSECT_TO_DISTINCT_NO_AGGREGATE_PUSHDOWN
 */
@Value.Enclosing
public class IntersectToDistinctRule
    extends RelRule<IntersectToDistinctRule.Config>
    implements TransformationRule {

  /** Creates an IntersectToDistinctRule. */
  protected IntersectToDistinctRule(Config config) {
    super(config);
  }

  @Deprecated // to be removed before 2.0
  public IntersectToDistinctRule(Class<? extends Intersect> intersectClass,
      RelBuilderFactory relBuilderFactory) {
    this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory)
        .as(Config.class)
        .withOperandFor(intersectClass));
  }

  //~ Methods ----------------------------------------------------------------
  @Override public void onMatch(RelOptRuleCall call) {
    if (config.isAggregatePushdown()) {
      onMatchAggregatePushdown(call);
    } else {
      onMatchAggregateOnUnion(call);
    }
  }

  /**
   * Variant not performing a partial aggregation pushdown.
   *
   * <p>Original query:
   * <pre>{@code
   * SELECT job FROM "scott".emp WHERE deptno = 10
   * INTERSECT
   * SELECT job FROM "scott".emp WHERE deptno = 20
   * }</pre>
   *
   * <p>Query after conversion:
   * <pre>{@code
   * SELECT job
   * FROM (
   *   SELECT job, 0 AS i FROM "scott".emp WHERE deptno = 10
   *   UNION ALL
   *   SELECT job, 1 AS i FROM "scott".emp WHERE deptno = 20
   * )
   * GROUP BY job
   * HAVING COUNT(*) FILTER (WHERE i = 0) > 0
   *    AND COUNT(*) FILTER (WHERE i = 1) > 0
   * }</pre>
   */
  public void onMatchAggregateOnUnion(RelOptRuleCall call) {
    final Intersect intersect = call.rel(0);
    if (intersect.all) {
      return; // nothing we can do
    }
    final RelBuilder relBuilder = call.builder();
    final int oriFieldCount = intersect.getRowType().getFieldCount();
    final int branchCount = intersect.getInputs().size();

    List<AggCall> aggCalls = new ArrayList<>(branchCount);
    for (int i = 0; i < branchCount; ++i) {
      relBuilder.push(intersect.getInputs().get(i));
      List<RexNode> fields = new ArrayList<>(relBuilder.fields());
      fields.add(relBuilder.alias(relBuilder.literal(i), "i"));
      relBuilder.project(fields);
      aggCalls.add(
          relBuilder.countStar(null).filter(
              relBuilder.equals(relBuilder.field(oriFieldCount),
                  relBuilder.literal(i)))
              .as("count_i" + i));
    }

    // create union and aggregate above all the branches
    relBuilder.union(true, branchCount)
        .aggregate(relBuilder.groupKey(ImmutableBitSet.range(oriFieldCount)), aggCalls);

    // Generate filter count_i{n} > 0 for each branch
    List<RexNode> filters = new ArrayList<>(branchCount);
    for (int i = 0; i < branchCount; i++) {
      filters.add(
          relBuilder.greaterThan(relBuilder.field(oriFieldCount + i),
          relBuilder.literal(0)));
    }
    relBuilder.filter(filters);

    // Project all but the last added field (e.g. count_i{n})
    relBuilder.project(skipLast(relBuilder.fields(), branchCount));
    call.transformTo(relBuilder.build());
  }

  /**
   * Variant performing a partial aggregation pushdown.
   *
   * <p>Original query:
   * <pre>{@code
   * SELECT job FROM "scott".emp WHERE deptno = 10
   * INTERSECT
   * SELECT job FROM "scott".emp WHERE deptno = 20
   * }</pre>
   *
   * <p>Query after conversion:
   * <pre>{@code
   * SELECT job
   * FROM (
   *   SELECT job, COUNT(*) AS c
   *   FROM (
   *     SELECT job, COUNT(*) FROM "scott".emp
   *     WHERE deptno = 10 GROUP BY job
   *     UNION ALL
   *     SELECT job, COUNT(*) FROM "scott".emp
   *     WHERE deptno = 20 GROUP BY job)
   *   GROUP BY job)
   * WHERE c = 2
   * }</pre>
   */
  public void onMatchAggregatePushdown(RelOptRuleCall call) {
    final Intersect intersect = call.rel(0);
    if (intersect.all) {
      return; // nothing we can do
    }
    final RelOptCluster cluster = intersect.getCluster();
    final RexBuilder rexBuilder = cluster.getRexBuilder();
    final RelBuilder relBuilder = call.builder();

    // 1st level aggregate: create an aggregate(col_0, ..., col_n, count(*)), for each branch
    for (RelNode input : intersect.getInputs()) {
      relBuilder.push(input);
      relBuilder.aggregate(relBuilder.groupKey(relBuilder.fields()),
          relBuilder.countStar(null));
    }

    // create a union above all the branches
    final int branchCount = intersect.getInputs().size();
    relBuilder.union(true, branchCount);
    final RelNode union = relBuilder.peek();

    // 2nd level aggregate: create an aggregate(col_0, ..., col_n, count(*)), for each branch
    // the index of the counter is union.getRowType().getFieldList().size() - 1
    final int fieldCount = union.getRowType().getFieldCount();

    final ImmutableBitSet groupSet =
        ImmutableBitSet.range(fieldCount - 1);
    relBuilder.aggregate(relBuilder.groupKey(groupSet),
        relBuilder.countStar(null));

    // add a filter count(*) = #branches
    relBuilder.filter(
        relBuilder.equals(relBuilder.field(fieldCount - 1),
            rexBuilder.makeBigintLiteral(new BigDecimal(branchCount))));

    // Project all but the last field
    relBuilder.project(Util.skipLast(relBuilder.fields()));

    // the schema for intersect distinct matches that of the relation,
    // built here with an extra last column for the count,
    // which is projected out by the final project we added
    call.transformTo(relBuilder.build());
  }

  /** Rule configuration. */
  @Value.Immutable
  public interface Config extends RelRule.Config {
    Config DEFAULT = ImmutableIntersectToDistinctRule.Config.of()
        .withOperandFor(LogicalIntersect.class);

    Config NO_AGGREGATE_PUSHDOWN = DEFAULT
        .withDescription("IntersectToDistinctRule(NoAggregatePushDown)")
        .as(Config.class)
        .withAggregatePushdown(false);

    @Override default IntersectToDistinctRule toRule() {
      return new IntersectToDistinctRule(this);
    }

    /** Defines an operand tree for the given classes. */
    default Config withOperandFor(Class<? extends Intersect> intersectClass) {
      return withOperandSupplier(b -> b.operand(intersectClass).anyInputs())
          .as(Config.class);
    }

    /** Whether to apply partial aggregate pushdown; default true. */
    @Value.Default default boolean isAggregatePushdown() {
      return true;
    }

    /** Sets {@link #isAggregatePushdown()} ()}. */
    Config withAggregatePushdown(boolean aggregatePushdown);
  }
}