CompositePropagationFactorySupplier.java

/*
 * Copyright 2013-2021 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.sleuth.brave.bridge;

import java.util.AbstractMap;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import brave.internal.propagation.StringPropagationAdapter;
import brave.propagation.B3Propagation;
import brave.propagation.Propagation;
import brave.propagation.TraceContext;
import brave.propagation.TraceContextOrSamplingFlags;
import brave.propagation.aws.AWSPropagation;

import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.cloud.sleuth.brave.propagation.PropagationFactorySupplier;
import org.springframework.cloud.sleuth.brave.propagation.PropagationType;

/**
 * Merges various propagation factories into a composite.
 *
 * @author Marcin Grzejszczak
 * @since 3.0.0
 */
public class CompositePropagationFactorySupplier implements PropagationFactorySupplier {

	private final BeanFactory beanFactory;

	private final List<String> localFields;

	private final List<PropagationType> types;

	public CompositePropagationFactorySupplier(BeanFactory beanFactory, List<String> localFields,
			List<PropagationType> types) {
		this.beanFactory = beanFactory;
		this.localFields = localFields;
		this.types = types;
	}

	@Override
	public Propagation.Factory get() {
		return new CompositePropagationFactory(this.beanFactory,
				this.beanFactory.getBeanProvider(BraveBaggageManager.class).getIfAvailable(BraveBaggageManager::new),
				this.localFields, this.types);
	}

}

class CompositePropagationFactory extends Propagation.Factory implements Propagation<String> {

	private final Map<PropagationType, Map.Entry<Propagation.Factory, Propagation<String>>> mapping = new HashMap<>();

	private final List<PropagationType> types;

	CompositePropagationFactory(BeanFactory beanFactory, BraveBaggageManager braveBaggageManager,
			List<String> localFields, List<PropagationType> types) {
		this.types = types;
		this.mapping.put(PropagationType.AWS,
				new AbstractMap.SimpleEntry<>(AWSPropagation.FACTORY, AWSPropagation.FACTORY.get()));
		// Note: Versions <2.2.3 use injectFormat(MULTI) for non-remote (ex
		// spring-messaging)
		// See #1643
		Factory b3Factory = b3Factory();
		this.mapping.put(PropagationType.B3, new AbstractMap.SimpleEntry<>(b3Factory, b3Factory.get()));
		W3CPropagation w3CPropagation = new W3CPropagation(braveBaggageManager, localFields);
		this.mapping.put(PropagationType.W3C, new AbstractMap.SimpleEntry<>(w3CPropagation, w3CPropagation.get()));
		LazyPropagationFactory lazyPropagationFactory = new LazyPropagationFactory(
				beanFactory.getBeanProvider(PropagationFactorySupplier.class));
		this.mapping.put(PropagationType.CUSTOM,
				new AbstractMap.SimpleEntry<>(lazyPropagationFactory, lazyPropagationFactory.get()));
	}

	private Factory b3Factory() {
		return B3Propagation.newFactoryBuilder().injectFormat(B3Propagation.Format.SINGLE_NO_PARENT).build();
	}

	@Override
	public List<String> keys() {
		return this.types.stream().map(this.mapping::get).flatMap(p -> p.getValue().keys().stream())
				.collect(Collectors.toList());
	}

	@Override
	public <R> TraceContext.Injector<R> injector(Setter<R, String> setter) {
		return (traceContext, request) -> {
			this.types.stream().map(this.mapping::get)
					.forEach(p -> p.getValue().injector(setter).inject(traceContext, request));
		};
	}

	@Override
	public <R> TraceContext.Extractor<R> extractor(Getter<R, String> getter) {
		return request -> {
			for (PropagationType type : this.types) {
				Map.Entry<Factory, Propagation<String>> entry = this.mapping.get(type);
				if (entry == null) {
					continue;
				}
				Propagation<String> propagator = entry.getValue();
				if (propagator == null || propagator == NoOpPropagation.INSTANCE) {
					continue;
				}
				TraceContextOrSamplingFlags extract = propagator.extractor(getter).extract(request);
				if (extract != TraceContextOrSamplingFlags.EMPTY) {
					return extract;
				}
			}
			return TraceContextOrSamplingFlags.EMPTY;
		};
	}

	@Override
	public <K> Propagation<K> create(KeyFactory<K> keyFactory) {
		return StringPropagationAdapter.create(this, keyFactory);
	}

	@Override
	public boolean supportsJoin() {
		return this.types.stream().map(this.mapping::get).allMatch(e -> e.getKey().supportsJoin());
	}

	@Override
	public boolean requires128BitTraceId() {
		return this.types.stream().map(this.mapping::get).allMatch(e -> e.getKey().requires128BitTraceId());
	}

	@Override
	public TraceContext decorate(TraceContext context) {
		for (PropagationType type : this.types) {
			Map.Entry<Factory, Propagation<String>> entry = this.mapping.get(type);
			if (entry == null) {
				continue;
			}
			TraceContext decorate = entry.getKey().decorate(context);
			if (decorate != context) {
				return decorate;
			}
		}
		return super.decorate(context);
	}

	@SuppressWarnings("unchecked")
	private static final class LazyPropagationFactory extends Propagation.Factory {

		private final ObjectProvider<PropagationFactorySupplier> delegate;

		private volatile Propagation.Factory propagationFactory;

		private LazyPropagationFactory(ObjectProvider<PropagationFactorySupplier> delegate) {
			this.delegate = delegate;
		}

		private Propagation.Factory propagationFactory() {
			if (this.propagationFactory == null) {
				this.propagationFactory = this.delegate.getIfAvailable(() -> () -> NoOpPropagation.INSTANCE).get();
			}
			return this.propagationFactory;
		}

		@Override
		public <K> Propagation<K> create(KeyFactory<K> keyFactory) {
			return propagationFactory().create(keyFactory);
		}

		@Override
		public boolean supportsJoin() {
			return propagationFactory().supportsJoin();
		}

		@Override
		public boolean requires128BitTraceId() {
			return propagationFactory().requires128BitTraceId();
		}

		@Override
		public Propagation<String> get() {
			return new LazyPropagation(this);
		}

		@Override
		public TraceContext decorate(TraceContext context) {
			return propagationFactory().decorate(context);
		}

	}

	@SuppressWarnings("unchecked")
	private static final class LazyPropagation implements Propagation<String> {

		private final LazyPropagationFactory delegate;

		private volatile Propagation<String> propagation;

		private LazyPropagation(LazyPropagationFactory delegate) {
			this.delegate = delegate;
		}

		private Propagation<String> propagation() {
			if (this.propagation == null) {
				this.propagation = this.delegate.propagationFactory().get();
			}
			return this.propagation;
		}

		@Override
		public List<String> keys() {
			return propagation().keys();
		}

		@Override
		public <R> TraceContext.Injector<R> injector(Setter<R, String> setter) {
			return propagation().injector(setter);
		}

		@Override
		public <R> TraceContext.Extractor<R> extractor(Getter<R, String> getter) {
			return propagation().extractor(getter);
		}

	}

	private static class NoOpPropagation extends Propagation.Factory implements Propagation<String> {

		static final NoOpPropagation INSTANCE = new NoOpPropagation();

		@Override
		public List<String> keys() {
			return Collections.emptyList();
		}

		@Override
		public <R> TraceContext.Injector<R> injector(Setter<R, String> setter) {
			return (traceContext, request) -> {

			};
		}

		@Override
		public <R> TraceContext.Extractor<R> extractor(Getter<R, String> getter) {
			return request -> TraceContextOrSamplingFlags.EMPTY;
		}

		@Override
		public <K> Propagation<K> create(KeyFactory<K> keyFactory) {
			return StringPropagationAdapter.create(this, keyFactory);
		}

	}

}