MeasureRules.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Sort;
import org.apache.calcite.rel.metadata.BuiltInMetadata;
import org.apache.calcite.rel.metadata.RelMdMeasure;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlInternalOperators;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.MonotonicSupplier;
import org.apache.calcite.util.Util;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;

import static com.google.common.collect.Iterables.getOnlyElement;

/**
 * Collection of planner rules that deal with measures.
 *
 * <p>A typical rule pushes down {@code M2V(measure)}
 * until it reaches a {@code V2M(expression)}.
 *
 * @see org.apache.calcite.sql.fun.SqlInternalOperators#M2V
 * @see org.apache.calcite.sql.fun.SqlInternalOperators#V2M
 */
public abstract class MeasureRules {

  private MeasureRules() { }

  /** Returns all rules. */
  public static Iterable<? extends RelOptRule> rules() {
    return ImmutableList.of(AGGREGATE2, PROJECT, PROJECT_SORT);
  }

  /** Rule that matches an {@link Aggregate}
   * that contains an {@code AGG_M2V} call
   * and pushes down the {@code AGG_M2V} call into a {@link Project}. */
  public static final RelOptRule AGGREGATE =
      AggregateMeasureRuleConfig.DEFAULT
          .toRule();

  /** Configuration for {@link AggregateMeasureRule}. */
  @Value.Immutable
  public interface AggregateMeasureRuleConfig extends RelRule.Config {
    AggregateMeasureRuleConfig DEFAULT = ImmutableAggregateMeasureRuleConfig.of()
        .withOperandSupplier(b ->
            b.operand(Aggregate.class)
                .predicate(b2 ->
                    b2.getAggCallList().stream().anyMatch(c ->
                        c.getAggregation() == SqlInternalOperators.AGG_M2V))
                .anyInputs());

    @Override default AggregateMeasureRule toRule() {
      return new AggregateMeasureRule(this);
    }
  }

  /** Rule that matches an {@link Aggregate} with at least one call to
   * {@link SqlInternalOperators#AGG_M2V} and converts those calls
   * to {@link SqlInternalOperators#M2X}.
   *
   * <p>Converts
   *
   * <pre>{@code
   * Aggregate(a, b, AGG_M2V(c), SUM(d), AGG_M2V(e))
   *   R
   * }</pre>
   *
   * <p>to
   *
   * <pre>{@code
   * Aggregate(a, b, SINGLE_VALUE(c), SUM(d), SINGLE_VALUE(e))
   *   Project(a, b, c, d, e, M2X(c, SAME_PARTITION(a, b)),
   *        M2X(e, SAME_PARTITION(a, b)))
   *     R
   * }</pre>
   *
   * <p>We rely on those {@code M2X} calls being pushed down until they merge
   * with {@code V2M2} and {@link ProjectMeasureRule} can apply.
   *
   * @see MeasureRules#AGGREGATE
   * @see AggregateMeasureRuleConfig */
  @SuppressWarnings("WeakerAccess")
  public static class AggregateMeasureRule
      extends RelRule<AggregateMeasureRuleConfig>
      implements TransformationRule {
    /** Creates a AggregateMeasureRule. */
    protected AggregateMeasureRule(AggregateMeasureRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final Aggregate aggregate = call.rel(0);
      final RelBuilder b = call.builder();
      b.push(aggregate.getInput());
      final List<Function<RelBuilder, RelBuilder.AggCall>> aggCallList =
          new ArrayList<>();
      final List<RexNode> extraProjects = new ArrayList<>();
      aggregate.getAggCallList().forEach(c -> {
        if (c.getAggregation().kind == SqlKind.AGG_M2V) {
          final int arg = getOnlyElement(c.getArgList());
          final int i = b.fields().size() + extraProjects.size();
          extraProjects.add(
              b.call(SqlInternalOperators.M2X, b.field(arg),
                  b.call(SqlInternalOperators.SAME_PARTITION,
                      b.fields(aggregate.getGroupSet()))));
          aggCallList.add(b2 ->
              b2.aggregateCall(SqlStdOperatorTable.SINGLE_VALUE, b2.field(i)));
        } else {
          aggCallList.add(b2 -> b2.aggregateCall(c));
        }
      });
      b.projectPlus(extraProjects);
      b.aggregate(
          b.groupKey(aggregate.getGroupSet(), aggregate.groupSets),
          bind(aggCallList).apply(b));
      call.transformTo(b.build());
    }

    /** Converts a list of functions into a function that returns a list.
     * It is named after the Monad bind operator. */
    private static <T, E> Function<T, List<E>> bind(List<Function<T, E>> list) {
      return t -> {
        final ImmutableList.Builder<E> builder = ImmutableList.builder();
        list.forEach(f -> builder.add(f.apply(t)));
        return builder.build();
      };
    }
  }

  /** Rule that merges an {@link Aggregate}
   * onto a {@code Project} that contains a {@code M2X} call. */
  // TODO rename field and class
  public static final RelOptRule PROJECT =
      ProjectMeasureRuleConfig.DEFAULT
          .toRule();

  /** Configuration for {@link ProjectMeasureRule}. */
  @Value.Immutable
  public interface ProjectMeasureRuleConfig extends RelRule.Config {
    ProjectMeasureRuleConfig DEFAULT = ImmutableProjectMeasureRuleConfig.of()
        .withOperandSupplier(b ->
            b.operand(Aggregate.class)
                .predicate(aggregate ->
                    aggregate.getAggCallList().stream().allMatch(c ->
                        c.getAggregation() == SqlStdOperatorTable.SINGLE_VALUE))
                .oneInput(b2 ->
                    b2.operand(Project.class)
                        .predicate(RexUtil.find(SqlKind.V2M)::inProject)
                        .anyInputs()));

    @Override default ProjectMeasureRule toRule() {
      return new ProjectMeasureRule(this);
    }
  }

  /** Rule that matches an {@link Aggregate}
   * that contains an {@code AGG_M2V} call
   * and pushes down the {@code AGG_M2V} call into a {@link Project}. */
  public static final RelOptRule AGGREGATE2 =
      AggregateMeasure2RuleConfig.DEFAULT
          .toRule();

  /** Configuration for {@link AggregateMeasure2Rule}. */
  @Value.Immutable
  public interface AggregateMeasure2RuleConfig extends RelRule.Config {
    AggregateMeasure2RuleConfig DEFAULT = ImmutableAggregateMeasure2RuleConfig.of()
        .withOperandSupplier(b ->
            b.operand(Aggregate.class)
                .predicate(b2 ->
                    b2.getAggCallList().stream().anyMatch(c ->
                        c.getAggregation() == SqlInternalOperators.AGG_M2V))
                .anyInputs());

    @Override default AggregateMeasure2Rule toRule() {
      return new AggregateMeasure2Rule(this);
    }
  }

  /** Rule that matches an {@link Aggregate} with at least one call to
   * {@link SqlInternalOperators#AGG_M2V} and expands these calls by
   * asking the measure for its expression.
   *
   * <p>Converts
   *
   * <pre>{@code
   * Aggregate(a, b, AGG_M2V(c), SUM(d), AGG_M2V(e))
   *   R
   * }</pre>
   *
   * <p>to
   *
   * <pre>{@code
   * Project(a, b, RexSubQuery(...), sum_d, RexSubQuery(...))
   *   Aggregate(a, b, SUM(d) AS sum_d)
   *     R
   * }</pre>
   *
   * <p>We will optimize those {@link org.apache.calcite.rex.RexSubQuery}
   * later. For example,
   *
   * <pre>{@code
   * SELECT deptno,
   *     (SELECT AVG(sal)
   *      FROM emp
   *      WHERE deptno = e.deptno)
   * FROM Emp
   * }</pre>
   *
   * <p>will become
   *
   * <pre>{@code
   * SELECT deptno, AVG(sal)
   * FROM emp
   * WHERE deptno = e.deptno
   * }</pre>
   *
   * @see org.apache.calcite.rel.metadata.RelMdMeasure
   * @see MeasureRules#AGGREGATE2
   * @see AggregateMeasure2RuleConfig */
  @SuppressWarnings("WeakerAccess")
  public static class AggregateMeasure2Rule
      extends RelRule<AggregateMeasure2RuleConfig>
      implements TransformationRule {
    /** Creates an AggregateMeasure2Rule. */
    protected AggregateMeasure2Rule(AggregateMeasure2RuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final RelMetadataQuery mq = call.getMetadataQuery();
      final Aggregate aggregate = call.rel(0);
      final RelBuilder b = call.builder();
      b.push(aggregate.getInput());
      final MonotonicSupplier<RexCorrelVariable> holder =
          MonotonicSupplier.empty();
      final List<Function<RelBuilder, RelBuilder.AggCall>> aggCallList =
          new ArrayList<>();
      final List<Function<RelBuilder, RexNode>> projects = new ArrayList<>();
      b.variable(holder)
          .let(b2 -> {
            aggregate.getGroupSet().forEachInt(i ->
                projects.add(b4 -> b4.field(i)));
            // Memoize the RelBuilder so we don't create more than one.
            @SuppressWarnings("FunctionalExpressionCanBeFolded")
            final Supplier<RelBuilder> builderSupplier =
                Suppliers.memoize(call::builder)::get;
            final BuiltInMetadata.Measure.Context context =
                RelMdMeasure.Contexts.forAggregate(aggregate, builderSupplier, holder.get());
            aggregate.getAggCallList().forEach(c -> {
              if (c.getAggregation().kind == SqlKind.AGG_M2V) {
                final int arg = getOnlyElement(c.getArgList());
                aggCallList.add(b3 ->
                    b3.aggregateCall(SqlInternalOperators.AGG_M2M,
                        b3.fields(c.getArgList()))
                        .filter(c.filterArg < 0 ? null : b3.field(c.filterArg)));
                final BuiltInMetadata.Measure.Context context2 =
                    new RelMdMeasure.DelegatingContext(context) {
                      @Override public List<RexNode> getFilters(RelBuilder b) {
                        final ImmutableList.Builder<RexNode> builder =
                            ImmutableList.builder();
                        builder.addAll(super.getFilters(b));
                        if (c.filterArg >= 0) {
                          builder.add(b.field(c.filterArg));
                        }
                        return builder.build();
                      }
                    };
                projects.add(b4 -> mq.expand(b4.peek(), arg, context2));
              } else {
                final int i =
                    aggregate.getGroupSet().cardinality() + aggCallList.size();
                aggCallList.add(b3 ->
                    b3.aggregateCall(c)
                        .filter(c.filterArg < 0 ? null : b3.field(c.filterArg)));
                projects.add(b4 -> b4.field(i));
              }
            });
            return b2;
          });
      b.aggregate(b.groupKey(aggregate.getGroupSet(), aggregate.groupSets),
          bind(aggCallList).apply(b));
      b.project(bind(projects).apply(b), aggregate.getRowType().getFieldNames(),
          false, ImmutableSet.of(holder.get().id));
      call.transformTo(b.build());
    }

    /** Converts a list of functions into a function that returns a list.
     * It is named after the Monad bind operator. */
    private static <T, E> Function<T, List<E>> bind(List<Function<T, E>> list) {
      return t -> {
        final ImmutableList.Builder<E> builder = ImmutableList.builder();
        list.forEach(f -> builder.add(f.apply(t)));
        return builder.build();
      };
    }
  }

  /** Rule that merges an {@link Aggregate} onto a {@link Project}.
   *
   * <p>Converts
   *
   * <pre>{@code
   * Aggregate(a, b, SINGLE_VALUE(d) AS e)
   *   Project(a, b, M2X(M2V(SUM(c) + 1), SAME_PARTITION(a, b)) AS d)
   *     R
   * }</pre>
   *
   * <p>to
   *
   * <pre>{@code
   * Project(a, b, sum_c + 1 AS e),
   *   Aggregate(a, b, SUM(c) AS sum_c)
   *     R
   * }</pre>
   *
   * @see ProjectMeasureRuleConfig */
  @SuppressWarnings("WeakerAccess")
  public static class ProjectMeasureRule
      extends RelRule<ProjectMeasureRuleConfig>
      implements TransformationRule {
    /** Creates a ProjectMeasureRule. */
    protected ProjectMeasureRule(ProjectMeasureRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final Aggregate aggregate = call.rel(0);
      final Project project = call.rel(1);
      final RelBuilder b = call.builder();
      b.push(project)
          .aggregateRex(
              b.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()),
              true,
              Util.transform(aggregate.getAggCallList(),
                  aggregateCall -> toRex(aggregateCall, project)));
      call.transformTo(b.build());
    }

    @SuppressWarnings("SwitchStatementWithTooFewBranches")
    private static RexNode toRex(AggregateCall aggregateCall, Project project) {
      switch (aggregateCall.getAggregation().kind) {
      case SINGLE_VALUE:
        final int arg = getOnlyElement(aggregateCall.getArgList());
        final RexNode e = project.getProjects().get(arg);
        switch (e.getKind()) {
        case M2X:
          final RexCall callM2x = (RexCall) e;
          switch (callM2x.operands.get(0).getKind()) {
          case V2M:
            final RexCall callV2m = (RexCall) callM2x.operands.get(0);
            return callV2m.operands.get(0);
          default:
            throw new UnsupportedOperationException();
          }
        default:
          throw new UnsupportedOperationException();
        }
      default:
        throw new UnsupportedOperationException();
      }
    }
  }

  /** Rule that matches a {@link Filter} that contains a {@code M2V} call
   * on top of a {@link Sort} and pushes down the {@code M2V} call. */
  public static final RelOptRule FILTER_SORT =
      FilterSortMeasureRuleConfig.DEFAULT
          .as(FilterSortMeasureRuleConfig.class)
          .toRule();

  /** Configuration for {@link FilterSortMeasureRule}. */
  @Value.Immutable
  public interface FilterSortMeasureRuleConfig extends RelRule.Config {
    FilterSortMeasureRuleConfig DEFAULT = ImmutableFilterSortMeasureRuleConfig.of()
        .withOperandSupplier(b ->
            b.operand(Filter.class)
                .oneInput(b2 -> b2.operand(Sort.class)
                    .anyInputs()));

    @Override default FilterSortMeasureRule toRule() {
      return new FilterSortMeasureRule(this);
    }
  }

  /** Rule that matches a {@link Filter} that contains a {@code M2V} call
   * on top of a {@link Sort} and pushes down the {@code M2V} call.
   *
   * @see MeasureRules#FILTER_SORT
   * @see FilterSortMeasureRuleConfig */
  @SuppressWarnings("WeakerAccess")
  public static class FilterSortMeasureRule
      extends RelRule<FilterSortMeasureRuleConfig>
      implements TransformationRule {
    /** Creates a FilterSortMeasureRule. */
    protected FilterSortMeasureRule(FilterSortMeasureRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final Filter filter = call.rel(0);
      final RexNode condition = filter.getCondition();
      if (condition.equals(filter.getCondition())) {
        return;
      }
      final RelBuilder relBuilder =
          relBuilderFactory.create(filter.getCluster(), null);
      relBuilder.push(filter.getInput())
          .filter(condition);
      call.transformTo(relBuilder.build());
    }
  }

  /** Rule that matches a {@link Project} that contains a {@code M2V} call
   * on top of a {@link Sort} and pushes down the {@code M2V} call. */
  public static final RelOptRule PROJECT_SORT =
      ProjectSortMeasureRuleConfig.DEFAULT
          .as(ProjectSortMeasureRuleConfig.class)
          .toRule();

  /** Rule that matches a {@link Project} that contains an {@code M2V} call
   * on top of a {@link Sort} and pushes down the {@code M2V} call.
   *
   * @see MeasureRules#PROJECT_SORT */
  @SuppressWarnings("WeakerAccess")
  public static class ProjectSortMeasureRule
      extends RelRule<ProjectSortMeasureRuleConfig>
      implements TransformationRule {
    /** Creates a ProjectSortMeasureRule. */
    protected ProjectSortMeasureRule(ProjectSortMeasureRuleConfig config) {
      super(config);
    }

    @Override public void onMatch(RelOptRuleCall call) {
      final Project project = call.rel(0);
      final Sort sort = call.rel(1);
      final RelBuilder relBuilder = call.builder();

      // Given
      //   Project [$0, 1 + M2V(2)]  (a)
      //     Sort $1 desc
      //       R
      // transform to
      //   Project [$0, 1 + $2]  (b)
      //     Sort $1 desc
      //       Project [$0, $1, M2V(2)]  (c)
      //         R
      //
      // projects is [$0, 1 + M2V(2)] (see a)
      // newProjects is [$0, 1 + $2]
      // map.keys() is [M2V(2)] (see c)

      final List<RexNode> projects = project.getAliasedProjects(relBuilder);
      final Map<RexCall, RexInputRef> map = new LinkedHashMap<>();
      final List<RexNode> newProjects =
          new RexShuttle() {
            @Override public RexNode visitCall(RexCall call) {
              if (call.getKind() == SqlKind.M2V) {
                return map.computeIfAbsent(call, c ->
                    relBuilder.getRexBuilder().makeInputRef(call.getType(),
                        projects.size() + map.size()));
              }
              return super.visitCall(call);
            }
          }.apply(projects);

      relBuilder.push(sort.getInput())
          .projectPlus(map.keySet())
          .sortLimit(sort.offset == null ? 0 : RexLiteral.intValue(sort.offset),
              sort.fetch == null ? -1 : RexLiteral.intValue(sort.fetch),
              sort.getSortExps())
          .project(newProjects);
      call.transformTo(relBuilder.build());
    }
  }

  /** Configuration for {@link ProjectSortMeasureRule}. */
  @Value.Immutable
  public interface ProjectSortMeasureRuleConfig extends RelRule.Config {
    ProjectSortMeasureRuleConfig DEFAULT =
        ImmutableProjectSortMeasureRuleConfig.of().withOperandSupplier(b ->
            b.operand(Project.class)
                .predicate(RexUtil.M2V_FINDER::inProject)
                .oneInput(b2 -> b2.operand(Sort.class)
                    .anyInputs()));

    @Override default ProjectSortMeasureRule toRule() {
      return new ProjectSortMeasureRule(this);
    }
  }

}