InjectorImpl.java
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.maven.di.impl;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Type;
import java.net.URL;
import java.util.AbstractList;
import java.util.AbstractMap;
import java.util.AbstractSet;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.maven.api.annotations.Nonnull;
import org.apache.maven.api.di.Provides;
import org.apache.maven.api.di.Qualifier;
import org.apache.maven.api.di.Singleton;
import org.apache.maven.api.di.Typed;
import org.apache.maven.di.Injector;
import org.apache.maven.di.Key;
import org.apache.maven.di.Scope;
import static org.apache.maven.di.impl.Binding.getPriorityComparator;
public class InjectorImpl implements Injector {
private final Map<Key<?>, Set<Binding<?>>> bindings = new HashMap<>();
private final Map<Class<? extends Annotation>, Supplier<Scope>> scopes = new HashMap<>();
private final Set<String> loadedUrls = new HashSet<>();
private final ThreadLocal<Set<Key<?>>> resolutionStack = new ThreadLocal<>();
public InjectorImpl() {
bindScope(Singleton.class, new SingletonScope());
}
@Nonnull
@Override
public <T> T getInstance(@Nonnull Class<T> key) {
return getInstance(Key.of(key));
}
@Nonnull
@Override
public <T> T getInstance(@Nonnull Key<T> key) {
return getCompiledBinding(new Dependency<>(key, false)).get();
}
@SuppressWarnings("unchecked")
@Override
public <T> void injectInstance(@Nonnull T instance) {
ReflectionUtils.generateInjectingInitializer(Key.of((Class<T>) instance.getClass()))
.compile(this::getCompiledBinding)
.accept(instance);
}
@Nonnull
@Override
public Injector discover(@Nonnull ClassLoader classLoader) {
try {
Enumeration<URL> enumeration = classLoader.getResources("META-INF/maven/org.apache.maven.api.di.Inject");
while (enumeration.hasMoreElements()) {
URL url = enumeration.nextElement();
if (loadedUrls.add(url.toExternalForm())) {
try (InputStream is = url.openStream();
BufferedReader reader =
new BufferedReader(new InputStreamReader(Objects.requireNonNull(is)))) {
for (String line :
reader.lines().filter(l -> !l.startsWith("#")).toList()) {
Class<?> clazz = classLoader.loadClass(line);
bindImplicit(clazz);
}
}
}
}
} catch (Exception e) {
throw new DIException("Error while discovering DI classes from classLoader", e);
}
return this;
}
@Nonnull
@Override
public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Scope scope) {
return bindScope(scopeAnnotation, () -> scope);
}
@Nonnull
@Override
public Injector bindScope(@Nonnull Class<? extends Annotation> scopeAnnotation, @Nonnull Supplier<Scope> scope) {
if (scopes.put(scopeAnnotation, scope) != null) {
throw new DIException(
"Cannot rebind scope annotation class to a different implementation: " + scopeAnnotation);
}
return this;
}
@Nonnull
@Override
public <U> Injector bindInstance(@Nonnull Class<U> clazz, @Nonnull U instance) {
Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
Binding<U> binding = Binding.toInstance(instance);
return doBind(key, binding);
}
@Override
public <U> Injector bindSupplier(@Nonnull Class<U> clazz, @Nonnull Supplier<U> supplier) {
Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
Binding<U> binding = Binding.toSupplier(supplier);
return doBind(key, binding);
}
@Nonnull
@Override
public Injector bindImplicit(@Nonnull Class<?> clazz) {
Key<?> key = Key.of(clazz, ReflectionUtils.qualifierOf(clazz));
if (clazz.isInterface()) {
bindings.computeIfAbsent(key, $ -> new HashSet<>());
if (key.getQualifier() != null) {
bindings.computeIfAbsent(Key.ofType(clazz), $ -> new HashSet<>());
}
} else if (!Modifier.isAbstract(clazz.getModifiers())) {
Binding<?> binding = ReflectionUtils.generateImplicitBinding(key);
doBind(key, binding);
}
return this;
}
private final LinkedHashSet<Key<?>> current = new LinkedHashSet<>();
private Injector doBind(Key<?> key, Binding<?> binding) {
if (!current.add(key)) {
current.add(key);
throw new DIException("Circular references: " + current);
}
try {
doBindImplicit(key, binding);
Class<?> cls = key.getRawType().getSuperclass();
while (cls != Object.class && cls != null) {
doBindImplicit(Key.of(cls, key.getQualifier()), binding);
if (key.getQualifier() != null) {
bind(Key.ofType(cls), binding);
}
cls = cls.getSuperclass();
}
return this;
} finally {
current.remove(key);
}
}
protected <U> Injector bind(Key<U> key, Binding<U> b) {
Set<Binding<?>> bindingSet = bindings.computeIfAbsent(key, $ -> new HashSet<>());
bindingSet.add(b);
return this;
}
@SuppressWarnings({"unchecked", "rawtypes"})
protected <T> Set<Binding<T>> getBindings(Key<T> key) {
return (Set) bindings.get(key);
}
protected Set<Key<?>> getBoundKeys() {
return bindings.keySet();
}
public Map<Key<?>, Set<Binding<?>>> getBindings() {
return bindings;
}
public <T> Set<Binding<T>> getAllBindings(Class<T> clazz) {
return getBindings(Key.of(clazz));
}
public <Q> Supplier<Q> getCompiledBinding(Dependency<Q> dep) {
Key<Q> key = dep.key();
Supplier<Q> originalSupplier = doGetCompiledBinding(dep);
return () -> {
checkCyclicDependency(key);
try {
return originalSupplier.get();
} finally {
removeFromResolutionStack(key);
}
};
}
public <Q> Supplier<Q> doGetCompiledBinding(Dependency<Q> dep) {
Key<Q> key = dep.key();
Set<Binding<Q>> res = getBindings(key);
if (res != null && !res.isEmpty()) {
List<Binding<Q>> bindingList = new ArrayList<>(res);
bindingList.sort(getPriorityComparator());
Binding<Q> binding = bindingList.get(0);
return compile(binding);
}
if (key.getRawType() == List.class) {
Set<Binding<Object>> res2 = getBindings(key.getTypeParameter(0));
if (res2 != null) {
// Sort bindings by priority (highest first) for deterministic ordering
List<Binding<Object>> sortedBindings = new ArrayList<>(res2);
sortedBindings.sort(getPriorityComparator());
List<Supplier<Object>> list =
sortedBindings.stream().map(this::compile).collect(Collectors.toList());
//noinspection unchecked
return () -> (Q) list(list, Supplier::get);
}
}
if (key.getRawType() == Map.class) {
Key<?> k = key.getTypeParameter(0);
Key<Object> v = key.getTypeParameter(1);
Set<Binding<Object>> res2 = getBindings(v);
if (k.getRawType() == String.class && res2 != null) {
Map<String, Supplier<Object>> map = res2.stream()
.filter(b -> b.getOriginalKey() == null
|| b.getOriginalKey().getQualifier() == null
|| b.getOriginalKey().getQualifier() instanceof String)
.collect(Collectors.toMap(
b -> (String)
(b.getOriginalKey() != null
? b.getOriginalKey().getQualifier()
: null),
this::compile));
//noinspection unchecked
return () -> (Q) map(map, Supplier::get);
}
}
if (dep.optional()) {
return () -> null;
}
throw new DIException("No binding to construct an instance for key "
+ key.getDisplayString() + ". Existing bindings:\n"
+ getBoundKeys().stream()
.map(Key::toString)
.map(String::trim)
.sorted()
.distinct()
.collect(Collectors.joining("\n - ", " - ", "")));
}
@SuppressWarnings("unchecked")
protected <Q> Supplier<Q> compile(Binding<Q> binding) {
Supplier<Q> compiled = binding.compile(this::getCompiledBinding);
if (binding.getScope() != null) {
Scope scope = scopes.entrySet().stream()
.filter(e -> e.getKey().isInstance(binding.getScope()))
.findFirst()
.map(Map.Entry::getValue)
.orElseThrow(() -> new DIException("Scope not bound for annotation "
+ binding.getScope().annotationType()))
.get();
compiled = scope.scope((Key<Q>) binding.getOriginalKey(), compiled);
}
return compiled;
}
protected void doBindImplicit(Key<?> key, Binding<?> binding) {
if (binding != null) {
// For non-explicit bindings, also bind all their base classes and interfaces according to the @Type
Object qualifier = key.getQualifier();
Class<?> type = key.getRawType();
Set<Class<?>> types = getBoundTypes(type.getAnnotation(Typed.class), type);
for (Type t : Types.getAllSuperTypes(type)) {
if (types == null || types.contains(Types.getRawType(t))) {
bind(Key.ofType(t, qualifier), binding);
if (qualifier != null) {
bind(Key.ofType(t), binding);
}
}
}
}
// Bind inner classes
for (Class<?> inner : key.getRawType().getDeclaredClasses()) {
boolean hasQualifier = Stream.of(inner.getAnnotations())
.anyMatch(ann -> ann.annotationType().isAnnotationPresent(Qualifier.class));
if (hasQualifier) {
bindImplicit(inner);
}
}
// Bind inner providers
for (Method method : key.getRawType().getDeclaredMethods()) {
if (method.isAnnotationPresent(Provides.class)) {
if (method.getTypeParameters().length != 0) {
throw new DIException("Parameterized method are not supported " + method);
}
Object qualifier = ReflectionUtils.qualifierOf(method);
Annotation scope = ReflectionUtils.scopeOf(method);
Type returnType = method.getGenericReturnType();
Set<Class<?>> types = getBoundTypes(method.getAnnotation(Typed.class), Types.getRawType(returnType));
Binding<Object> bind = ReflectionUtils.bindingFromMethod(method).scope(scope);
for (Type t : Types.getAllSuperTypes(returnType)) {
if (types == null || types.contains(Types.getRawType(t))) {
bind(Key.ofType(t, qualifier), bind);
if (qualifier != null) {
bind(Key.ofType(t), bind);
}
}
}
}
}
}
private static Set<Class<?>> getBoundTypes(Typed typed, Class<?> clazz) {
if (typed != null) {
Class<?>[] typesArray = typed.value();
if (typesArray == null || typesArray.length == 0) {
Set<Class<?>> types = new HashSet<>(Arrays.asList(clazz.getInterfaces()));
types.add(Object.class);
return types;
} else {
return new HashSet<>(Arrays.asList(typesArray));
}
} else {
return null;
}
}
protected <K, V, T> Map<K, V> map(Map<K, T> map, Function<T, V> mapper) {
return new WrappingMap<>(map, mapper);
}
private static class WrappingMap<K, V, T> extends AbstractMap<K, V> {
private final Map<K, T> delegate;
private final Function<T, V> mapper;
WrappingMap(Map<K, T> delegate, Function<T, V> mapper) {
this.delegate = delegate;
this.mapper = mapper;
}
@Override
public Set<Entry<K, V>> entrySet() {
return new AbstractSet<>() {
@Override
public Iterator<Entry<K, V>> iterator() {
Iterator<Entry<K, T>> it = delegate.entrySet().iterator();
return new Iterator<>() {
@Override
public boolean hasNext() {
return it.hasNext();
}
@Override
public Entry<K, V> next() {
Entry<K, T> n = it.next();
return new SimpleImmutableEntry<>(n.getKey(), mapper.apply(n.getValue()));
}
};
}
@Override
public int size() {
return delegate.size();
}
};
}
}
protected <Q, T> List<Q> list(List<T> bindingList, Function<T, Q> mapper) {
return new WrappingList<>(bindingList, mapper);
}
private static class WrappingList<Q, T> extends AbstractList<Q> {
private final List<T> delegate;
private final Function<T, Q> mapper;
WrappingList(List<T> delegate, Function<T, Q> mapper) {
this.delegate = delegate;
this.mapper = mapper;
}
@Override
public Q get(int index) {
return mapper.apply(delegate.get(index));
}
@Override
public int size() {
return delegate.size();
}
}
private void checkCyclicDependency(Key<?> key) {
Set<Key<?>> stack = resolutionStack.get();
if (stack == null) {
stack = new LinkedHashSet<>();
resolutionStack.set(stack);
}
if (!stack.add(key)) {
throw new DIException("Cyclic dependency detected: "
+ stack.stream().map(Key::getDisplayString).collect(Collectors.joining(" -> "))
+ " -> "
+ key.getDisplayString());
}
}
private void removeFromResolutionStack(Key<?> key) {
Set<Key<?>> stack = resolutionStack.get();
if (stack != null) {
stack.remove(key);
if (stack.isEmpty()) {
resolutionStack.remove();
}
}
}
private static class SingletonScope implements Scope {
Map<Key<?>, java.util.function.Supplier<?>> cache = new ConcurrentHashMap<>();
@Nonnull
@SuppressWarnings("unchecked")
@Override
public <T> java.util.function.Supplier<T> scope(
@Nonnull Key<T> key, @Nonnull java.util.function.Supplier<T> unscoped) {
return (java.util.function.Supplier<T>)
cache.computeIfAbsent(key, k -> new java.util.function.Supplier<T>() {
volatile T instance;
@Override
public T get() {
if (instance == null) {
synchronized (this) {
if (instance == null) {
instance = unscoped.get();
}
}
}
return instance;
}
});
}
}
/**
* Release all internal state so this Injector can be GC���d
* (and so that subsequent tests start from a clean slate).
* @since 4.1
*/
public void dispose() {
// First, clear any singleton���scope caches
scopes.values().stream()
.map(Supplier::get)
.filter(scope -> scope instanceof SingletonScope)
.map(scope -> (SingletonScope) scope)
.forEach(singleton -> singleton.cache.clear());
// Now clear everything else
bindings.clear();
scopes.clear();
loadedUrls.clear();
resolutionStack.remove();
}
}