W3CPropagation.java

/*
 * Copyright 2013-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.sleuth.brave.bridge;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;

import brave.baggage.BaggageField;
import brave.baggage.BaggagePropagation;
import brave.baggage.BaggagePropagationConfig;
import brave.internal.baggage.BaggageFields;
import brave.internal.propagation.StringPropagationAdapter;
import brave.propagation.Propagation;
import brave.propagation.TraceContext;
import brave.propagation.TraceContextOrSamplingFlags;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import org.springframework.cloud.sleuth.BaggageInScope;
import org.springframework.cloud.sleuth.internal.EncodingUtils;
import org.springframework.util.StringUtils;

import static java.util.Collections.singletonList;

/**
 * Adopted from OpenTelemetry API.
 *
 * Implementation of the TraceContext propagation protocol. See <a
 * href=https://github.com/w3c/distributed-tracing>w3c/distributed-tracing</a>.
 *
 * @author OpenTelemetry Authors
 * @author Marcin Grzejszczak
 * @since 3.0.0
 */
class W3CPropagation extends Propagation.Factory implements Propagation<String> {

	private static final Log logger = LogFactory.getLog(W3CPropagation.class.getName());

	static final String TRACE_PARENT = "traceparent";

	static final String TRACE_STATE = "tracestate";

	private static final List<String> FIELDS = Collections.unmodifiableList(Arrays.asList(TRACE_PARENT, TRACE_STATE));

	private static final String VERSION = "00";

	private static final int VERSION_SIZE = 2;

	private static final char TRACEPARENT_DELIMITER = '-';

	private static final int TRACEPARENT_DELIMITER_SIZE = 1;

	private static final int LONG_BYTES = Long.SIZE / Byte.SIZE;

	private static final int BYTE_BASE16 = 2;

	private static final int LONG_BASE16 = BYTE_BASE16 * LONG_BYTES;

	private static final int TRACE_ID_HEX_SIZE = 2 * LONG_BASE16;

	private static final int SPAN_ID_SIZE = 8;

	private static final int SPAN_ID_HEX_SIZE = 2 * SPAN_ID_SIZE;

	private static final int FLAGS_SIZE = 1;

	private static final int TRACE_OPTION_HEX_SIZE = 2 * FLAGS_SIZE;

	private static final int TRACE_ID_OFFSET = VERSION_SIZE + TRACEPARENT_DELIMITER_SIZE;

	private static final int SPAN_ID_OFFSET = TRACE_ID_OFFSET + TRACE_ID_HEX_SIZE + TRACEPARENT_DELIMITER_SIZE;

	private static final int TRACE_OPTION_OFFSET = SPAN_ID_OFFSET + SPAN_ID_HEX_SIZE + TRACEPARENT_DELIMITER_SIZE;

	private static final int TRACEPARENT_HEADER_SIZE = TRACE_OPTION_OFFSET + TRACE_OPTION_HEX_SIZE;

	private static final String INVALID_TRACE_ID = "00000000000000000000000000000000";

	private static final String INVALID_SPAN_ID = "0000000000000000";

	// private static final char TRACESTATE_ENTRY_DELIMITER = ',';

	private static final Set<String> VALID_VERSIONS;

	private static final String VERSION_00 = "00";

	static {
		// A valid version is 1 byte representing an 8-bit unsigned integer, version ff is
		// invalid.
		VALID_VERSIONS = new HashSet<>();
		for (int i = 0; i < 255; i++) {
			String version = Long.toHexString(i);
			if (version.length() < 2) {
				version = '0' + version;
			}
			VALID_VERSIONS.add(version);
		}
	}

	private final W3CBaggagePropagator baggagePropagator;

	private final BraveBaggageManager braveBaggageManager;

	W3CPropagation(BraveBaggageManager braveBaggageManager, List<String> localFields) {
		this.baggagePropagator = new W3CBaggagePropagator(braveBaggageManager, localFields);
		this.braveBaggageManager = braveBaggageManager;
	}

	@Override
	public <K> Propagation<K> create(KeyFactory<K> keyFactory) {
		return StringPropagationAdapter.create(this, keyFactory);
	}

	@Override
	public List<String> keys() {
		return FIELDS;
	}

	@Override
	public <R> TraceContext.Injector<R> injector(Setter<R, String> setter) {
		return (context, carrier) -> {
			Objects.requireNonNull(context, "context");
			Objects.requireNonNull(setter, "setter");
			char[] chars = TemporaryBuffers.chars(TRACEPARENT_HEADER_SIZE);
			chars[0] = VERSION.charAt(0);
			chars[1] = VERSION.charAt(1);
			chars[2] = TRACEPARENT_DELIMITER;
			String traceId = padLeftWithZeros(context.traceIdString(), TRACE_ID_HEX_SIZE);
			for (int i = 0; i < traceId.length(); i++) {
				chars[TRACE_ID_OFFSET + i] = traceId.charAt(i);
			}
			chars[SPAN_ID_OFFSET - 1] = TRACEPARENT_DELIMITER;
			String spanId = context.spanIdString();
			for (int i = 0; i < spanId.length(); i++) {
				chars[SPAN_ID_OFFSET + i] = spanId.charAt(i);
			}
			chars[TRACE_OPTION_OFFSET - 1] = TRACEPARENT_DELIMITER;
			copyTraceFlagsHexTo(chars, TRACE_OPTION_OFFSET, context);
			setter.put(carrier, TRACE_PARENT, new String(chars, 0, TRACEPARENT_HEADER_SIZE));
			addTraceState(setter, context, carrier);
			this.baggagePropagator.injector(setter).inject(context, carrier);
		};
	}

	private <R> void addTraceState(Setter<R, String> setter, TraceContext context, R carrier) {
		if (carrier != null) {
			BaggageInScope baggage = this.braveBaggageManager.getBaggage(BraveTraceContext.fromBrave(context),
					TRACE_STATE);
			if (baggage == null) {
				return;
			}
			String traceState = baggage.get(BraveTraceContext.fromBrave(context));
			if (StringUtils.hasText(traceState)) {
				setter.put(carrier, TRACE_STATE, traceState);
			}
		}
	}

	private String padLeftWithZeros(String string, int length) {
		if (string.length() >= length) {
			return string;
		}
		else {
			StringBuilder sb = new StringBuilder(length);
			for (int i = string.length(); i < length; i++) {
				sb.append('0');
			}

			return sb.append(string).toString();
		}
	}

	void copyTraceFlagsHexTo(char[] dest, int destOffset, TraceContext context) {
		dest[destOffset] = '0';
		dest[destOffset + 1] = Boolean.TRUE.equals(context.sampled()) ? '1' : '0';
	}

	@Override
	public <R> TraceContext.Extractor<R> extractor(Getter<R, String> getter) {
		Objects.requireNonNull(getter, "getter");
		return carrier -> {
			String traceParent = getter.get(carrier, TRACE_PARENT);
			if (traceParent == null) {
				return withBaggage(TraceContextOrSamplingFlags.EMPTY, carrier, getter);
			}
			TraceContext contextFromParentHeader = extractContextFromTraceParent(traceParent);
			if (contextFromParentHeader == null) {
				return withBaggage(TraceContextOrSamplingFlags.EMPTY, carrier, getter);
			}
			String traceStateHeader = getter.get(carrier, TRACE_STATE);
			return withBaggage(context(contextFromParentHeader, traceStateHeader), carrier, getter);
		};
	}

	private <R> TraceContextOrSamplingFlags withBaggage(TraceContextOrSamplingFlags context, R carrier,
			Getter<R, String> getter) {
		if (context.context() == null) {
			return context;
		}
		return this.baggagePropagator.contextWithBaggage(carrier, context, getter);
	}

	TraceContextOrSamplingFlags context(TraceContext contextFromParentHeader, String traceStateHeader) {
		if (!StringUtils.hasText(traceStateHeader)) {
			return TraceContextOrSamplingFlags.create(contextFromParentHeader);
		}
		try {
			return TraceContextOrSamplingFlags
					.newBuilder(TraceContext.newBuilder().traceId(contextFromParentHeader.traceId())
							.traceIdHigh(contextFromParentHeader.traceIdHigh()).spanId(contextFromParentHeader.spanId())
							.sampled(contextFromParentHeader.sampled()).shared(true).build())
					.build();
		}
		catch (IllegalArgumentException e) {
			// logger.info("Unparseable tracestate header. Returning span context without
			// state.");
			return TraceContextOrSamplingFlags.create(contextFromParentHeader);
		}
	}

	private static boolean isTraceIdValid(CharSequence traceId) {
		return (traceId.length() == TRACE_ID_HEX_SIZE) && !INVALID_TRACE_ID.contentEquals(traceId)
				&& EncodingUtils.isValidBase16String(traceId);
	}

	private static boolean isSpanIdValid(String spanId) {
		return (spanId.length() == SPAN_ID_HEX_SIZE) && !INVALID_SPAN_ID.equals(spanId)
				&& EncodingUtils.isValidBase16String(spanId);
	}

	private static TraceContext extractContextFromTraceParent(String traceparent) {
		// TODO(bdrutu): Do we need to verify that version is hex and that
		// for the version the length is the expected one?
		boolean isValid = (traceparent.length() == TRACEPARENT_HEADER_SIZE
				|| (traceparent.length() > TRACEPARENT_HEADER_SIZE
						&& traceparent.charAt(TRACEPARENT_HEADER_SIZE) == TRACEPARENT_DELIMITER))
				&& traceparent.charAt(TRACE_ID_OFFSET - 1) == TRACEPARENT_DELIMITER
				&& traceparent.charAt(SPAN_ID_OFFSET - 1) == TRACEPARENT_DELIMITER
				&& traceparent.charAt(TRACE_OPTION_OFFSET - 1) == TRACEPARENT_DELIMITER;
		if (!isValid) {
			// logger.info("Unparseable traceparent header. Returning INVALID span
			// context.");
			return null;
		}

		try {
			String version = traceparent.substring(0, 2);
			if (!VALID_VERSIONS.contains(version)) {
				return null;
			}
			if (version.equals(VERSION_00) && traceparent.length() > TRACEPARENT_HEADER_SIZE) {
				return null;
			}

			String traceId = traceparent.substring(TRACE_ID_OFFSET, TRACE_ID_OFFSET + TRACE_ID_HEX_SIZE);
			String spanId = traceparent.substring(SPAN_ID_OFFSET, SPAN_ID_OFFSET + SPAN_ID_HEX_SIZE);
			if (isTraceIdValid(traceId) && isSpanIdValid(spanId)) {
				String traceIdHigh = traceId.substring(0, traceId.length() / 2);
				String traceIdLow = traceId.substring(traceId.length() / 2);
				byte isSampled = TraceFlags.byteFromHex(traceparent, TRACE_OPTION_OFFSET);
				return TraceContext.newBuilder().shared(true)
						.traceIdHigh(EncodingUtils.longFromBase16String(traceIdHigh))
						.traceId(EncodingUtils.longFromBase16String(traceIdLow))
						.spanId(EncodingUtils.longFromBase16String(spanId)).sampled(isSampled == TraceFlags.IS_SAMPLED)
						.build();
			}
			return null;
		}
		catch (IllegalArgumentException e) {
			// logger.info("Unparseable traceparent header. Returning INVALID span
			// context.");
			return null;
		}
	}

}

/**
 * Taken from OpenTelemetry API.
 */
class W3CBaggagePropagator {

	private static final Log log = LogFactory.getLog(W3CBaggagePropagator.class);

	private static final String TRACE_STATE = "tracestate";

	private static final BaggageField TRACE_STATE_BAGGAGE = BaggageField.create(TRACE_STATE);

	private static final String FIELD = "baggage";

	private static final List<String> FIELDS = singletonList(FIELD);

	private final BraveBaggageManager braveBaggageManager;

	private final List<String> localFields;

	W3CBaggagePropagator(BraveBaggageManager braveBaggageManager, List<String> localFields) {
		this.braveBaggageManager = braveBaggageManager;
		this.localFields = localFields;
	}

	private BaggagePropagation.FactoryBuilder factory() {
		return BaggagePropagation.newFactoryBuilder(new Propagation.Factory() {
			@Override
			public <K> Propagation<K> create(Propagation.KeyFactory<K> keyFactory) {
				return null;
			}
		});
	}

	public List<String> keys() {
		return FIELDS;
	}

	public <R> TraceContext.Injector<R> injector(Propagation.Setter<R, String> setter) {
		return (context, carrier) -> {
			BaggageFields extra = context.findExtra(BaggageFields.class);
			if (extra == null || extra.getAllFields().isEmpty()) {
				return;
			}
			StringBuilder headerContent = new StringBuilder();
			// We ignore local keys - they won't get propagated
			String[] strings = this.localFields.toArray(new String[0]);
			Map<String, String> filtered = extra.toMapFilteringFieldNames(strings);
			for (Map.Entry<String, String> entry : filtered.entrySet()) {
				if (TRACE_STATE.equalsIgnoreCase(entry.getKey())) {
					continue;
				}
				headerContent.append(entry.getKey()).append("=").append(entry.getValue());
				// TODO: [OTEL] No metadata support
				// String metadataValue = entry.getEntryMetadata().getValue();
				// if (metadataValue != null && !metadataValue.isEmpty()) {
				// headerContent.append(";").append(metadataValue);
				// }
				headerContent.append(",");
			}
			if (headerContent.length() > 0) {
				headerContent.setLength(headerContent.length() - 1);
				setter.put(carrier, FIELD, headerContent.toString());
			}
		};
	}

	<R> TraceContextOrSamplingFlags contextWithBaggage(R carrier, TraceContextOrSamplingFlags flags,
			Propagation.Getter<R, String> getter) {
		BaggagePropagation.FactoryBuilder factoryBuilder = factory();
		String traceState = getter.get(carrier, TRACE_STATE);
		boolean hasTraceState = StringUtils.hasText(traceState);
		if (hasTraceState) {
			factoryBuilder = factoryBuilder
					.add(BaggagePropagationConfig.SingleBaggageField.remote(TRACE_STATE_BAGGAGE));
		}
		String baggageHeader = getter.get(carrier, FIELD);
		List<AbstractMap.SimpleEntry<BaggageInScope, String>> pairs = baggageHeader == null || baggageHeader.isEmpty()
				? Collections.emptyList() : addBaggageToContext(baggageHeader);
		Set<String> names = pairs.stream().map(e -> e.getKey().name()).collect(Collectors.toSet());
		for (String name : names) {
			factoryBuilder = factoryBuilder.add(BaggagePropagationConfig.SingleBaggageField
					.remote(((BraveBaggageInScope) this.braveBaggageManager.createBaggage(name)).unwrap()));
		}
		TraceContext decoratedContext = factoryBuilder.build().decorate(flags.context());
		if (hasTraceState) {
			BaggageInScope baggageInScope = this.braveBaggageManager.createBaggage(TRACE_STATE);
			baggageInScope.set(new BraveTraceContext(decoratedContext), traceState);
		}
		pairs.forEach(e -> {
			BaggageField baggage = ((BraveBaggageInScope) e.getKey()).unwrap();
			baggage.updateValue(decoratedContext, e.getValue());
		});
		return TraceContextOrSamplingFlags.create(decoratedContext);
	}

	List<AbstractMap.SimpleEntry<BaggageInScope, String>> addBaggageToContext(String baggageHeader) {
		List<AbstractMap.SimpleEntry<BaggageInScope, String>> pairs = new ArrayList<>();
		String[] entries = baggageHeader.split(",");
		for (String entry : entries) {
			int beginningOfMetadata = entry.indexOf(";");
			if (beginningOfMetadata > 0) {
				entry = entry.substring(0, beginningOfMetadata);
			}
			String[] keyAndValue = entry.split("=");
			for (int i = 0; i < keyAndValue.length; i += 2) {
				try {
					String key = keyAndValue[i].trim();
					String value = keyAndValue[i + 1].trim();
					BaggageInScope baggage = this.braveBaggageManager.createBaggage(key);
					pairs.add(new AbstractMap.SimpleEntry<>(baggage, value));
				}
				catch (Exception e) {
					if (log.isDebugEnabled()) {
						log.debug("Exception occurred while trying to parse baggage with key value ["
								+ Arrays.toString(keyAndValue) + "]. Will ignore that entry.", e);
					}
				}
			}
		}
		return pairs;
	}

}

/**
 * Taken from OpenTelemetry API.
 *
 * {@link ThreadLocal} buffers for use when creating new derived objects such as
 * {@link String}s. These buffers are reused within a single thread - it is _not safe_ to
 * use the buffer to generate multiple derived objects at the same time because the same
 * memory will be used. In general, you should get a temporary buffer, fill it with data,
 * and finish by converting into the derived object within the same method to avoid
 * multiple usages of the same buffer.
 */
final class TemporaryBuffers {

	private static final ThreadLocal<char[]> CHAR_ARRAY = new ThreadLocal<>();

	/**
	 * A {@link ThreadLocal} {@code char[]} of size {@code len}. Take care when using a
	 * large value of {@code len} as this buffer will remain for the lifetime of the
	 * thread. The returned buffer will not be zeroed and may be larger than the requested
	 * size, you must make sure to fill the entire content to the desired value and set
	 * the length explicitly when converting to a {@link String}.
	 */
	public static char[] chars(int len) {
		char[] buffer = CHAR_ARRAY.get();
		if (buffer == null) {
			buffer = new char[len];
			CHAR_ARRAY.set(buffer);
		}
		else if (buffer.length < len) {
			buffer = new char[len];
			CHAR_ARRAY.set(buffer);
		}
		return buffer;
	}

	// Visible for testing
	static void clearChars() {
		CHAR_ARRAY.set(null);
	}

	private TemporaryBuffers() {
	}

}

/**
 * Taken from OpenTelemetry API.
 */
final class TraceFlags {

	private TraceFlags() {
	}

	// Bit to represent whether trace is sampled or not.
	static final byte IS_SAMPLED = 0x1;

	/** Extract the byte representation of the flags from a hex-representation. */
	static byte byteFromHex(CharSequence src, int srcOffset) {
		return EncodingUtils.byteFromBase16String(src, srcOffset);
	}

}