TracingRequesterRSocketProxy.java
/*
* Copyright 2013-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.cloud.sleuth.instrument.rsocket;
import java.util.HashSet;
import java.util.Iterator;
import java.util.function.Function;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.frame.FrameType;
import io.rsocket.metadata.CompositeMetadataCodec;
import io.rsocket.metadata.RoutingMetadata;
import io.rsocket.metadata.TracingMetadataCodec;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.util.RSocketProxy;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;
import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.TraceContext;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.docs.AssertingSpanBuilder;
import org.springframework.cloud.sleuth.internal.EncodingUtils;
import org.springframework.cloud.sleuth.propagation.Propagator;
/**
* Tracing representation of a {@link RSocketProxy} for the requester.
*
* @author Marcin Grzejszczak
* @author Oleh Dokuka
* @since 3.1.0
*/
public class TracingRequesterRSocketProxy extends RSocketProxy {
private static final Log log = LogFactory.getLog(TracingRequesterRSocketProxy.class);
private final Propagator propagator;
private final Propagator.Setter<CompositeByteBuf> setter;
private final Tracer tracer;
private final boolean isZipkinPropagationEnabled;
public TracingRequesterRSocketProxy(RSocket source, Propagator propagator,
Propagator.Setter<CompositeByteBuf> setter, Tracer tracer, boolean isZipkinPropagationEnabled) {
super(source);
this.propagator = propagator;
this.setter = setter;
this.tracer = tracer;
this.isZipkinPropagationEnabled = isZipkinPropagationEnabled;
}
private void clearThreadLocal() {
this.tracer.withSpan(null);
}
@Override
public Mono<Void> fireAndForget(Payload payload) {
clearThreadLocal();
return setSpan(super::fireAndForget, payload, FrameType.REQUEST_FNF);
}
@Override
public Mono<Payload> requestResponse(Payload payload) {
clearThreadLocal();
return setSpan(super::requestResponse, payload, FrameType.REQUEST_RESPONSE);
}
<T> Mono<T> setSpan(Function<Payload, Mono<T>> input, Payload payload, FrameType frameType) {
return Mono.deferContextual(contextView -> {
Span.Builder spanBuilder = spanBuilder(contextView);
ByteBuf extracted = CompositeMetadataUtils.extract(payload.sliceMetadata(),
WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString());
// TODO: do sth about extracted == null, log that tracing can't be used or sth
final RoutingMetadata routingMetadata = new RoutingMetadata(extracted);
final Iterator<String> iterator = routingMetadata.iterator();
String route = iterator.next();
Span span = AssertingSpanBuilder
.of(SleuthRSocketSpan.RSOCKET_REQUESTER_SPAN, spanBuilder.kind(Span.Kind.PRODUCER))
.name(frameType.name() + " " + route).tag(SleuthRSocketSpan.Tags.ROUTE, route)
.tag(SleuthRSocketSpan.Tags.REQUEST_TYPE, frameType.name()).start();
if (log.isDebugEnabled()) {
log.debug("Extracted result from context or thread local " + span);
}
final Payload newPayload = PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields()));
final TraceContext traceContext = span.context();
final CompositeByteBuf metadata = (CompositeByteBuf) newPayload.metadata();
if (this.isZipkinPropagationEnabled) {
injectDefaultZipkinRSocketHeaders(metadata, traceContext);
}
this.propagator.inject(traceContext, metadata, this.setter);
return input.apply(newPayload).doOnError(span::error).doFinally(signalType -> span.end());
});
}
void injectDefaultZipkinRSocketHeaders(CompositeByteBuf metadata, TraceContext traceContext) {
TracingMetadataCodec.Flags flags = traceContext.sampled() == null ? TracingMetadataCodec.Flags.UNDECIDED
: traceContext.sampled() ? TracingMetadataCodec.Flags.SAMPLE : TracingMetadataCodec.Flags.NOT_SAMPLE;
String traceId = traceContext.traceId();
long[] traceIds = EncodingUtils.fromString(traceId);
long[] spanId = EncodingUtils.fromString(traceContext.spanId());
long[] parentSpanId = EncodingUtils.fromString(traceContext.parentId());
boolean isTraceId128Bit = traceIds.length == 2;
final ByteBufAllocator allocator = metadata.alloc();
if (isTraceId128Bit) {
CompositeMetadataCodec.encodeAndAddMetadata(metadata, allocator,
WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN,
TracingMetadataCodec.encode128(allocator, traceIds[0], traceIds[1], spanId[0],
EncodingUtils.fromString(traceContext.parentId())[0], flags));
}
else {
CompositeMetadataCodec.encodeAndAddMetadata(metadata, allocator,
WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN,
TracingMetadataCodec.encode64(allocator, traceIds[0], spanId[0], parentSpanId[0], flags));
}
}
Span.Builder spanBuilder(ContextView contextView) {
Span.Builder spanBuilder = this.tracer.spanBuilder();
if (contextView.hasKey(TraceContext.class)) {
spanBuilder = spanBuilder.setParent(contextView.get(TraceContext.class));
}
else if (this.tracer.currentSpan() != null) {
spanBuilder = spanBuilder.setParent(this.tracer.currentSpan().context());
}
return spanBuilder;
}
@Override
public Flux<Payload> requestStream(Payload payload) {
clearThreadLocal();
return Flux.deferContextual(contextView -> setSpan(super::requestStream, payload, contextView));
}
@Override
public Flux<Payload> requestChannel(Publisher<Payload> inbound) {
clearThreadLocal();
return Flux.from(inbound).switchOnFirst((firstSignal, flux) -> {
final Payload firstPayload = firstSignal.get();
if (firstPayload != null) {
return setSpan(p -> super.requestChannel(flux.skip(1).startWith(p)), firstPayload,
firstSignal.getContextView());
}
return flux;
});
}
<T> Flux<Payload> setSpan(Function<Payload, Flux<Payload>> input, Payload payload, ContextView contextView) {
Span.Builder spanBuilder = spanBuilder(contextView);
final RoutingMetadata routingMetadata = new RoutingMetadata(CompositeMetadataUtils
.extract(payload.sliceMetadata(), WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString()));
final Iterator<String> iterator = routingMetadata.iterator();
Span span = AssertingSpanBuilder
.of(SleuthRSocketSpan.RSOCKET_REQUESTER_SPAN, spanBuilder.kind(Span.Kind.PRODUCER))
.name(iterator.next()).start();
if (log.isDebugEnabled()) {
log.debug("Extracted result from context or thread local " + span);
}
final Payload newPayload = PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields()));
this.propagator.inject(span.context(), (CompositeByteBuf) newPayload.metadata(), this.setter);
return input.apply(newPayload).doOnError(span::error).doFinally(signalType -> span.end());
}
}