ExtensionProviders.java
/**
* Copyright (c) 2018, RTE (http://www.rte-france.com)
* This Source Code Form is subject to the terms of the Mozilla Public
* License, v. 2.0. If a copy of the MPL was not distributed with this
* file, You can obtain one at http://mozilla.org/MPL/2.0/.
* SPDX-License-Identifier: MPL-2.0
*/
package com.powsybl.commons.extensions;
import com.powsybl.commons.PowsyblException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* @author Mathieu Bague {@literal <mathieu.bague at rte-france.com>}
*/
public final class ExtensionProviders<T extends ExtensionProvider> {
private static final Logger LOGGER = LoggerFactory.getLogger(ExtensionProviders.class);
private final Map<String, T> providers;
public static <T extends ExtensionProvider> ExtensionProviders<T> createProvider(Class<T> clazz) {
return new ExtensionProviders<>(clazz);
}
public static <T extends ExtensionProvider> ExtensionProviders<T> createProvider(Class<T> clazz, String categoryName) {
return new ExtensionProviders<>(clazz, categoryName);
}
public static <T extends ExtensionProvider> ExtensionProviders<T> createProvider(Class<T> clazz, String categoryName, ExtensionProvidersLoader loader) {
return new ExtensionProviders<>(clazz, categoryName, null, loader);
}
public static <T extends ExtensionProvider> ExtensionProviders<T> createProvider(Class<T> clazz, String categoryName, Set<String> extensionNames) {
return new ExtensionProviders<>(clazz, categoryName, extensionNames);
}
private ExtensionProviders(Class<T> clazz) {
Objects.requireNonNull(clazz);
providers = loadProviders(clazz, null, null, new DefaultExtensionProvidersLoader());
}
private ExtensionProviders(Class<T> clazz, String categoryName) {
this(clazz, categoryName, null);
}
private ExtensionProviders(Class<T> clazz, String categoryName, Set<String> extensionNames) {
this(clazz, categoryName, extensionNames, new DefaultExtensionProvidersLoader());
}
private ExtensionProviders(Class<T> clazz, String categoryName, Set<String> extensionNames, ExtensionProvidersLoader loader) {
Objects.requireNonNull(clazz);
Objects.requireNonNull(categoryName);
providers = loadProviders(clazz, categoryName, extensionNames, loader);
}
private Map<String, T> loadProviders(Class<T> clazz, String categoryName, Set<String> extensionNames, ExtensionProvidersLoader loader) {
final Map<String, T> providersMap = new HashMap<>();
Stream<T> servicesStream = loader.getServicesStream(clazz);
if (categoryName != null) {
servicesStream = servicesStream.filter(s -> s.getCategoryName().equals(categoryName));
}
Set<T> services = servicesStream.collect(Collectors.toSet());
services.forEach(service -> addService(providersMap, service, extensionNames));
if (clazz.equals(ExtensionSerDe.class)) {
// Add the alternative serialization names for extension SerDes
services.forEach(service -> ((ExtensionSerDe<?, ?>) service).getSerializationNames().stream()
.filter(name -> !service.getExtensionName().equals(name))
.forEach(name -> addServiceForAlternativeName(providersMap, service, name, extensionNames)));
}
return providersMap;
}
private void addService(Map<String, T> providersMap, T service, Set<String> extensionsToImport) {
String name = service.getExtensionName();
if (extensionsToImport == null || extensionsToImport.contains(name)) {
if (providersMap.containsKey(name)) {
// There should not be two extensions of the same real name
throw new IllegalStateException("Several providers were found for extension '" + name + "'");
} else {
providersMap.put(name, service);
}
}
}
private void addServiceForAlternativeName(Map<String, T> providersMap, T service, String alternativeName, Set<String> extensionsToImport) {
if (extensionsToImport == null || extensionsToImport.contains(service.getExtensionName()) || extensionsToImport.contains(alternativeName)) {
T previousService = providersMap.get(alternativeName);
if (previousService != null) {
// For alternative names, a duplicate does not replace the previous provider, to avoid replacing
// providers mapped to their real extension names by providers mapped to an alternative name.
LOGGER.warn("Alternative extension name {} for extension {} is already used for real extension {} - Skipping",
alternativeName, service.getExtensionName(), previousService.getExtensionName());
} else {
providersMap.put(alternativeName, service);
}
}
}
public T findProvider(String name) {
return providers.get(name);
}
public T findProviderOrThrowException(String name) {
T serializer = findProvider(name);
if (serializer == null) {
throw new PowsyblException("Provider not found for extension " + name);
}
return serializer;
}
public Collection<T> getProviders() {
return providers.values().stream().distinct().collect(Collectors.toList());
}
public <E> void addExtensions(Extendable<E> extendable, Collection<Extension<E>> extensions) {
Objects.requireNonNull(extendable);
Objects.requireNonNull(extensions);
extensions.forEach(e -> extendable.addExtension(findProvider(e.getName()).getExtensionClass(), e));
}
}