DiscoveryClientBasedReactiveLoadBalancer.java

/*
 * Copyright 2012-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.cloud.client.loadbalancer.reactive;

import java.util.List;
import java.util.random.RandomGenerator;
import java.util.random.RandomGeneratorFactory;

import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;

import org.springframework.cloud.client.ServiceInstance;
import org.springframework.cloud.client.discovery.DiscoveryClient;
import org.springframework.cloud.client.loadbalancer.DefaultResponse;
import org.springframework.cloud.client.loadbalancer.EmptyResponse;
import org.springframework.cloud.client.loadbalancer.Request;
import org.springframework.cloud.client.loadbalancer.Response;
import org.springframework.cloud.client.loadbalancer.RetryableRequestContext;

/**
 * A {@link ReactiveLoadBalancer} implementation used for tests.
 *
 * @author Olga Maciaszek-Sharma
 * @since 3.0.0
 */
class DiscoveryClientBasedReactiveLoadBalancer implements ReactiveLoadBalancer<ServiceInstance> {

	private final RandomGenerator random = RandomGeneratorFactory.getDefault().create();

	private final String serviceId;

	private final DiscoveryClient discoveryClient;

	DiscoveryClientBasedReactiveLoadBalancer(String serviceId, DiscoveryClient discoveryClient) {
		this.serviceId = serviceId;
		this.discoveryClient = discoveryClient;
	}

	@Override
	public Publisher<Response<ServiceInstance>> choose() {
		List<ServiceInstance> instances = discoveryClient.getInstances(serviceId);
		if (instances.isEmpty()) {
			return Mono.just(new EmptyResponse());
		}
		int instanceIdx = random.nextInt(instances.size());
		return Mono.just(new DefaultResponse(instances.get(instanceIdx)));
	}

	@Override
	public Publisher<Response<ServiceInstance>> choose(Request request) {

		List<ServiceInstance> instances = discoveryClient.getInstances(serviceId);
		if (request.getContext() instanceof RetryableRequestContext context) {
			if (context.getPreviousServiceInstance() != null) {
				List<ServiceInstance> instancesCopy = discoveryClient.getInstances(serviceId);
				instancesCopy.remove(context.getPreviousServiceInstance());
				if (!instancesCopy.isEmpty()) {
					instances = instancesCopy;
				}
			}
		}
		if (instances.isEmpty()) {
			return Mono.just(new EmptyResponse());
		}
		int instanceIdx = random.nextInt(instances.size());
		return Mono.just(new DefaultResponse(instances.get(instanceIdx)));
	}

}