DispatchGenerator.java

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to you under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.calcite.rel.metadata.janino;

import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.metadata.MetadataHandler;
import org.apache.calcite.rel.metadata.RelMetadataQuery;

import com.google.common.collect.ImmutableSet;

import org.checkerframework.checker.nullness.qual.Nullable;

import java.lang.reflect.Method;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.apache.calcite.linq4j.Nullness.castNonNull;
import static org.apache.calcite.rel.metadata.janino.CodeGeneratorUtil.argList;
import static org.apache.calcite.rel.metadata.janino.CodeGeneratorUtil.paramList;

/**
 * Generates the metadata dispatch to handlers.
 */
class DispatchGenerator {
  private final Map<MetadataHandler<?>, String> metadataHandlerToName;

  DispatchGenerator(Map<MetadataHandler<?>, String> metadataHandlerToName) {
    this.metadataHandlerToName = metadataHandlerToName;
  }

  void dispatchMethod(StringBuilder buff, Method method,
      Collection<? extends MetadataHandler<?>> metadataHandlers) {
    Map<MetadataHandler<?>, Set<Class<? extends RelNode>>> handlersToClasses =
        metadataHandlers.stream()
            .distinct()
            .collect(
                Collectors.toMap(
                    Function.identity(),
                    mh -> methodAndInstanceToImplementingClass(method, mh)));

    Set<Class<? extends RelNode>> delegateClassSet = handlersToClasses.values().stream()
        .flatMap(Set::stream)
        .collect(Collectors.toSet());
    List<Class<? extends RelNode>> delegateClassList = topologicalSort(delegateClassSet);
    buff
        .append("  private ")
        .append(method.getReturnType().getName())
        .append(" ")
        .append(method.getName())
        .append("_(\n")
        .append("      ")
        .append(RelNode.class.getName())
        .append(" r,\n")
        .append("      ")
        .append(RelMetadataQuery.class.getName())
        .append(" mq");
    paramList(buff, method)
        .append(") {\n");
    if (delegateClassList.isEmpty()) {
      throwUnknown(buff.append("    "), method)
          .append("  }\n");
    } else {
      buff
          .append(
              delegateClassList.stream()
                  .map(clazz ->
                      ifInstanceThenDispatch(method,
                          metadataHandlers, handlersToClasses, clazz))
                  .collect(
                      Collectors.joining("    } else if ",
                          "    if ", "    } else {\n")));
      throwUnknown(buff.append("      "), method)
          .append("    }\n")
          .append("  }\n");
    }
  }

  private StringBuilder ifInstanceThenDispatch(Method method,
      Collection<? extends MetadataHandler<?>> metadataHandlers,
      Map<MetadataHandler<?>, Set<Class<? extends RelNode>>> handlersToClasses,
      Class<? extends RelNode> clazz) {
    String handlerName = findProvider(metadataHandlers, handlersToClasses, clazz);
    StringBuilder buff = new StringBuilder()
        .append("(r instanceof ").append(clazz.getName()).append(") {\n")
        .append("      return ");
    dispatchedCall(buff, handlerName, method, clazz);

    return buff;
  }

  private String findProvider(Collection<? extends MetadataHandler<?>> metadataHandlers,
      Map<MetadataHandler<?>, Set<Class<? extends RelNode>>> handlerToClasses,
      Class<? extends RelNode> clazz) {
    for (MetadataHandler<?> mh : metadataHandlers) {
      if (handlerToClasses.getOrDefault(mh, ImmutableSet.of()).contains(clazz)) {
        return castNonNull(this.metadataHandlerToName.get(mh));
      }
    }
    throw new RuntimeException();
  }

  private static StringBuilder throwUnknown(StringBuilder buff, Method method) {
    return buff
        .append("      throw new ")
        .append(IllegalArgumentException.class.getName())
        .append("(\"No handler for method [").append(method)
        .append("] applied to argument of type [\" + r.getClass() + ")
        .append("\"]; we recommend you create a catch-all (RelNode) handler\"")
        .append(");\n");
  }

  private static void dispatchedCall(StringBuilder buff, String handlerName, Method method,
      Class<? extends RelNode> clazz) {
    buff.append(handlerName).append(".").append(method.getName())
        .append("((").append(clazz.getName()).append(") r, mq");
    argList(buff, method);
    buff.append(");\n");
  }

  private static Set<Class<? extends RelNode>> methodAndInstanceToImplementingClass(
      Method method, MetadataHandler<?> handler) {
    Set<Class<? extends RelNode>> set = new HashSet<>();
    for (Method m : handler.getClass().getMethods()) {
      Class<? extends RelNode> aClass = toRelClass(method, m);
      if (aClass != null) {
        set.add(aClass);
      }
    }
    return set;
  }

  private static @Nullable Class<? extends RelNode> toRelClass(Method superMethod,
      Method candidate) {
    if (!superMethod.getName().equals(candidate.getName())) {
      return null;
    } else if (superMethod.getParameterCount() != candidate.getParameterCount()) {
      return null;
    } else {
      Class<?>[] cpt = candidate.getParameterTypes();
      Class<?>[] smpt = superMethod.getParameterTypes();
      if (!RelNode.class.isAssignableFrom(cpt[0])) {
        return null;
      } else if (!RelMetadataQuery.class.equals(cpt[1])) {
        return null;
      }
      for (int i = 2; i < smpt.length; i++) {
        if (cpt[i] != smpt[i]) {
          return null;
        }
      }
      return (Class<? extends RelNode>) cpt[0];
    }
  }

  private static List<Class<? extends RelNode>> topologicalSort(
      Collection<Class<? extends RelNode>> list) {
    List<Class<? extends RelNode>> l = new ArrayList<>();
    ArrayDeque<Class<? extends RelNode>> s = list.stream()
        .sorted(Comparator.comparing(Class::getName))
        .collect(Collectors.toCollection(ArrayDeque::new));

    while (!s.isEmpty()) {
      Class<? extends RelNode> n = s.remove();

      boolean found = false;
      for (Class<? extends RelNode> other : s) {
        if (n.isAssignableFrom(other)) {
          found = true;
          break;
        }
      }
      if (found) {
        s.add(n);
      } else {
        l.add(n);
      }
    }
    return l;
  }
}