PartitionAwareFunctionWrapper.java

/*
 * Copyright 2020-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.stream.function;

import java.util.function.Function;
import java.util.function.Supplier;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper;
import org.springframework.cloud.stream.binder.BinderHeaders;
import org.springframework.cloud.stream.binder.PartitionHandler;
import org.springframework.cloud.stream.binder.ProducerProperties;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.expression.spel.support.StandardEvaluationContext;
import org.springframework.integration.expression.ExpressionUtils;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.util.ObjectUtils;

/**
 * This class is effectively a wrapper which is aware of the stream related partition information
 * for outgoing messages. It has only one responsibility and that is to modify the result message
 * with 'scst_partition' header if necessary.
 *
 * @author Oleg Zhurakousky
 * @author Soby Chacko
 * @author Byungjun You
 */
class PartitionAwareFunctionWrapper implements Function<Object, Object>, Supplier<Object> {

	protected final Log logger = LogFactory.getLog(PartitionAwareFunctionWrapper.class);

	@SuppressWarnings("rawtypes")
	protected final Function function;

	private final Function<Object, Object> outputMessageEnricher;

	private boolean messageEnricherEnabled = true;

	PartitionAwareFunctionWrapper(Function<?, ?> function, ConfigurableApplicationContext context, ProducerProperties producerProperties) {
		this.function = function;

		if (producerProperties != null && producerProperties.isPartitioned()) {
			StandardEvaluationContext evaluationContext = ExpressionUtils.createStandardEvaluationContext(context.getBeanFactory());
			PartitionHandler partitionHandler = new PartitionHandler(evaluationContext, producerProperties, context.getBeanFactory());

			this.outputMessageEnricher = output -> {
				if ((ObjectUtils.isArray(output) && !(output instanceof byte[])) || output instanceof Iterable) {
					return output;
				}
				else if (!(output instanceof Message)) {
					output = MessageBuilder.withPayload(output).build();
				}
				return toMessageWithPartitionHeader((Message<?>) output, partitionHandler);
			};
		}
		else {
			this.outputMessageEnricher = null;
		}
	}

	@SuppressWarnings({ "unchecked", "rawtypes" })
	private Message<?> toMessageWithPartitionHeader(Message message, PartitionHandler partitionHandler) {
		int partitionId = partitionHandler.determinePartition(message);
		return MessageBuilder
			.fromMessage(message)
			.setHeader(BinderHeaders.PARTITION_HEADER, partitionId).build();
	}

	@SuppressWarnings("unchecked")
	@Override
	public Object apply(Object input) {
		if (this.messageEnricherEnabled) {
			this.setEnhancerIfNecessary();
		}
		Object result = this.function.apply(input);
		boolean messageContainsPartitionHeader = false;
		if (result != null && Message.class.isAssignableFrom(result.getClass())) {
			if (((Message<?>) result).getHeaders().containsKey(BinderHeaders.PARTITION_HEADER)) {
				messageContainsPartitionHeader = true;
			}
		}
		if (!((FunctionInvocationWrapper) this.function).isInputTypePublisher() && !messageContainsPartitionHeader) {
			((FunctionInvocationWrapper) this.function).setEnhancer(null);
		}
		return result;
	}

	@Override
	public Object get() {
		if (this.function instanceof FunctionInvocationWrapper functionInvocationWrapper) {
			if (this.messageEnricherEnabled) {
				this.setEnhancerIfNecessary();
			}
			return functionInvocationWrapper.get();
		}
		throw new IllegalStateException("Call to get() is not allowed since this function is not a Supplier.");
	}

	private void setEnhancerIfNecessary() {
		if (this.function instanceof FunctionInvocationWrapper functionInvocationWrapper) {
			functionInvocationWrapper.setEnhancer(this.outputMessageEnricher);
		}
	}

	public void setMessageEnricherEnabled(boolean messageEnricherEnabled) {
		this.messageEnricherEnabled = messageEnricherEnabled;
	}
}