NativeElementsHelper.java
/*
* Copyright 2017-2024 original authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.micronaut.inject.utils;
import io.micronaut.core.annotation.Internal;
import io.micronaut.core.annotation.NonNull;
import io.micronaut.core.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* The native elements helper.
*
* @param <C> The class native element type
* @param <M> The method native element type
* @author Denis Stepanov
* @since 4.3.0
*/
@Internal
public abstract class NativeElementsHelper<C, M> {
private final Set<Object> processedClasses = new HashSet<>();
private final Map<MethodCacheKey, Collection<M>> overridesCache = new HashMap<>();
/**
* Check if one method overrides another.
*
* @param m1 The override method
* @param m2 The overridden method
* @param owner The class owner of the methods
* @return true if overridden
*/
protected abstract boolean overrides(M m1, M m2, C owner);
/**
* Gets the element name.
*
* @param element The element
* @return The name
*/
@NonNull
protected abstract String getMethodName(M element);
/**
* Extracts the super class.
*
* @param classNode The class
* @return The super calss
*/
@Nullable
protected abstract C getSuperClass(C classNode);
/**
* Extracts the interfaces of the class.
*
* @param classNode The class
* @return The interfaces
*/
@NonNull
protected abstract Collection<C> getInterfaces(C classNode);
/**
* Extracts the enclosed elements of the class.
*
* @param classNode The class
* @return The enclosed elements
*/
@NonNull
protected abstract List<M> getMethods(C classNode);
/**
* Checks if the class needs to be excluded.
*
* @param classNode The class
* @return true if to exclude
*/
protected abstract boolean excludeClass(C classNode);
/**
* Is interface.
*
* @param classNode The class node
* @return true if interface
*/
protected abstract boolean isInterface(C classNode);
/**
* Get a class cache key.
*
* @param classElement The class element
* @return a new key or the previous value
*/
protected Object getClassCacheKey(C classElement) {
return classElement;
}
/**
* Get a method cache key.
*
* @param methodElement The method element
* @return a new key or the previous value
*/
protected Object getMethodCacheKey(M methodElement) {
return methodElement;
}
/**
* Populate with the class hierarchy.
*
* @param element The element
* @param hierarchy The hierarchy
*/
public final void populateTypeHierarchy(C element, List<C> hierarchy) {
for (C anInterface : getInterfaces(element)) {
populateTypeHierarchy(anInterface, hierarchy);
}
C superClass = getSuperClass(element);
if (superClass != null) {
populateTypeHierarchy(superClass, hierarchy);
}
if (!excludeClass(element)) {
hierarchy.add(element);
}
}
/**
* Find overridden methods.
*
* @param classNode The class of the method
* @param methodElement The method
* @return the overridden methods
*/
public final Collection<M> findOverriddenMethods(C classNode, M methodElement) {
Object classCacheKey = getClassCacheKey(classNode);
MethodCacheKey methodCacheKey = new MethodCacheKey(
classCacheKey,
getMethodCacheKey(methodElement)
);
Collection<M> overriddenMethods = overridesCache.get(methodCacheKey);
if (overriddenMethods != null) {
return overriddenMethods;
}
if (processedClasses.contains(classCacheKey)) {
return List.of();
}
List<MethodElement<M>> allElements = getAllElements(classNode);
for (MethodElement<M> method : allElements) {
if (method.overridden.isEmpty()) {
continue;
}
overridesCache.put(
new MethodCacheKey(
classCacheKey,
getMethodCacheKey(method.methodElement)
),
method.overridden
);
}
processedClasses.add(classCacheKey);
return overridesCache.getOrDefault(methodCacheKey, List.of());
}
private List<MethodElement<M>> getAllElements(C classNode) {
List<MethodElement<M>> elements = new LinkedList<>();
List<MethodElement<M>> cache = new ArrayList<>(20);
if (isInterface(classNode)) {
processInterfaceHierarchy(classNode, classNode, cache, elements, true);
} else {
processClassHierarchy(classNode, classNode, cache, elements, true);
}
return elements;
}
private void processClassHierarchy(C owner,
C classNode,
List<MethodElement<M>> cache,
List<MethodElement<M>> collectedMethods,
boolean includeAbstract) {
if (excludeClass(classNode)) {
return;
}
C superClass = getSuperClass(classNode);
if (superClass != null) {
processClassHierarchy(owner, superClass, cache, collectedMethods, includeAbstract);
}
reduce(owner, collectedMethods, getMethods(classNode), cache, false, false);
for (C anInterface : getInterfaces(classNode)) {
processInterfaceHierarchy(owner, anInterface, cache, collectedMethods, includeAbstract);
}
}
private void processInterfaceHierarchy(C owner,
C classNode,
List<MethodElement<M>> cache,
Collection<MethodElement<M>> collectedMethods,
boolean includeAbstract) {
if (excludeClass(classNode)) {
return;
}
for (C anInterface : getInterfaces(classNode)) {
processInterfaceHierarchy(owner, anInterface, cache, collectedMethods, includeAbstract);
}
reduce(owner, collectedMethods, getMethods(classNode), cache, true, includeAbstract);
}
private void reduce(C owner,
Collection<MethodElement<M>> collectedMethods,
List<M> newMethodElements,
List<MethodElement<M>> cache,
boolean isInterface,
boolean includesAbstract) {
cache.clear(); // Reusing this collection for all the calls
classElements:
for (M newElement : newMethodElements) {
for (Iterator<MethodElement<M>> iterator = collectedMethods.iterator(); iterator.hasNext(); ) {
MethodElement<M> existingEntry = iterator.next();
M existingElement = existingEntry.methodElement;
if (!getMethodName(existingElement).equals(getMethodName(newElement))) {
continue;
}
LinkedHashSet<M> overridden = existingEntry.overridden;
if (isInterface) {
if (existingElement == newElement) {
continue classElements;
}
if (overrides(existingElement, newElement, owner)) {
overridden.add(newElement);
continue classElements;
} else if (includesAbstract && overrides(newElement, existingElement, owner)) {
iterator.remove();
overridden.add(existingElement);
cache.add(new MethodElement<>(newElement, overridden));
continue classElements;
}
} else if (overrides(newElement, existingElement, owner)) {
iterator.remove();
overridden.add(existingElement);
cache.add(new MethodElement<>(newElement, overridden));
continue classElements;
}
}
cache.add(new MethodElement<>(newElement, new LinkedHashSet<>()));
}
collectedMethods.addAll(cache);
}
/**
* The method element.
*
* @param methodElement The element
* @param overridden The overridden collection
* @param <N> The native method element type
*/
public record MethodElement<N>(N methodElement, LinkedHashSet<N> overridden) {
}
private record MethodCacheKey(Object classKey, Object methodKey) {
}
}