CursorReadingTask.java

/*
 * Copyright 2018-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.mongodb.core.messaging;

import java.time.Duration;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Supplier;

import org.jspecify.annotations.Nullable;
import org.springframework.dao.DataAccessResourceFailureException;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.messaging.Message.MessageProperties;
import org.springframework.data.mongodb.core.messaging.SubscriptionRequest.RequestOptions;
import org.springframework.data.util.Lock;
import org.springframework.util.Assert;
import org.springframework.util.ErrorHandler;

import com.mongodb.client.MongoCursor;

/**
 * @author Christoph Strobl
 * @author Mark Paluch
 * @param <T> type of objects returned by the cursor.
 * @param <R> conversion target type.
 * @since 2.1
 */
abstract class CursorReadingTask<T, R> implements Task {

	private final Lock lock = Lock.of(new ReentrantLock());

	private final MongoTemplate template;
	private final SubscriptionRequest<T, R, RequestOptions> request;
	private final Class<R> targetType;
	private final ErrorHandler errorHandler;
	private final CountDownLatch awaitStart = new CountDownLatch(1);

	private State state = State.CREATED;

	private @Nullable MongoCursor<T> cursor;

	/**
	 * @param template must not be {@literal null}.
	 * @param request must not be {@literal null}.
	 * @param targetType must not be {@literal null}.
	 */
	@SuppressWarnings({ "unchecked", "rawtypes" })
	CursorReadingTask(MongoTemplate template, SubscriptionRequest<?, ? super T, ? extends RequestOptions> request,
			Class<R> targetType, ErrorHandler errorHandler) {

		this.template = template;
		this.request = (SubscriptionRequest) request;
		this.targetType = targetType;
		this.errorHandler = errorHandler;
	}

	@Override
	public void run() {

		try {

			start();

			while (isRunning()) {

				try {

					T next = execute(this::getNext);

					if (next != null) {
						emitMessage(createMessage(next, targetType, request.getRequestOptions()));
					} else {
						Thread.sleep(10);
					}
				} catch (InterruptedException e) {

					lock.executeWithoutResult(() -> state = State.CANCELLED);
					Thread.currentThread().interrupt();
					break;
				}
			}
		} catch (RuntimeException e) {

			lock.executeWithoutResult(() -> state = State.CANCELLED);
			errorHandler.handleError(e);
		}
	}

	/**
	 * Initialize the Task by 1st setting the current state to {@link State#STARTING starting} indicating the
	 * initialization procedure. <br />
	 * Moving on the underlying {@link MongoCursor} gets {@link #initCursor(MongoTemplate, RequestOptions, Class) created}
	 * and is {@link #isValidCursor(MongoCursor) health checked}. Once a valid {@link MongoCursor} is created the
	 * {@link #state} is set to {@link State#RUNNING running}. If the health check is not passed the {@link MongoCursor}
	 * is immediately {@link MongoCursor#close() closed} and a new {@link MongoCursor} is requested until a valid one is
	 * retrieved or the {@link #state} changes.
	 */
	@SuppressWarnings("NullAway")
	private void start() {

		lock.executeWithoutResult(() -> {
			if (!State.RUNNING.equals(state)) {
				state = State.STARTING;
			}
		});

		do {

			boolean valid = lock.execute(() -> {

				if (!State.STARTING.equals(state)) {
					return false;
				}

				MongoCursor<T> cursor = execute(() -> initCursor(template, request.getRequestOptions(), targetType));
				boolean isValid = isValidCursor(cursor);
				if (isValid) {
					this.cursor = cursor;
					state = State.RUNNING;
				} else if (cursor != null) {
					cursor.close();
				}
				return isValid;
			});

			if (!valid) {

				try {
					Thread.sleep(100);
				} catch (InterruptedException e) {

					lock.executeWithoutResult(() -> state = State.CANCELLED);
					Thread.currentThread().interrupt();
				}
			}
		} while (State.STARTING.equals(getState()));

		if (awaitStart.getCount() == 1) {
			awaitStart.countDown();
		}
	}

	protected abstract MongoCursor<T> initCursor(MongoTemplate template, RequestOptions options, Class<?> targetType);

	@Override
	public void cancel() throws DataAccessResourceFailureException {

		lock.executeWithoutResult(() -> {

			if (State.RUNNING.equals(state) || State.STARTING.equals(state)) {
				this.state = State.CANCELLED;
				if (cursor != null) {
					cursor.close();
				}
			}
		});
	}

	@Override
	public boolean isLongLived() {
		return true;
	}

	@Override
	public State getState() {
		return lock.execute(() -> state);
	}

	@Override
	public boolean awaitStart(Duration timeout) throws InterruptedException {

		Assert.notNull(timeout, "Timeout must not be null");
		Assert.isTrue(!timeout.isNegative(), "Timeout must not be negative");

		return awaitStart.await(timeout.toNanos(), TimeUnit.NANOSECONDS);
	}

	@SuppressWarnings("NullAway")
	protected Message<T, R> createMessage(T source, Class<R> targetType, RequestOptions options) {

		SimpleMessage<T, T> message = new SimpleMessage<>(source, source, MessageProperties.builder()
				.databaseName(template.getDb().getName()).collectionName(options.getCollectionName()).build());

		return new LazyMappingDelegatingMessage<>(message, targetType, template.getConverter());
	}

	private boolean isRunning() {
		return State.RUNNING.equals(getState());
	}

	@SuppressWarnings("unchecked")
	private void emitMessage(Message<T, R> message) {
		try {
			request.getMessageListener().onMessage((Message) message);
		} catch (Exception e) {
			errorHandler.handleError(e);
		}
	}

	private @Nullable T getNext() {

		return lock.execute(() -> {
			if (cursor != null && State.RUNNING.equals(state)) {
				return cursor.tryNext();
			}
			throw new IllegalStateException(String.format("Cursor %s is not longer open", cursor));
		});
	}

	private static boolean isValidCursor(@Nullable MongoCursor<?> cursor) {

		if (cursor == null) {
			return false;
		}

		return cursor.getServerCursor() != null && cursor.getServerCursor().getId() != 0;
	}

	/**
	 * Execute an operation and take care of translating exceptions using the {@link MongoTemplate templates}
	 * {@link org.springframework.data.mongodb.core.MongoExceptionTranslator} rethrowing the potentially translated
	 * exception.
	 *
	 * @param callback must not be {@literal null}.
	 * @param <V>
	 * @return can be {@literal null}.
	 * @throws RuntimeException The potentially translated exception.
	 */
	private <V> @Nullable V execute(Supplier<V> callback) {

		try {
			return callback.get();
		} catch (RuntimeException e) {

			RuntimeException translated = template.getExceptionTranslator().translateExceptionIfPossible(e);
			throw translated != null ? translated : e;
		}
	}

}