TracingResponderRSocketProxy.java

/*
 * Copyright 2013-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.sleuth.instrument.rsocket;

import java.util.HashSet;
import java.util.Iterator;

import io.netty.buffer.ByteBuf;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.frame.FrameType;
import io.rsocket.metadata.RoutingMetadata;
import io.rsocket.metadata.TracingMetadata;
import io.rsocket.metadata.TracingMetadataCodec;
import io.rsocket.metadata.WellKnownMimeType;
import io.rsocket.util.RSocketProxy;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.ThreadLocalSpan;
import org.springframework.cloud.sleuth.TraceContext;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.docs.AssertingSpanBuilder;
import org.springframework.cloud.sleuth.instrument.reactor.ReactorSleuth;
import org.springframework.cloud.sleuth.internal.EncodingUtils;
import org.springframework.cloud.sleuth.propagation.Propagator;

/**
 * Tracing representation of a {@link RSocketProxy} for the responder.
 *
 * @author Marcin Grzejszczak
 * @author Oleh Dokuka
 * @since 3.1.0
 */
public class TracingResponderRSocketProxy extends RSocketProxy {

	private static final Log log = LogFactory.getLog(TracingResponderRSocketProxy.class);

	private final Propagator propagator;

	private final Propagator.Getter<ByteBuf> getter;

	private final Tracer tracer;

	private final ThreadLocalSpan threadLocalSpan;

	private final boolean isZipkinPropagationEnabled;

	public TracingResponderRSocketProxy(RSocket source, Propagator propagator, Propagator.Getter<ByteBuf> getter,
			Tracer tracer, boolean isZipkinPropagationEnabled) {
		super(source);
		this.propagator = propagator;
		this.getter = getter;
		this.tracer = tracer;
		this.threadLocalSpan = new ThreadLocalSpan(tracer);
		this.isZipkinPropagationEnabled = isZipkinPropagationEnabled;
	}

	@Override
	public Mono<Void> fireAndForget(Payload payload) {
		clearThreadLocal();
		// called on Netty EventLoop
		// there can't be trace context in thread local here
		Span handle = consumerSpanBuilder(payload.sliceMetadata(), FrameType.REQUEST_FNF);
		if (log.isDebugEnabled()) {
			log.debug("Created consumer span " + handle);
		}
		final Payload newPayload = PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields()));
		return ReactorSleuth.tracedMono(this.tracer, handle, () -> super.fireAndForget(newPayload));
	}

	private void clearThreadLocal() {
		this.tracer.withSpan(null);
	}

	@Override
	public Mono<Payload> requestResponse(Payload payload) {
		clearThreadLocal();
		Span handle = consumerSpanBuilder(payload.sliceMetadata(), FrameType.REQUEST_RESPONSE);
		if (log.isDebugEnabled()) {
			log.debug("Created consumer span " + handle);
		}
		final Payload newPayload = PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields()));
		return ReactorSleuth.tracedMono(this.tracer, handle, () -> super.requestResponse(newPayload));
	}

	@Override
	public Flux<Payload> requestStream(Payload payload) {
		clearThreadLocal();
		Span handle = consumerSpanBuilder(payload.sliceMetadata(), FrameType.REQUEST_STREAM);
		if (log.isDebugEnabled()) {
			log.debug("Created consumer span " + handle);
		}
		final Payload newPayload = PayloadUtils.cleanTracingMetadata(payload, new HashSet<>(propagator.fields()));
		return ReactorSleuth.tracedFlux(this.tracer, handle, () -> super.requestStream(newPayload));
	}

	@Override
	public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
		clearThreadLocal();
		return Flux.from(payloads).switchOnFirst((firstSignal, flux) -> {
			final Payload firstPayload = firstSignal.get();
			if (firstPayload != null) {
				Span handle = consumerSpanBuilder(firstPayload.sliceMetadata(), FrameType.REQUEST_CHANNEL);
				if (handle == null) {
					return super.requestChannel(flux);
				}
				if (log.isDebugEnabled()) {
					log.debug("Created consumer span " + handle);
				}
				final Payload newPayload = PayloadUtils.cleanTracingMetadata(firstPayload,
						new HashSet<>(propagator.fields()));
				return ReactorSleuth.tracedFlux(this.tracer, handle,
						() -> super.requestChannel(flux.skip(1).startWith(newPayload)));
			}
			return flux;
		});
	}

	private Span consumerSpanBuilder(ByteBuf headers, FrameType requestType) {
		Span.Builder consumerSpanBuilder = consumerSpanBuilder(headers);
		if (log.isDebugEnabled()) {
			log.debug("Extracted result from headers " + consumerSpanBuilder);
		}
		final ByteBuf extract = CompositeMetadataUtils.extract(headers,
				WellKnownMimeType.MESSAGE_RSOCKET_ROUTING.getString());
		String name = "handle";
		if (extract != null) {
			final RoutingMetadata routingMetadata = new RoutingMetadata(extract);
			final Iterator<String> iterator = routingMetadata.iterator();
			name = requestType.name() + " " + iterator.next();
		}
		return AssertingSpanBuilder
				.of(SleuthRSocketSpan.RSOCKET_RESPONDER_SPAN, consumerSpanBuilder.kind(Span.Kind.CONSUMER)).name(name)
				.start();
	}

	private Span.Builder consumerSpanBuilder(ByteBuf headers) {
		if (this.isZipkinPropagationEnabled) {
			ByteBuf extract = CompositeMetadataUtils.extract(headers,
					WellKnownMimeType.MESSAGE_RSOCKET_TRACING_ZIPKIN.getString());
			if (extract != null) {
				TracingMetadata tracingMetadata = TracingMetadataCodec.decode(extract);
				Span.Builder builder = this.tracer.spanBuilder();
				String traceId = EncodingUtils.fromLong(tracingMetadata.traceId());
				long traceIdHigh = tracingMetadata.traceIdHigh();
				if (traceIdHigh != 0L) {
					// ExtendedTraceId
					traceId = EncodingUtils.fromLong(traceIdHigh) + traceId;
				}
				TraceContext.Builder parentBuilder = this.tracer.traceContextBuilder()
						.sampled(tracingMetadata.isDebug() || tracingMetadata.isSampled()).traceId(traceId)
						.spanId(EncodingUtils.fromLong(tracingMetadata.spanId()))
						.parentId(EncodingUtils.fromLong(tracingMetadata.parentId()));
				return builder.setParent(parentBuilder.build());
			}
			else {
				return this.propagator.extract(headers, this.getter);
			}
		}
		return this.propagator.extract(headers, this.getter);
	}

}