TraceFunctionAroundWrapper.java

/*
 * Copyright 2013-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.sleuth.instrument.messaging;

import java.lang.reflect.Type;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.cloud.context.scope.refresh.RefreshScopeRefreshedEvent;
import org.springframework.cloud.function.context.catalog.FunctionAroundWrapper;
import org.springframework.cloud.function.context.catalog.FunctionTypeUtils;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry;
import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.instrument.reactor.ReactorSleuth;
import org.springframework.cloud.sleuth.propagation.Propagator;
import org.springframework.context.ApplicationListener;
import org.springframework.core.env.Environment;
import org.springframework.messaging.Message;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;

/**
 * Trace representation of a {@link FunctionAroundWrapper}.
 *
 * @author Marcin Grzejszczak
 * @author Oleg Zhurakousky
 * @author Tim Ysewyn
 * @since 3.0.0
 */
public class TraceFunctionAroundWrapper extends FunctionAroundWrapper
		implements ApplicationListener<RefreshScopeRefreshedEvent> {

	private static final Log log = LogFactory.getLog(TraceFunctionAroundWrapper.class);

	private final Environment environment;

	private final Tracer tracer;

	private final Propagator propagator;

	private final Propagator.Setter<MessageHeaderAccessor> injector;

	private final Propagator.Getter<MessageHeaderAccessor> extractor;

	private final TraceMessageHandler traceMessageHandler;

	private final List<FunctionMessageSpanCustomizer> customizers;

	final Map<String, String> functionToDestinationCache = new ConcurrentHashMap<>();

	public TraceFunctionAroundWrapper(Environment environment, Tracer tracer, Propagator propagator,
			Propagator.Setter<MessageHeaderAccessor> injector, Propagator.Getter<MessageHeaderAccessor> extractor) {
		this(environment, tracer, propagator, injector, extractor, Collections.emptyList());
	}

	public TraceFunctionAroundWrapper(Environment environment, Tracer tracer, Propagator propagator,
			Propagator.Setter<MessageHeaderAccessor> injector, Propagator.Getter<MessageHeaderAccessor> extractor,
			List<FunctionMessageSpanCustomizer> customizers) {
		this.environment = environment;
		this.tracer = tracer;
		this.propagator = propagator;
		this.injector = injector;
		this.extractor = extractor;
		this.customizers = customizers;
		this.traceMessageHandler = TraceMessageHandler.forNonSpringIntegration(this.tracer, this.propagator,
				this.injector, this.extractor, this.customizers);
	}

	@Override
	protected Object doApply(Object message, SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction) {
		if (FunctionTypeUtils.isCollectionOfMessage(targetFunction.getOutputType())) {
			return targetFunction.apply(message); // no instrumentation
		}
		else if (targetFunction.isInputTypePublisher() || targetFunction.isOutputTypePublisher()) {
			if (message != null && !(message instanceof Publisher)) {
				logDebugAboutMessageTypes(message);
				return targetFunction.apply(message); // no instrumentation
			}
			return reactorStream((Publisher) message, targetFunction);
		}
		else if (message != null && !(message instanceof Message)) {
			logDebugAboutMessageTypes(message);
			return targetFunction.apply(message); // no instrumentation
		}
		return nonReactorStream((Message<byte[]>) message, targetFunction);
	}

	private void logDebugAboutMessageTypes(Object message) {
		if (log.isDebugEnabled()) {
			String messageClass = message.getClass().getName();
			log.debug("We only support tracing for Message types. You need to wrap your function type [" + messageClass
					+ "] into [Message<" + messageClass + ">]");
		}
	}

	private Object reactorStream(Publisher messageStream,
			SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction) {
		if (messageStream == null && targetFunction.isSupplier()) { // Supplier
			return reactorStreamSupplier(messageStream, targetFunction);
		}
		Type itemType = FunctionTypeUtils.getGenericType(targetFunction.getInputType());
		Class<?> itemTypeClass = FunctionTypeUtils.getRawType(itemType);
		if (!itemTypeClass.equals(Message.class)) {
			if (log.isDebugEnabled()) {
				log.debug("Target function [" + targetFunction.getFunctionDefinition() + "] has raw input type ["
						+ itemType + "] and should be [" + Message.class + "]. Will not wrap it.");
			}
			return targetFunction.apply(messageStream);
		}
		Publisher<Message> messagePublisher = messageStream;
		if (FunctionTypeUtils.isMono(targetFunction.getInputType())) {
			return reactorMonoStream(targetFunction, messagePublisher);
		}
		return reactorFluxStream(targetFunction, messagePublisher);
	}

	private Object reactorMonoStream(SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction,
			Publisher<Message> messagePublisher) {
		if (log.isDebugEnabled()) {
			log.debug("Will instrument a stream Mono function");
		}
		Mono<Message> mono = Mono.from(messagePublisher)
				// ensure there are no previous spans
				.doOnNext(m -> tracer.withSpan(null))
				.map(msg -> this.traceMessageHandler.wrapInputMessage(msg,
						inputDestination(targetFunction.getFunctionDefinition())))
				.flatMap(msg -> Mono.deferContextual(ctx -> {
					MessageAndSpansAndScope messageAndSpansAndScope = ctx.get(MessageAndSpansAndScope.class);
					messageAndSpansAndScope.messageAndSpans = msg;
					messageAndSpansAndScope.span = msg.childSpan;
					setNameAndTag(targetFunction, msg.childSpan);
					messageAndSpansAndScope.scope = tracer.withSpan(msg.childSpan);
					return Mono.just(msg.msg);
				}));
		if (targetFunction.isConsumer()) {
			return targetFunction.apply(reactorStreamConsumer(mono));
		}
		final Publisher<Message> function = ((Publisher<Message>) targetFunction.apply(mono));
		if (function instanceof Mono) {
			return messageMono(targetFunction, (Mono<Message>) function);
		}
		return messageFlux(targetFunction, (Flux<Message>) function);
	}

	private Mono<Message> messageMono(SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction,
			Mono<Message> function) {
		return Mono.deferContextual(contextView -> {
			MessageAndSpansAndScope msg = contextView.get(MessageAndSpansAndScope.class);
			return function.doOnNext(message -> {
				msg.end();
				msg.handle();
			}).map(msgResult -> {
				MessageAndSpan messageAndSpan = traceMessageHandler.wrapOutputMessage(msgResult,
						msg.messageAndSpans.parentSpan, outputDestination(targetFunction.getFunctionDefinition()));
				traceMessageHandler.afterMessageHandled(messageAndSpan.span, null);
				return messageAndSpan.msg;
			})
					// TODO: Fix me when this is resolved in Reactor
					// .doOnSubscribe(__ -> scope.close())
					.doOnError(msg::error).doFinally(signalType -> {
						if (!msg.isHandled()) {
							msg.end();
						}
					});
		}).contextWrite(contextView -> contextView.put(MessageAndSpansAndScope.class, new MessageAndSpansAndScope()));
	}

	private Object reactorFluxStream(SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction,
			Publisher<Message> messagePublisher) {
		if (log.isDebugEnabled()) {
			log.debug("Will instrument a stream Flux function");
		}
		Flux<Message> flux = Flux.from(messagePublisher)
				// ensure there are no previous spans
				.doOnNext(m -> tracer.withSpan(null))
				.map(msg -> this.traceMessageHandler.wrapInputMessage(msg,
						inputDestination(targetFunction.getFunctionDefinition())))
				.flatMap(msg -> Flux.deferContextual(ctx -> {
					MessageAndSpansAndScope messageAndSpansAndScope = ctx.get(MessageAndSpansAndScope.class);
					messageAndSpansAndScope.messageAndSpans = msg;
					messageAndSpansAndScope.span = msg.childSpan;
					setNameAndTag(targetFunction, msg.childSpan);
					messageAndSpansAndScope.scope = tracer.withSpan(msg.childSpan);
					return Mono.just(msg.msg);
				}));
		if (targetFunction.isConsumer()) {
			return targetFunction.apply(reactorStreamConsumer(flux));
		}
		final Publisher<Message> function = ((Publisher<Message>) targetFunction.apply(flux));
		if (function instanceof Mono) {
			return messageMono(targetFunction, (Mono<Message>) function);
		}
		return messageFlux(targetFunction, (Flux<Message>) function);
	}

	private Flux<Message> messageFlux(SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction,
			Flux<Message> function) {
		return Flux.deferContextual(contextView -> {
			MessageAndSpansAndScope msg = contextView.get(MessageAndSpansAndScope.class);
			return function.doOnNext(message -> {
				msg.end();
				msg.handle();
			}).map(msgResult -> {
				MessageAndSpan messageAndSpan = traceMessageHandler.wrapOutputMessage(msgResult,
						msg.messageAndSpans.parentSpan, outputDestination(targetFunction.getFunctionDefinition()));
				traceMessageHandler.afterMessageHandled(messageAndSpan.span, null);
				return messageAndSpan.msg;
			})
					// TODO: Fix me when this is resolved in Reactor
					// .doOnSubscribe(__ -> scope.close())
					.doOnError(msg::error).doFinally(signalType -> {
						if (!msg.isHandled()) {
							msg.end();
						}
					});
		}).contextWrite(contextView -> contextView.put(MessageAndSpansAndScope.class, new MessageAndSpansAndScope()));
	}

	private Object reactorStreamConsumer(Object result) {
		if (result instanceof Mono) {
			return Mono.deferContextual(contextView -> {
				MessageAndSpansAndScope msg = contextView.get(MessageAndSpansAndScope.class);
				return ((Mono<Message>) result)
						// TODO: Fix me when this is resolved in Reactor
						// .doOnSubscribe(__ -> scope.close())
						.doOnError(msg::error).doFinally(signalType -> {
							msg.end();
						});
			}).contextWrite(
					contextView -> contextView.put(MessageAndSpansAndScope.class, new MessageAndSpansAndScope()));
		}
		return Flux.deferContextual(contextView -> {
			MessageAndSpansAndScope msg = contextView.get(MessageAndSpansAndScope.class);
			return ((Flux<Message>) result)
					// TODO: Fix me when this is resolved in Reactor
					// .doOnSubscribe(__ -> scope.close())
					.doOnError(msg::error).doFinally(signalType -> {
						msg.end();
					});
		}).contextWrite(contextView -> contextView.put(MessageAndSpansAndScope.class, new MessageAndSpansAndScope()));
	}

	private Object reactorStreamSupplier(Publisher<?> message,
			SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction) {
		Publisher<?> publisher = (Publisher<?>) targetFunction.get();
		if (publisher instanceof Mono) {
			if (log.isDebugEnabled()) {
				log.debug("Will instrument a stream Mono supplier");
			}
			Mono mono = (Mono) publisher;
			publisher = ReactorSleuth.tracedMono(tracer, tracer.currentTraceContext(),
					targetFunction.getFunctionDefinition(), () -> mono, (msg, s) -> {
						customizedInputMessageSpan(s, msg instanceof Message ? (Message) msg : null);
					}).map(object -> toMessage(object))
					.map(object -> this.getMessageAndSpans((Message) object, targetFunction.getFunctionDefinition(),
							setNameAndTag(targetFunction, tracer.currentSpan())))
					.doOnNext(wrappedOutputMessage -> customizedOutputMessageSpan(
							((MessageAndSpan) wrappedOutputMessage).span, ((MessageAndSpan) wrappedOutputMessage).msg))
					.doOnNext(wrappedOutputMessage -> traceMessageHandler
							.afterMessageHandled(((MessageAndSpan) wrappedOutputMessage).span, null))
					.map(wrappedOutputMessage -> ((MessageAndSpan) wrappedOutputMessage).msg);
		}
		else {
			if (log.isDebugEnabled()) {
				log.debug("Will instrument a stream Flux supplier");
			}
			Flux flux = (Flux) publisher;
			publisher = ReactorSleuth.tracedFlux(tracer, tracer.currentTraceContext(),
					targetFunction.getFunctionDefinition(), () -> flux, (msg, s) -> {
						customizedInputMessageSpan(s, msg instanceof Message ? (Message) msg : null);
					}).map(object -> toMessage(object))
					.map(object -> this.getMessageAndSpans((Message) object, targetFunction.getFunctionDefinition(),
							setNameAndTag(targetFunction, tracer.currentSpan())))
					.doOnNext(wrappedOutputMessage -> customizedOutputMessageSpan(
							((MessageAndSpan) wrappedOutputMessage).span, ((MessageAndSpan) wrappedOutputMessage).msg))
					.doOnNext(wrappedOutputMessage -> traceMessageHandler
							.afterMessageHandled(((MessageAndSpan) wrappedOutputMessage).span, null))
					.map(wrappedOutputMessage -> ((MessageAndSpan) wrappedOutputMessage).msg);
		}
		return publisher;
	}

	private Span setNameAndTag(SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction, Span span) {
		return span.name(targetFunction.getFunctionDefinition()).tag(SleuthMessagingSpan.Tags.FUNCTION_NAME.getKey(),
				targetFunction.getFunctionDefinition());
	}

	private Object nonReactorStream(Message<byte[]> message,
			SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction) {
		MessageAndSpans invocationMessage = null;
		Span span;
		if (message == null && targetFunction.isSupplier()) { // Supplier
			if (log.isDebugEnabled()) {
				log.debug("Creating a span for a supplier");
			}
			span = setNameAndTag(targetFunction, this.tracer.nextSpan());
			customizedInputMessageSpan(span, null);
		}
		else {
			if (log.isDebugEnabled()) {
				log.debug("Will retrieve the tracing headers from the message");
			}
			invocationMessage = this.traceMessageHandler.wrapInputMessage(message,
					inputDestination(targetFunction.getFunctionDefinition()));
			if (log.isDebugEnabled()) {
				log.debug("Wrapped input msg " + invocationMessage);
			}
			span = setNameAndTag(targetFunction, invocationMessage.childSpan);
		}
		Object result;
		Throwable throwable = null;
		try (Tracer.SpanInScope ws = this.tracer.withSpan(span.start())) {
			result = invocationMessage == null ? targetFunction.get() : targetFunction.apply(invocationMessage.msg);
		}
		catch (Exception e) {
			throwable = e;
			throw e;
		}
		finally {
			this.traceMessageHandler.afterMessageHandled(span, throwable);
		}
		if (result == null) {
			if (log.isDebugEnabled()) {
				log.debug("Returned message is null - we have a consumer");
			}
			return null;
		}
		Message<?> msgResult = toMessage(result);
		MessageAndSpan wrappedOutputMessage;
		if (log.isDebugEnabled()) {
			log.debug("Will instrument the output message");
		}
		if (invocationMessage != null) {
			wrappedOutputMessage = this.traceMessageHandler.wrapOutputMessage(msgResult, invocationMessage.parentSpan,
					outputDestination(targetFunction.getFunctionDefinition()));
		}
		else {
			wrappedOutputMessage = this.getMessageAndSpans(msgResult, targetFunction.getFunctionDefinition(), span);
		}
		if (log.isDebugEnabled()) {
			log.debug("Wrapped output msg " + wrappedOutputMessage);
		}
		traceMessageHandler.afterMessageHandled(wrappedOutputMessage.span, null);
		return wrappedOutputMessage.msg;
	}

	MessageAndSpan getMessageAndSpans(Message<?> resultMessage, String name, Span spanFromMessage) {
		return traceMessageHandler.wrapOutputMessage(resultMessage, spanFromMessage, outputDestination(name));
	}

	private void customizedInputMessageSpan(Span spanToCustomize, Message<?> msg) {
		this.customizers.forEach(cust -> cust.customizeInputMessageSpan(spanToCustomize, msg));
	}

	private void customizedOutputMessageSpan(Span spanToCustomize, Message<?> msg) {
		this.customizers.forEach(cust -> cust.customizeOutputMessageSpan(spanToCustomize, msg));
	}

	private Message<?> toMessage(Object result) {
		if (!(result instanceof Message)) {
			return MessageBuilder.withPayload(result).build();
		}
		return (Message<?>) result;
	}

	String inputDestination(String functionDefinition) {
		return this.functionToDestinationCache.computeIfAbsent(functionDefinition, s -> {
			String bindingMappingProperty = "spring.cloud.stream.function.bindings." + s + "-in-0";
			String bindingProperty = this.environment.containsProperty(bindingMappingProperty)
					? this.environment.getProperty(bindingMappingProperty) : s + "-in-0";
			return this.environment.getProperty("spring.cloud.stream.bindings." + bindingProperty + ".destination", s);
		});
	}

	String outputDestination(String functionDefinition) {
		return this.functionToDestinationCache.computeIfAbsent(functionDefinition, s -> {
			String bindingMappingProperty = "spring.cloud.stream.function.bindings." + s + "-out-0";
			String bindingProperty = this.environment.containsProperty(bindingMappingProperty)
					? this.environment.getProperty(bindingMappingProperty) : s + "-out-0";
			return this.environment.getProperty("spring.cloud.stream.bindings." + bindingProperty + ".destination", s);
		});
	}

	@Override
	public void onApplicationEvent(RefreshScopeRefreshedEvent event) {
		if (log.isDebugEnabled()) {
			log.debug("Context refreshed, will reset the cache");
		}
		this.functionToDestinationCache.clear();
	}

	static class MessageAndSpansAndScope {

		MessageAndSpans messageAndSpans;

		Span span;

		Tracer.SpanInScope scope;

		boolean handled;

		void error(Throwable throwable) {
			if (this.span != null) {
				this.span.error(throwable);
			}
		}

		void handle() {
			this.handled = true;
		}

		boolean isHandled() {
			return this.handled;
		}

		void end() {
			if (this.span != null) {
				this.span.end();
			}
			if (this.scope != null) {
				this.scope.close();
			}
		}

	}

}