KafkaStreamsFunctionBeanPostProcessor.java

/*
 * Copyright 2019-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.stream.binder.kafka.streams.function;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.kafka.streams.kstream.GlobalKTable;
import org.apache.kafka.streams.kstream.KStream;
import org.apache.kafka.streams.kstream.KTable;

import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.beans.factory.annotation.AnnotatedBeanDefinition;
import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.beans.factory.support.BeanDefinitionRegistry;
import org.springframework.beans.factory.support.RootBeanDefinition;
import org.springframework.cloud.stream.binder.kafka.streams.KafkaStreamsBinderUtils;
import org.springframework.cloud.stream.function.StreamFunctionProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.support.GenericApplicationContext;
import org.springframework.core.ResolvableType;
import org.springframework.core.type.StandardMethodMetadata;
import org.springframework.lang.NonNull;
import org.springframework.util.ClassUtils;

/**
 * @author Soby Chacko
 * @author James Forward
 *
 * @since 2.2.0
 */
public class KafkaStreamsFunctionBeanPostProcessor implements InitializingBean, BeanFactoryAware, ApplicationContextAware {

	private static final Log LOG = LogFactory.getLog(KafkaStreamsFunctionBeanPostProcessor.class);

	private static final String[] EXCLUDE_FUNCTIONS = new String[]{"functionRouter", "sendToDlqAndContinue"};

	private ConfigurableListableBeanFactory beanFactory;
	private boolean onlySingleFunction;
	private final Map<String, ResolvableType> resolvableTypeMap = new TreeMap<>();
	private final Map<String, Method> methods = new TreeMap<>();

	private final StreamFunctionProperties streamFunctionProperties;

	private final Map<String, ResolvableType> kafkaStreamsOnlyResolvableTypes = new HashMap<>();
	private final Map<String, Method> kafakStreamsOnlyMethods = new HashMap<>();
	private ConfigurableApplicationContext applicationContext;

	public KafkaStreamsFunctionBeanPostProcessor(StreamFunctionProperties streamFunctionProperties) {
		this.streamFunctionProperties = streamFunctionProperties;
	}

	public Map<String, ResolvableType> getResolvableTypes() {
		return this.resolvableTypeMap;
	}

	public Map<String, Method> getMethods() {
		return methods;
	}

	@Override
	public void afterPropertiesSet() {
		String[] functionNames = this.beanFactory.getBeanNamesForType(Function.class);
		String[] biFunctionNames = this.beanFactory.getBeanNamesForType(BiFunction.class);
		String[] consumerNames = this.beanFactory.getBeanNamesForType(Consumer.class);
		String[] biConsumerNames = this.beanFactory.getBeanNamesForType(BiConsumer.class);

		final Stream<String> concat = Stream.concat(
				Stream.concat(Stream.of(functionNames), Stream.of(consumerNames)),
				Stream.concat(Stream.of(biFunctionNames), Stream.of(biConsumerNames)));
		final List<String> collect = concat.collect(Collectors.toList());
		collect.removeIf(s -> Arrays.stream(EXCLUDE_FUNCTIONS).anyMatch(t -> t.equals(s)));
		collect.removeIf(Pattern.compile(".*_registration").asPredicate());

		onlySingleFunction = collect.size() == 1;
		collect.stream()
				.forEach(this::extractResolvableTypes);

		kafkaStreamsOnlyResolvableTypes.keySet().forEach(k -> addResolvableTypeInfo(k, kafkaStreamsOnlyResolvableTypes.get(k)));
		kafakStreamsOnlyMethods.keySet().forEach(k -> addResolvableTypeInfo(k, kafakStreamsOnlyMethods.get(k)));

		BeanDefinitionRegistry registry = (BeanDefinitionRegistry) beanFactory;

		final String definition = streamFunctionProperties.getDefinition();
		final String[] functionUnits = KafkaStreamsBinderUtils.deriveFunctionUnits(definition);

		final Set<String> kafkaStreamsMethodNames = new HashSet<>(kafkaStreamsOnlyResolvableTypes.keySet());
		kafkaStreamsMethodNames.addAll(this.resolvableTypeMap.keySet());

		if (functionUnits.length == 0) {
			for (String s : getResolvableTypes().keySet()) {
				ResolvableType[] resolvableTypes = new ResolvableType[]{getResolvableTypes().get(s)};
				RootBeanDefinition rootBeanDefinition = new RootBeanDefinition(
						KafkaStreamsBindableProxyFactory.class);
				rootBeanDefinition.getPropertyValues().add("streamFunctionProperties", this.streamFunctionProperties);
				registerKakaStreamsProxyFactory(registry, s, resolvableTypes, rootBeanDefinition);
			}
		}
		else {
			for (String functionUnit : functionUnits) {
				if (functionUnit.contains("|")) {
					final String[] composedFunctions = functionUnit.split("\\|");
					String derivedNameFromComposed = "";
					ResolvableType[] resolvableTypes = new ResolvableType[composedFunctions.length];

					int i = 0;
					boolean nonKafkaStreamsFunctionsFound = false;

					for (String split : composedFunctions) {
						derivedNameFromComposed = derivedNameFromComposed.concat(split);
						resolvableTypes[i++] = getResolvableTypes().get(split);
						if (!kafkaStreamsMethodNames.contains(split)) {
							nonKafkaStreamsFunctionsFound = true;
							break;
						}
					}
					if (!nonKafkaStreamsFunctionsFound) {
						RootBeanDefinition rootBeanDefinition = new RootBeanDefinition(
								KafkaStreamsBindableProxyFactory.class);
						rootBeanDefinition.getPropertyValues().add("streamFunctionProperties", this.streamFunctionProperties);
						registerKakaStreamsProxyFactory(registry, derivedNameFromComposed, resolvableTypes, rootBeanDefinition);
					}
				}
				else {
					// Ensure that the function unit is a Kafka Streams function
					if (kafkaStreamsMethodNames.contains(functionUnit)) {
						ResolvableType[] resolvableTypes = new ResolvableType[]{getResolvableTypes().get(functionUnit)};
						RootBeanDefinition rootBeanDefinition = new RootBeanDefinition(
								KafkaStreamsBindableProxyFactory.class);
						rootBeanDefinition.getPropertyValues().add("streamFunctionProperties", this.streamFunctionProperties);
						registerKakaStreamsProxyFactory(registry, functionUnit, resolvableTypes, rootBeanDefinition);
					}
				}
			}
		}
	}

	private void registerKakaStreamsProxyFactory(BeanDefinitionRegistry registry, String s, ResolvableType[] resolvableTypes, RootBeanDefinition rootBeanDefinition) {
		AtomicReference<KafkaStreamsBindableProxyFactory> proxyFactory = new AtomicReference<>();
		Method method = getMethods().get(s);
		KafkaStreamsBindableProxyFactory kafkaStreamsBindableProxyFactory =
			new KafkaStreamsBindableProxyFactory(resolvableTypes, s, method, this.streamFunctionProperties);
		proxyFactory.set(kafkaStreamsBindableProxyFactory);
		((GenericApplicationContext) this.applicationContext).registerBean("kafkaStreamsBindableProxyFactory-" + s,
			KafkaStreamsBindableProxyFactory.class, proxyFactory::get);
	}

	private void extractResolvableTypes(String key) {
		BeanDefinition beanDefinition = this.beanFactory.getBeanDefinition(key);
		ResolvableType resolvableType = null;
		Class<?> rawClass = null;
		Optional<Method> functionalBeanMethods;
		try {
			if (beanDefinition instanceof AnnotatedBeanDefinition annotatedBeanDefinition) {
				if (annotatedBeanDefinition.getFactoryMethodMetadata() instanceof StandardMethodMetadata factoryMethodMetadata) {
					Method introspectedMethod = factoryMethodMetadata.getIntrospectedMethod();
					resolvableType = ResolvableType.forMethodReturnType(introspectedMethod);
					rawClass = resolvableType.getGeneric(0).getRawClass();
				}
				else {
					final Class<?> classObj = ClassUtils.resolveClassName(((AnnotatedBeanDefinition)
							this.beanFactory.getBeanDefinition(key))
							.getMetadata().getClassName(),
						ClassUtils.getDefaultClassLoader());
					Method[] methods = classObj.getDeclaredMethods();
					functionalBeanMethods = KafkaStreamsBinderUtils.findMethodWithName(key, methods);
					if (functionalBeanMethods.isEmpty()) {
						methods = classObj.getMethods(); // check the inherited methods
						functionalBeanMethods = KafkaStreamsBinderUtils.findMethodWithName(key, methods);
					}
					if (functionalBeanMethods.isEmpty()) {
						final String factoryMethodName = beanDefinition.getFactoryMethodName();
						functionalBeanMethods = KafkaStreamsBinderUtils.findMethodWithName(factoryMethodName, methods);
					}

					if (functionalBeanMethods.isPresent()) {
						Method method = functionalBeanMethods.get();
						resolvableType = ResolvableType.forMethodReturnType(method, classObj);
						rawClass = resolvableType.getGeneric(0).getRawClass();
						if (rawClass == KStream.class || rawClass == KTable.class || rawClass == GlobalKTable.class) {
							saveTypeInformation(key, resolvableType);
						}
					}
					else {
						Optional<Method> componentBeanMethods = Arrays.stream(methods)
							.filter(m -> m.getName().equals("apply") && isKafkaStreamsTypeFound(m) ||
								m.getName().equals("accept") && isKafkaStreamsTypeFound(m)).findFirst();
						if (componentBeanMethods.isPresent()) {
							Method method = componentBeanMethods.get();
							resolvableType = ResolvableType.forMethodParameter(method, 0);
							rawClass = resolvableType.getRawClass();
							saveMethodInfoForComponentBeans(key, method);
						}
					}
				}
			}
			else {
				resolvableType = beanDefinition.getResolvableType();
				rawClass = resolvableType.getGeneric(0).getRawClass();
			}

			if (rawClass == KStream.class || rawClass == KTable.class || rawClass == GlobalKTable.class) {
				saveTypeInformation(key, resolvableType);
			}
		}
		catch (Exception e) {
			LOG.error("Function activation issues while mapping the function: " + key, e);
		}
	}

	private void saveMethodInfoForComponentBeans(String key, Method method) {
		if (onlySingleFunction) {
			this.methods.put(key, method);
		}
		else {
			kafakStreamsOnlyMethods.put(key, method);
		}
	}

	private void saveTypeInformation(String key, ResolvableType resolvableType) {
		if (onlySingleFunction) {
			resolvableTypeMap.put(key, resolvableType);
		}
		else {
			discoverOnlyKafkaStreamsResolvableTypes(key, resolvableType);
		}
	}

	private void addResolvableTypeInfo(String key, ResolvableType resolvableType) {
		if (kafkaStreamsOnlyResolvableTypes.size() == 1) {
			resolvableTypeMap.put(key, resolvableType);
		}
		else {
			final String definition = streamFunctionProperties.getDefinition();
			if (definition == null) {
				throw new IllegalStateException("Multiple functions found, but function definition property is not set.");
			}
			else if (definition.contains(key)) {
				resolvableTypeMap.put(key, resolvableType);
			}
		}
	}

	private void discoverOnlyKafkaStreamsResolvableTypes(String key, ResolvableType resolvableType) {
		kafkaStreamsOnlyResolvableTypes.put(key, resolvableType);
	}

	private void discoverOnlyKafkaStreamsResolvableTypesAndMethods(String key, ResolvableType resolvableType, Method method) {
		kafkaStreamsOnlyResolvableTypes.put(key, resolvableType);
		kafakStreamsOnlyMethods.put(key, method);
	}

	private void addResolvableTypeInfo(String key, Method method) {
		if (kafakStreamsOnlyMethods.size() == 1) {
			this.methods.put(key, method);
		}
		else {
			final String definition = streamFunctionProperties.getDefinition();
			if (definition == null) {
				throw new IllegalStateException("Multiple functions found, but function definition property is not set.");
			}
			else if (definition.contains(key)) {
				this.methods.put(key, method);
			}
		}
	}

	private boolean isKafkaStreamsTypeFound(Method method) {
		return KStream.class.isAssignableFrom(method.getParameters()[0].getType()) ||
				KTable.class.isAssignableFrom(method.getParameters()[0].getType()) ||
				GlobalKTable.class.isAssignableFrom(method.getParameters()[0].getType());
	}

	@Override
	public void setBeanFactory(@NonNull BeanFactory beanFactory) throws BeansException {
		this.beanFactory = (ConfigurableListableBeanFactory) beanFactory;
	}

	@Override
	public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
		this.applicationContext = (ConfigurableApplicationContext) applicationContext;
	}
}