UnnestDecorrelateRule.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.rules;

import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Correlate;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Uncollect;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;

import org.immutables.value.Value;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import static java.util.Objects.requireNonNull;

/** Convert representations of a projected Unnest that use LogicalCorrelate into
 * simple Unnest representations.
 *
 * <p>Original plan:
 * LogicalProject // only uses rightmost columns of correlate, outerProject
 *   LogicalCorrelate(correlation=[$cor0], joinType=[inner], requiredColumns=[{...}])
 *     LeftSubquery
 *     LogicalProject (optional; innerProject)
 *       Uncollect
 *         LogicalProject(COL=[$cor0.ARRAY])
 *           LogicalValues(tuples=[[{ 0 }]])
 *
 * <p>is converted to
 *
 * <p>Resulting plan:
 * LogicalProject
 *   LogicalProject (optional)
 *     Uncollect
 *       LogicalProject
 *         LeftSubquery
 */
@Value.Enclosing
public class UnnestDecorrelateRule extends RelRule<UnnestDecorrelateRule.Config>
    implements TransformationRule {

  protected UnnestDecorrelateRule(UnnestDecorrelateRule.Config config) {
    super(config);
  }

  /** Given an expression and a correlationId, find whether the expression is a
   * sequence of field accesses that starts in the correlationId, i.e., it
   * has the form corId.field1.field2.
   *
   * @param expr   Expression to analyze
   * @param corId  Correlation id to search for
   * @param fieldsAccessed  On successful return, contains the list of fields accessed
   *                        in reverse order, e.g., (field2, field1)
   * @return  True if {@code expr} has the expected shape, false otherwise.
   */
  private boolean extractFieldReferences(
      RexNode expr, CorrelationId corId, List<RelDataTypeField> fieldsAccessed) {
    if (expr instanceof RexCorrelVariable) {
      RexCorrelVariable cv = (RexCorrelVariable) expr;
      return cv.id == corId;
    } else if (expr instanceof RexFieldAccess) {
      RexFieldAccess fieldAccess = (RexFieldAccess) expr;
      fieldsAccessed.add(fieldAccess.getField());
      return extractFieldReferences(fieldAccess.getReferenceExpr(), corId, fieldsAccessed);
    } else {
      return false;
    }
  }

  @Override public void onMatch(RelOptRuleCall call) {
    Project outerProject = call.rel(0);
    Correlate cor = call.rel(1);
    CorrelationId corId = cor.getCorrelationId();

    RelNode left = call.rel(2);
    int leftCount = left.getRowType().getFieldCount();
    ImmutableBitSet used = RelOptUtil.InputFinder.bits(outerProject.getProjects(), null);
    int firstUsed = used.nextSetBit(0);
    if (firstUsed != -1 && firstUsed < leftCount) {
      return;
    }

    int uncollectIndex = 3;
    Project innerProject = null;
    if (call.rel(uncollectIndex) instanceof Project) {
      innerProject = call.rel(3);
      uncollectIndex = 4;
    }

    Uncollect uncollect = call.rel(uncollectIndex);
    Project project = call.rel(uncollectIndex + 1);

    List<RexNode> projects = project.getProjects();
    if (projects.size() != 1) {
      return;
    }

    final RexNode projected = projects.get(0);
    final ArrayList<RelDataTypeField> fieldsAccessed = new ArrayList<>();
    if (!extractFieldReferences(projected, corId, fieldsAccessed)) {
      return;
    }

    final RelBuilder builder = call.builder();
    builder.push(left);

    // Last field constructed by builder
    RexNode field = null;
    // Fields are in reverse order
    Collections.reverse(fieldsAccessed);
    for (RelDataTypeField index : fieldsAccessed) {
      if (field != null) {
        field = builder.field(field, index.getName());
      } else {
        field = builder.field(index.getName());
      }
    }
    builder.project(requireNonNull(field, "field"))
        .uncollect(uncollect.getItemAliases(), uncollect.withOrdinality);
    if (innerProject != null) {
      builder.project(innerProject.getProjects());
    }
    final List<RexNode> shifted = RexUtil.shift(outerProject.getProjects(), -leftCount);
    builder.project(shifted);
    RelNode result = builder.build();
    call.transformTo(result);
  }

  /** Rule configuration. */
  @Value.Immutable
  public interface Config extends RelRule.Config {
    UnnestDecorrelateRule.Config BASE = ImmutableUnnestDecorrelateRule.Config.of();

    RelRule.Config DEFAULT = BASE
          .withOperandSupplier(b0 -> b0.operand(Project.class)
              .oneInput(b1 -> b1.operand(Correlate.class)
                  .inputs(b2 -> b2.operand(RelNode.class).anyInputs(),
                      b3 -> b3.operand(Uncollect.class)
                          .oneInput(b4 -> b4.operand(Project.class)
                              .oneInput(b5 -> b5.operand(LogicalValues.class).anyInputs())))));

    RelRule.Config WITH_PROJECT = BASE
        .withOperandSupplier(b0 -> b0.operand(Project.class)
            .oneInput(b1 -> b1.operand(Correlate.class)
                .inputs(b2 -> b2.operand(RelNode.class).anyInputs(),
                    b3 -> b3.operand(Project.class)
                        .oneInput(b4 -> b4.operand(Uncollect.class)
                            .oneInput(b5 -> b5.operand(Project.class)
                                .oneInput(b6 -> b6.operand(LogicalValues.class).anyInputs()))))));

    @Override default UnnestDecorrelateRule toRule() {
      return new UnnestDecorrelateRule(this);
    }
  }
}