Match.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.core;

import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelTraitSet;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.RelWriter;
import org.apache.calcite.rel.SingleRel;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexPatternFieldRef;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.fun.SqlBitOpAggFunction;
import org.apache.calcite.sql.fun.SqlMinMaxAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.util.ImmutableBitSet;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.ImmutableSortedSet;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeMap;
import java.util.TreeSet;

import static com.google.common.base.Preconditions.checkArgument;

import static java.util.Objects.requireNonNull;

/**
 * Relational expression that represent a MATCH_RECOGNIZE node.
 *
 * <p>Each output row has the columns defined in the measure statements.
 */
public abstract class Match extends SingleRel {
  //~ Instance fields ---------------------------------------------
  private static final String STAR = "*";
  protected final ImmutableMap<String, RexNode> measures;
  protected final RexNode pattern;
  protected final boolean strictStart;
  protected final boolean strictEnd;
  protected final boolean allRows;
  protected final RexNode after;
  protected final ImmutableMap<String, RexNode> patternDefinitions;
  protected final Set<RexMRAggCall> aggregateCalls;
  protected final Map<String, SortedSet<RexMRAggCall>> aggregateCallsPreVar;
  protected final ImmutableMap<String, SortedSet<String>> subsets;
  protected final ImmutableBitSet partitionKeys;
  protected final RelCollation orderKeys;
  protected final @Nullable RexNode interval;

  //~ Constructors -----------------------------------------------

  /**
   * Creates a Match.
   *
   * @param cluster Cluster
   * @param traitSet Trait set
   * @param input Input relational expression
   * @param rowType Row type
   * @param pattern Regular expression that defines pattern variables
   * @param strictStart Whether it is a strict start pattern
   * @param strictEnd Whether it is a strict end pattern
   * @param patternDefinitions Pattern definitions
   * @param measures Measure definitions
   * @param after After match definitions
   * @param subsets Subsets of pattern variables
   * @param allRows Whether all rows per match (false means one row per match)
   * @param partitionKeys Partition by columns
   * @param orderKeys Order by columns
   * @param interval Interval definition, null if WITHIN clause is not defined
   */
  protected Match(RelOptCluster cluster, RelTraitSet traitSet, RelNode input,
      RelDataType rowType, RexNode pattern,
      boolean strictStart, boolean strictEnd,
      Map<String, RexNode> patternDefinitions, Map<String, RexNode> measures,
      RexNode after, Map<String, ? extends SortedSet<String>> subsets,
      boolean allRows, ImmutableBitSet partitionKeys, RelCollation orderKeys,
      @Nullable RexNode interval) {
    super(cluster, traitSet, input);
    this.rowType = requireNonNull(rowType, "rowType");
    this.pattern = requireNonNull(pattern, "pattern");
    checkArgument(!patternDefinitions.isEmpty());
    this.strictStart = strictStart;
    this.strictEnd = strictEnd;
    this.patternDefinitions = ImmutableMap.copyOf(patternDefinitions);
    this.measures = ImmutableMap.copyOf(measures);
    this.after = requireNonNull(after, "after");
    this.subsets = copyMap(subsets);
    this.allRows = allRows;
    this.partitionKeys = requireNonNull(partitionKeys, "partitionKeys");
    this.orderKeys = requireNonNull(orderKeys, "orderKeys");
    this.interval = interval;

    final AggregateFinder aggregateFinder = new AggregateFinder();
    for (RexNode rex : this.patternDefinitions.values()) {
      if (rex instanceof RexCall) {
        aggregateFinder.go((RexCall) rex);
      }
    }

    for (RexNode rex : this.measures.values()) {
      if (rex instanceof RexCall) {
        aggregateFinder.go((RexCall) rex);
      }
    }

    aggregateCalls = ImmutableSortedSet.copyOf(aggregateFinder.aggregateCalls);
    aggregateCallsPreVar =
        copyMap(aggregateFinder.aggregateCallsPerVar);
  }

  /** Creates an immutable map of a map of sorted sets. */
  private static <K extends Comparable<K>, V>
      ImmutableSortedMap<K, SortedSet<V>>
      copyMap(Map<K, ? extends SortedSet<V>> map) {
    final ImmutableSortedMap.Builder<K, SortedSet<V>> b =
        ImmutableSortedMap.naturalOrder();
    for (Map.Entry<K, ? extends SortedSet<V>> e : map.entrySet()) {
      b.put(e.getKey(), ImmutableSortedSet.copyOf(e.getValue()));
    }
    return b.build();
  }

  //~ Methods --------------------------------------------------

  public ImmutableMap<String, RexNode> getMeasures() {
    return measures;
  }

  public RexNode getAfter() {
    return after;
  }

  public RexNode getPattern() {
    return pattern;
  }

  public boolean isStrictStart() {
    return strictStart;
  }

  public boolean isStrictEnd() {
    return strictEnd;
  }

  public boolean isAllRows() {
    return allRows;
  }

  public ImmutableMap<String, RexNode> getPatternDefinitions() {
    return patternDefinitions;
  }

  public ImmutableMap<String, SortedSet<String>> getSubsets() {
    return subsets;
  }

  public ImmutableBitSet getPartitionKeys() {
    return partitionKeys;
  }

  public RelCollation getOrderKeys() {
    return orderKeys;
  }

  public @Nullable RexNode getInterval() {
    return interval;
  }

  @Override public RelWriter explainTerms(RelWriter pw) {
    return super.explainTerms(pw)
        .item("partition", getPartitionKeys().asList())
        .item("order", getOrderKeys())
        .item("outputFields", getRowType().getFieldNames())
        .item("allRows", isAllRows())
        .item("after", getAfter())
        .item("pattern", getPattern())
        .item("isStrictStarts", isStrictStart())
        .item("isStrictEnds", isStrictEnd())
        .itemIf("interval", getInterval(), getInterval() != null)
        .item("subsets", getSubsets().values().asList())
        .item("patternDefinitions", getPatternDefinitions().values().asList())
        .item("inputFields", getInput().getRowType().getFieldNames());
  }

  /**
   * Find aggregate functions in operands.
   */
  private static class AggregateFinder extends RexVisitorImpl<Void> {
    final NavigableSet<RexMRAggCall> aggregateCalls = new TreeSet<>();
    final Map<String, NavigableSet<RexMRAggCall>> aggregateCallsPerVar =
        new TreeMap<>();

    AggregateFinder() {
      super(true);
    }

    @Override public Void visitCall(RexCall call) {
      SqlAggFunction aggFunction = null;
      switch (call.getKind()) {
      case SUM:
        aggFunction = new SqlSumAggFunction(call.getType());
        break;
      case SUM0:
        aggFunction = new SqlSumEmptyIsZeroAggFunction();
        break;
      case MAX:
      case MIN:
        aggFunction = new SqlMinMaxAggFunction(call.getKind());
        break;
      case COUNT:
        aggFunction = SqlStdOperatorTable.COUNT;
        break;
      case ANY_VALUE:
        aggFunction = SqlStdOperatorTable.ANY_VALUE;
        break;
      case BIT_AND:
      case BIT_OR:
      case BIT_XOR:
        aggFunction = new SqlBitOpAggFunction(call.getKind());
        break;
      default:
        visitEach(call.operands);
      }
      if (aggFunction != null) {
        RexMRAggCall aggCall =
            new RexMRAggCall(aggFunction, call.getType(), call.getOperands(),
                aggregateCalls.size());
        aggregateCalls.add(aggCall);
        Set<String> pv = new PatternVarFinder().go(call.getOperands());
        if (pv.isEmpty()) {
          pv.add(STAR);
        }
        for (String alpha : pv) {
          final NavigableSet<RexMRAggCall> set;
          if (aggregateCallsPerVar.containsKey(alpha)) {
            set = aggregateCallsPerVar.get(alpha);
          } else {
            set = new TreeSet<>();
            aggregateCallsPerVar.put(alpha, set);
          }
          boolean update = true;
          for (RexMRAggCall rex : set) {
            if (rex.equals(aggCall)) {
              update = false;
              break;
            }
          }
          if (update) {
            set.add(aggCall);
          }
        }
      }
      return null;
    }

    public void go(RexCall call) {
      call.accept(this);
    }
  }

  /**
   * Visits the operands of an aggregate call to retrieve relevant pattern
   * variables.
   */
  private static class PatternVarFinder extends RexVisitorImpl<Void> {
    final Set<String> patternVars = new HashSet<>();

    PatternVarFinder() {
      super(true);
    }

    @Override public Void visitPatternFieldRef(RexPatternFieldRef fieldRef) {
      patternVars.add(fieldRef.getAlpha());
      return null;
    }

    @Override public Void visitCall(RexCall call) {
      visitEach(call.operands);
      return null;
    }

    public Set<String> go(RexNode rex) {
      rex.accept(this);
      return patternVars;
    }

    public Set<String> go(List<RexNode> rexNodeList) {
      visitEach(rexNodeList);
      return patternVars;
    }
  }

  /**
   * Aggregate calls in match recognize.
   */
  public static final class RexMRAggCall extends RexCall
      implements Comparable<RexMRAggCall> {
    public final int ordinal;

    RexMRAggCall(SqlAggFunction aggFun,
        RelDataType type,
        List<RexNode> operands,
        int ordinal) {
      super(type, aggFun, operands);
      this.ordinal = ordinal;
      digest = toString(); // can compute here because class is final
    }

    @Override public int compareTo(RexMRAggCall o) {
      return toString().compareTo(o.toString());
    }

    @Override public boolean equals(@Nullable Object obj) {
      return obj == this
          || obj instanceof RexMRAggCall
          && toString().equals(obj.toString());
    }

    @Override public int hashCode() {
      return toString().hashCode();
    }
  }
}