TraceMessageHandler.java
/*
* Copyright 2013-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.sleuth.instrument.messaging;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.function.Function;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.internal.SpanNameUtil;
import org.springframework.cloud.sleuth.propagation.Propagator;
import org.springframework.core.ResolvableType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.support.ErrorMessage;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.StringUtils;
// TODO: Duplicates a lot from TraceChannelInterceptor, need to figure out how to merge the two
class TraceMessageHandler {
private static final Log log = LogFactory.getLog(TraceMessageHandler.class);
/**
* Using the literal "broker" until we come up with a better solution.
*
* <p>
* If the message originated from a binder (consumer binding), there will be different
* headers present (e.g. "KafkaHeaders.RECEIVED_TOPIC" Vs.
* "AmqpHeaders.CONSUMER_QUEUE" (unless the application removes them before sending).
* These don't represent the broker, rather a queue, and in any case the heuristics
* are not great. At least we might be able to tell if this is rabbit or not (ex how
* spring-rabbit works). We need to think this through before making an api, possibly
* experimenting.
*
* <p>
* If the app is outbound only (producer), there's no indication of what type the
* destination broker is. This may hint at a non-manual solution being overwriting the
* remoteServiceName later, similar to how servlet instrumentation lazy set
* "http.route".
*/
private static final String REMOTE_SERVICE_NAME = "broker";
private static final String TRACE_HANDLER_PARENT_SPAN = "traceHandlerParentSpan";
final Tracer tracer;
private final Propagator propagator;
private final Propagator.Setter<MessageHeaderAccessor> injector;
private final Propagator.Getter<MessageHeaderAccessor> extractor;
private final Function<Span, Span> preSendFunction;
private final TriConsumer<MessageHeaderAccessor, Span, Span> preSendMessageManipulator;
private final Function<Span, Span.Builder> outputMessageSpanFunction;
private final List<FunctionMessageSpanCustomizer> customizers;
TraceMessageHandler(Tracer tracer, Propagator propagator, Propagator.Setter<MessageHeaderAccessor> injector,
Propagator.Getter<MessageHeaderAccessor> extractor, Function<Span, Span> preSendFunction,
TriConsumer<MessageHeaderAccessor, Span, Span> preSendMessageManipulator,
Function<Span, Span.Builder> outputMessageSpanFunction, List<FunctionMessageSpanCustomizer> customizers) {
this.tracer = tracer;
this.propagator = propagator;
this.injector = injector;
this.extractor = extractor;
// TODO: Abstractions to reuse in TraceChannelInterceptors?
this.preSendFunction = preSendFunction;
this.preSendMessageManipulator = preSendMessageManipulator;
this.outputMessageSpanFunction = outputMessageSpanFunction;
this.customizers = customizers;
}
static TraceMessageHandler forNonSpringIntegration(Tracer tracer, Propagator propagator,
Propagator.Setter<MessageHeaderAccessor> injector, Propagator.Getter<MessageHeaderAccessor> extractor,
List<FunctionMessageSpanCustomizer> customizers) {
Function<Span, Span> preSendFunction = span -> SleuthMessagingSpan.MESSAGING_SPAN.wrap(tracer.nextSpan(span))
.name("function").start();
TriConsumer<MessageHeaderAccessor, Span, Span> preSendMessageManipulator = (headers, parentSpan, childSpan) -> {
headers.setHeader("traceHandlerParentSpan", parentSpan);
headers.setHeader(Span.class.getName(), childSpan);
};
Function<Span, Span.Builder> postReceiveFunction = span -> tracer.spanBuilder().setParent(span.context());
return new TraceMessageHandler(tracer, propagator, injector, extractor, preSendFunction,
preSendMessageManipulator, postReceiveFunction, customizers);
}
@SuppressWarnings("unchecked")
static TraceMessageHandler forNonSpringIntegration(BeanFactory beanFactory) {
Propagator.Setter<MessageHeaderAccessor> setter = firstBeanOrException(beanFactory, Propagator.Setter.class);
Propagator.Getter<MessageHeaderAccessor> getter = firstBeanOrException(beanFactory, Propagator.Getter.class);
return forNonSpringIntegration(beanFactory.getBean(Tracer.class), beanFactory.getBean(Propagator.class), setter,
getter, customizers(beanFactory));
}
private static <T> T firstBeanOrException(BeanFactory beanFactory, Class<T> clazz) {
ObjectProvider<T> setterObjectProvider = beanFactory
.getBeanProvider(ResolvableType.forClassWithGenerics(clazz, MessageHeaderAccessor.class));
T object = setterObjectProvider.iterator().hasNext() ? setterObjectProvider.iterator().next() : null;
if (object == null) {
throw new NoSuchBeanDefinitionException("No Propagator.Setter has been defined");
}
return object;
}
private static List<FunctionMessageSpanCustomizer> customizers(BeanFactory beanFactory) {
List<FunctionMessageSpanCustomizer> customizers = new ArrayList<>();
ObjectProvider<FunctionMessageSpanCustomizer> provider = beanFactory
.getBeanProvider(FunctionMessageSpanCustomizer.class);
for (FunctionMessageSpanCustomizer functionMessageSpanCustomizer : provider) {
customizers.add(functionMessageSpanCustomizer);
}
return customizers;
}
/**
* Wraps the given input message with tracing headers and returns a corresponding
* span.
* @param message - message to wrap
* @param destinationName - destination from which the message was received
* @return a tuple with the wrapped message and a corresponding span
*/
MessageAndSpans wrapInputMessage(Message<?> message, String destinationName) {
MessageHeaderAccessor headers = mutableHeaderAccessor(message);
Span.Builder consumerSpanBuilder = SleuthMessagingSpan.MESSAGING_SPAN
.wrap(this.propagator.extract(headers, this.extractor));
Span consumerSpan = consumerSpan(destinationName, consumerSpanBuilder, message);
if (log.isDebugEnabled()) {
log.debug("Built a consumer span " + consumerSpan);
}
Span childSpan = this.preSendFunction.apply(consumerSpan);
clearTracingHeaders(headers);
this.preSendMessageManipulator.accept(headers, consumerSpan, childSpan);
this.customizers.forEach(customizer -> customizer.customizeFunctionSpan(childSpan, message));
if (message instanceof ErrorMessage) {
return new MessageAndSpans(new ErrorMessage((Throwable) message.getPayload(), headers.getMessageHeaders()),
consumerSpan, childSpan);
}
headers.setImmutable();
return new MessageAndSpans(new GenericMessage<>(message.getPayload(), headers.getMessageHeaders()),
consumerSpan, childSpan);
}
private Span consumerSpan(String destinationName, Span.Builder consumerSpanBuilder, Message<?> message) {
consumerSpanBuilder.kind(Span.Kind.CONSUMER).name("handle");
addTags(consumerSpanBuilder, destinationName);
consumerSpanBuilder.remoteServiceName(REMOTE_SERVICE_NAME);
// this is the consumer part of the producer->consumer mechanism
Span consumerSpan = consumerSpanBuilder.start();
this.customizers.forEach(customizer -> customizer.customizeInputMessageSpan(consumerSpan, message));
// we're ending this immediately just to have a properly nested graph
consumerSpan.end();
return consumerSpan;
}
Span spanFromMessage(Message<?> message) {
MessageHeaderAccessor headers = mutableHeaderAccessor(message);
Span span = span(headers, Span.class.getName());
if (span != null) {
return span;
}
span = span(headers, TRACE_HANDLER_PARENT_SPAN);
if (span != null) {
return span;
}
return this.propagator.extract(headers, this.extractor).start();
}
private void addTags(Span.Builder result, String destinationName) {
if (StringUtils.hasText(destinationName)) {
SleuthMessagingSpan.MESSAGING_SPAN.wrap(result).tag(SleuthMessagingSpan.Tags.CHANNEL,
SpanNameUtil.shorten(destinationName));
}
}
/**
* Called either when message got received and processed or message got sent.
* @param span - span that corresponds to the given operation
* @param ex - an optional exception that occurred while processing / sending.
*/
void afterMessageHandled(Span span, Throwable ex) {
if (log.isDebugEnabled()) {
log.debug("Will finish the current span after message handled " + span);
}
finishSpan(span, ex);
}
Span parentSpan(Message message) {
return span(mutableHeaderAccessor(message), "traceHandlerParentSpan");
}
Span consumerSpan(Message message) {
return span(mutableHeaderAccessor(message), Span.class.getName());
}
private Span span(MessageHeaderAccessor headerAccessor, String key) {
return headerAccessor.getMessageHeaders().get(key, Span.class);
}
/**
* Wraps the given output message with tracing headers and returns a corresponding
* span.
* @param message - message to wrap
* @param destinationName - destination to which the message should be sent
* @return a tuple with the wrapped message and a corresponding span
*/
MessageAndSpan wrapOutputMessage(Message<?> message, Span parentSpan, String destinationName) {
Message<?> retrievedMessage = getMessage(message);
MessageHeaderAccessor headers = mutableHeaderAccessor(retrievedMessage);
Span.Builder span = this.outputMessageSpanFunction.apply(parentSpan);
clearTracingHeaders(headers);
Span producerSpan = createProducerSpan(headers, span, destinationName, message);
this.propagator.inject(producerSpan.context(), headers, this.injector);
if (log.isDebugEnabled()) {
log.debug("Created a new span output message " + span);
}
return new MessageAndSpan(outputMessage(message, retrievedMessage, headers), producerSpan);
}
private Span createProducerSpan(MessageHeaderAccessor headers, Span.Builder spanBuilder, String destinationName,
Message<?> message) {
spanBuilder.kind(Span.Kind.PRODUCER).name("send").remoteServiceName(toRemoteServiceName(headers));
Span span = spanBuilder.start();
if (!span.isNoop()) {
addTags(spanBuilder, destinationName);
}
this.customizers.forEach(customizer -> customizer.customizeOutputMessageSpan(span, message));
return span;
}
private String toRemoteServiceName(MessageHeaderAccessor headers) {
for (String key : headers.getMessageHeaders().keySet()) {
if (key.startsWith("kafka_")) {
return "kafka";
}
else if (key.startsWith("amqp_")) {
return "rabbitmq";
}
}
return REMOTE_SERVICE_NAME;
}
private Message<?> outputMessage(Message<?> originalMessage, Message<?> retrievedMessage,
MessageHeaderAccessor additionalHeaders) {
MessageHeaderAccessor headers = mutableHeaderAccessor(originalMessage);
clearTechnicalTracingHeaders(headers);
if (originalMessage instanceof ErrorMessage) {
ErrorMessage errorMessage = (ErrorMessage) originalMessage;
headers.copyHeaders(MessageHeaderPropagatorSetter.propagationHeaders(additionalHeaders.getMessageHeaders(),
this.propagator.fields()));
return new ErrorMessage(errorMessage.getPayload(), isWebSockets(headers) ? headers.getMessageHeaders()
: new MessageHeaders(headers.getMessageHeaders()), errorMessage.getOriginalMessage());
}
headers.copyHeaders(additionalHeaders.getMessageHeaders());
return new GenericMessage<>(retrievedMessage.getPayload(),
isWebSockets(headers) ? headers.getMessageHeaders() : new MessageHeaders(headers.getMessageHeaders()));
}
private boolean isWebSockets(MessageHeaderAccessor headerAccessor) {
return headerAccessor.getMessageHeaders().containsKey("stompCommand")
|| headerAccessor.getMessageHeaders().containsKey("simpMessageType");
}
private Message<?> getMessage(Message<?> message) {
Object payload = message.getPayload();
if (payload instanceof MessagingException) {
MessagingException e = (MessagingException) payload;
Message<?> failedMessage = e.getFailedMessage();
return failedMessage != null ? failedMessage : message;
}
return message;
}
private MessageHeaderAccessor mutableHeaderAccessor(Message<?> message) {
MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
if (accessor != null && accessor.isMutable()) {
return accessor;
}
MessageHeaderAccessor headers = MessageHeaderAccessor.getMutableAccessor(message);
headers.setLeaveMutable(true);
return headers;
}
private void clearTracingHeaders(MessageHeaderAccessor headers) {
List<String> keysToRemove = new ArrayList<>(this.propagator.fields());
keysToRemove.add(Span.class.getName());
keysToRemove.add("traceHandlerParentSpan");
MessageHeaderPropagatorSetter.removeAnyTraceHeaders(headers, keysToRemove);
}
private void clearTechnicalTracingHeaders(MessageHeaderAccessor headers) {
MessageHeaderPropagatorSetter.removeAnyTraceHeaders(headers,
Arrays.asList(Span.class.getName(), "traceHandlerParentSpan"));
}
private void finishSpan(Span span, Throwable error) {
if (span == null || span.isNoop()) {
return;
}
if (error != null) { // an error occurred, adding error to span
String message = error.getMessage();
if (message == null) {
message = error.getClass().getSimpleName();
}
// TODO: Go with span.error(...)
span.tag("error", message);
}
span.end();
}
}
class MessageAndSpan {
final Message msg;
final Span span;
MessageAndSpan(Message msg, Span span) {
this.msg = msg;
this.span = span;
}
@Override
public String toString() {
return "MessageAndSpan{" + "msg=" + this.msg + ", span=" + this.span + '}';
}
}
class MessageAndSpans {
final Message msg;
final Span parentSpan;
final Span childSpan;
MessageAndSpans(Message msg, Span parentSpan, Span childSpan) {
this.msg = msg;
this.parentSpan = parentSpan;
this.childSpan = childSpan;
}
@Override
public String toString() {
return "MessageAndSpans{" + "msg=" + msg + ", parentSpan=" + parentSpan + ", childSpan=" + childSpan + '}';
}
}
interface TriConsumer<K, V, S> {
/**
* Performs the operation given the specified arguments.
* @param k the first input argument
* @param v the second input argument
* @param s the third input argument
*/
void accept(K k, V v, S s);
}