TracingChannelInterceptor.java

/*
 * Copyright 2013-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.sleuth.instrument.messaging;

import java.util.Set;
import java.util.function.Function;

import org.springframework.aop.support.AopUtils;
import org.springframework.beans.BeansException;
import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.SpanAndScope;
import org.springframework.cloud.sleuth.ThreadLocalSpan;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.propagation.Propagator;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.core.log.LogAccessor;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.MessageHandler;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.support.ErrorMessage;
import org.springframework.messaging.support.ExecutorChannelInterceptor;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.ClassUtils;
import org.springframework.util.StringUtils;

/**
 * This starts and propagates {@link Span.Kind#PRODUCER} span for each message sent (via
 * native headers. It also extracts or creates a {@link Span.Kind#CONSUMER} span for each
 * message received. This span is injected onto each message so it becomes the parent when
 * a handler later calls {@link MessageHandler#handleMessage(Message)}.
 *
 * @author Marcin Grzejszczak
 * @author Artem Bilan
 * @since 3.0.0
 */
public final class TracingChannelInterceptor implements ExecutorChannelInterceptor, ApplicationContextAware {

	/**
	 * Name of the class in Spring Cloud Stream that is a direct channel.
	 */
	public static final String STREAM_DIRECT_CHANNEL = "org.springframework.cloud.stream.messaging.DirectWithAttributesChannel";

	private static final LogAccessor log = new LogAccessor(TracingChannelInterceptor.class);

	/**
	 * Using the literal "broker" until we come up with a better solution.
	 *
	 * <p>
	 * If the message originated from a binder (consumer binding), there will be different
	 * headers present (e.g. "KafkaHeaders.RECEIVED_TOPIC" Vs.
	 * "AmqpHeaders.CONSUMER_QUEUE" (unless the application removes them before sending).
	 * These don't represent the broker, rather a queue, and in any case the heuristics
	 * are not great. At least we might be able to tell if this is rabbit or not (ex how
	 * spring-rabbit works). We need to think this through before making an api, possibly
	 * experimenting.
	 *
	 * <p>
	 * If the app is outbound only (producer), there's no indication of what type the
	 * destination broker is. This may hint at a non-manual solution being overwriting the
	 * remoteServiceName later, similar to how servlet instrumentation lazy set
	 * "http.route".
	 */
	private static final String REMOTE_SERVICE_NAME = "broker";

	private static final boolean hasDirectChannelClass = ClassUtils
			.isPresent("org.springframework.integration.channel.DirectChannel", null);

	private static final boolean hasBinderTypeRegistry = ClassUtils
			.isPresent("org.springframework.cloud.stream.binder.BinderTypeRegistry", null);

	// special case of a Stream
	private static final Class<?> directWithAttributesChannelClass = ClassUtils.isPresent(STREAM_DIRECT_CHANNEL, null)
			? ClassUtils.resolveClassName(STREAM_DIRECT_CHANNEL, null) : null;

	private final ThreadLocalSpan threadLocalSpan;

	private final Tracer tracer;

	private final Propagator.Setter<MessageHeaderAccessor> injector;

	private final Propagator.Getter<MessageHeaderAccessor> extractor;

	private final MessageSpanCustomizer messageSpanCustomizer;

	private final Propagator propagator;

	private final Function<String, String> remoteServiceNameMapper;

	private ApplicationContext applicationContext;

	public TracingChannelInterceptor(Tracer tracer, Propagator propagator,
			Propagator.Setter<MessageHeaderAccessor> setter, Propagator.Getter<MessageHeaderAccessor> getter,
			Function<String, String> remoteServiceNameMapper, MessageSpanCustomizer messageSpanCustomizer) {
		this.tracer = tracer;
		this.propagator = propagator;
		this.injector = setter;
		this.extractor = getter;
		this.remoteServiceNameMapper = remoteServiceNameMapper;
		this.messageSpanCustomizer = messageSpanCustomizer;
		this.threadLocalSpan = new ThreadLocalSpan(tracer);
	}

	@Override
	public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
		this.applicationContext = applicationContext;
	}

	/**
	 * Starts and propagates {@link Span.Kind#PRODUCER} span for each message sent.
	 */
	@Override
	public Message<?> preSend(Message<?> message, MessageChannel channel) {
		Message<?> retrievedMessage = getMessage(message);
		log.debug(() -> "Received a message in pre-send " + retrievedMessage);
		MessageHeaderAccessor headers = mutableHeaderAccessor(retrievedMessage);
		Span.Builder spanBuilder = this.propagator.extract(headers, this.extractor);
		MessageHeaderPropagatorSetter.removeAnyTraceHeaders(headers, this.propagator.fields());
		spanBuilder = spanBuilder.kind(Span.Kind.PRODUCER);
		spanBuilder = this.messageSpanCustomizer.customizeSend(spanBuilder, message, channel)
				.remoteServiceName(toRemoteServiceName(headers, remoteServiceNameMapper, applicationContext));
		Span span = spanBuilder.start();
		log.debug(() -> "Extracted result from headers " + span);
		setSpanInScope(span);
		this.propagator.inject(span.context(), headers, this.injector);
		log.debug(() -> "Created a new span in pre send " + span);
		Message<?> outputMessage = outputMessage(message, retrievedMessage, headers);
		if (isDirectChannel(channel)) {
			beforeHandle(outputMessage, channel, null);
		}
		return outputMessage;
	}

	private void setSpanInScope(Span span) {
		this.threadLocalSpan.set(span);
		log.debug(() -> "Put span in scope " + span);
	}

	private static String toRemoteServiceName(MessageHeaderAccessor headers,
			Function<String, String> remoteServiceNameMapper, ApplicationContext applicationContext) {

		for (String key : headers.getMessageHeaders().keySet()) {
			String remoteServiceName = remoteServiceNameMapper.apply(key);
			if (StringUtils.hasText(remoteServiceName)) {
				return remoteServiceName;
			}
		}

		if (hasBinderTypeRegistry && applicationContext != null) {
			org.springframework.cloud.stream.binder.BinderTypeRegistry typeRegistry = applicationContext
					.getBean(org.springframework.cloud.stream.binder.BinderTypeRegistry.class);
			Set<String> binderNames = typeRegistry.getAll().keySet();
			for (String binderName : binderNames) {
				String remoteServiceName = remoteServiceNameMapper.apply(binderName);
				if (StringUtils.hasText(remoteServiceName)) {
					return remoteServiceName;
				}
			}
		}
		return REMOTE_SERVICE_NAME;
	}

	private Message<?> outputMessage(Message<?> originalMessage, Message<?> retrievedMessage,
			MessageHeaderAccessor additionalHeaders) {
		MessageHeaderAccessor headers = mutableHeaderAccessor(originalMessage);
		if (originalMessage instanceof ErrorMessage) {
			ErrorMessage errorMessage = (ErrorMessage) originalMessage;
			headers.copyHeaders(MessageHeaderPropagatorSetter.propagationHeaders(additionalHeaders.getMessageHeaders(),
					this.propagator.fields()));
			return new ErrorMessage(errorMessage.getPayload(), isWebSockets(headers) ? headers.getMessageHeaders()
					: new MessageHeaders(headers.getMessageHeaders()), errorMessage.getOriginalMessage());
		}
		headers.copyHeaders(additionalHeaders.getMessageHeaders());
		return new GenericMessage<>(retrievedMessage.getPayload(),
				isWebSockets(headers) ? headers.getMessageHeaders() : new MessageHeaders(headers.getMessageHeaders()));
	}

	private static boolean isWebSockets(MessageHeaderAccessor headerAccessor) {
		return headerAccessor.getMessageHeaders().containsKey("stompCommand")
				|| headerAccessor.getMessageHeaders().containsKey("simpMessageType");
	}

	private static boolean isDirectChannel(MessageChannel channel) {
		Class<?> targetClass = AopUtils.getTargetClass(channel);
		return (directWithAttributesChannelClass == null
				|| !directWithAttributesChannelClass.isAssignableFrom(targetClass)) && hasDirectChannelClass
				&& org.springframework.integration.channel.DirectChannel.class.isAssignableFrom(targetClass);
	}

	@Override
	public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
		if (isDirectChannel(channel)) {
			afterMessageHandled(message, channel, null, ex);
		}
		log.debug(() -> "Will finish the current span after completion " + this.tracer.currentSpan());
		finishSpan(ex);
	}

	/**
	 * This starts a consumer span as a child of the incoming message or the current trace
	 * context, placing it in scope until the receive completes.
	 */
	@Override
	public Message<?> postReceive(Message<?> message, MessageChannel channel) {
		MessageHeaderAccessor headers = mutableHeaderAccessor(message);
		log.debug(() -> "Received a message in post-receive " + message);
		Span result = this.propagator.extract(headers, this.extractor).start();
		log.debug(() -> "Extracted result from headers " + result);
		Span span = consumerSpanReceive(message, channel, headers, result);
		setSpanInScope(span);
		log.debug(() -> "Created a new span that will be injected in the headers " + span);
		this.propagator.inject(span.context(), headers, this.injector);
		log.debug(() -> "Created a new span in post receive " + span);
		headers.setImmutable();
		if (message instanceof ErrorMessage) {
			ErrorMessage errorMessage = (ErrorMessage) message;
			return new ErrorMessage(errorMessage.getPayload(), headers.getMessageHeaders(),
					errorMessage.getOriginalMessage());
		}
		return new GenericMessage<>(message.getPayload(), headers.getMessageHeaders());
	}

	private Span consumerSpanReceive(Message<?> message, MessageChannel channel, MessageHeaderAccessor headers,
			Span result) {
		Span.Builder builder = this.tracer.spanBuilder().setParent(result.context());
		MessageHeaderPropagatorSetter.removeAnyTraceHeaders(headers, this.propagator.fields());
		builder = builder.kind(Span.Kind.CONSUMER);
		builder = this.messageSpanCustomizer.customizeReceive(builder, message, channel);
		builder = builder.remoteServiceName(toRemoteServiceName(headers, remoteServiceNameMapper, applicationContext));
		return builder.start();
	}

	@Override
	public void afterReceiveCompletion(Message<?> message, MessageChannel channel, Exception ex) {
		log.debug(() -> "Will finish the current span after receive completion " + this.tracer.currentSpan());
		finishSpan(ex);
	}

	/**
	 * This starts a consumer span as a child of the incoming message or the current trace
	 * context. It then creates a span for the handler, placing it in scope.
	 */
	@Override
	public Message<?> beforeHandle(Message<?> message, MessageChannel channel, MessageHandler handler) {
		MessageHeaderAccessor headers = mutableHeaderAccessor(message);
		log.debug(() -> "Received a message in before handle " + message);
		Span consumerSpan = consumerSpan(message, channel, headers);
		// create and scope a span for the message processor
		Span handle = this.tracer.nextSpan(consumerSpan);
		handle = this.messageSpanCustomizer.customizeHandle(handle, message, channel).start();
		if (log.isDebugEnabled()) {
			log.debug("Created consumer span " + handle);
		}
		setSpanInScope(handle);
		// remove any trace headers, but don't re-inject as we are synchronously
		// processing the
		// message and can rely on scoping to access this span later.
		MessageHeaderPropagatorSetter.removeAnyTraceHeaders(headers, this.propagator.fields());
		if (log.isDebugEnabled()) {
			log.debug("Created a new span in before handle " + handle);
		}
		if (message instanceof ErrorMessage) {
			return new ErrorMessage((Throwable) message.getPayload(), headers.getMessageHeaders());
		}
		headers.setImmutable();
		return new GenericMessage<>(message.getPayload(), headers.getMessageHeaders());
	}

	private Span consumerSpan(Message<?> message, MessageChannel channel, MessageHeaderAccessor headers) {
		Span.Builder consumerSpanBuilder = this.propagator.extract(headers, this.extractor);
		if (log.isDebugEnabled()) {
			log.debug("Extracted result from headers - will finish it immediately " + consumerSpanBuilder);
		}
		// Start and finish a consumer span as we will immediately process it.
		consumerSpanBuilder.kind(Span.Kind.CONSUMER).start();
		consumerSpanBuilder.remoteServiceName(REMOTE_SERVICE_NAME);
		consumerSpanBuilder = this.messageSpanCustomizer.customizeHandle(consumerSpanBuilder, message, channel);
		Span consumerSpan = consumerSpanBuilder.start();
		consumerSpan.end();
		return consumerSpan;
	}

	@Override
	public void afterMessageHandled(Message<?> message, MessageChannel channel, MessageHandler handler, Exception ex) {
		log.debug(() -> "Will finish the current span after message handled " + this.tracer.currentSpan());
		finishSpan(ex);
	}

	void finishSpan(Exception error) {
		SpanAndScope spanAndScope = getSpanFromThreadLocal();
		if (spanAndScope == null) {
			return;
		}
		Span span = spanAndScope.getSpan();
		Tracer.SpanInScope scope = spanAndScope.getScope();
		if (span.isNoop()) {
			log.debug(() -> "Span " + span + " is noop - will stop the scope");
			scope.close();
			return;
		}
		if (error != null) { // an error occurred, adding error to span
			String message = error.getMessage();
			if (message == null) {
				message = error.getClass().getSimpleName();
			}
			span.tag("error", message);
		}
		log.debug(() -> "Will finish the and its corresponding scope " + span);
		span.end();
		scope.close();
	}

	private SpanAndScope getSpanFromThreadLocal() {
		SpanAndScope span = this.threadLocalSpan.get();
		log.debug(() -> "Took span [" + span + "] from thread local");
		this.threadLocalSpan.remove();
		return span;
	}

	private static MessageHeaderAccessor mutableHeaderAccessor(Message<?> message) {
		MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
		if (accessor != null && accessor.isMutable()) {
			return accessor;
		}
		MessageHeaderAccessor headers = MessageHeaderAccessor.getMutableAccessor(message);
		headers.setLeaveMutable(true);
		return headers;
	}

	private static Message<?> getMessage(Message<?> message) {
		Object payload = message.getPayload();
		if (payload instanceof MessagingException) {
			MessagingException e = (MessagingException) payload;
			Message<?> failedMessage = e.getFailedMessage();
			return failedMessage != null ? failedMessage : message;
		}
		return message;
	}

}