AggregationUpdate.java

/*
 * Copyright 2019-present the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.springframework.data.mongodb.core.aggregation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;

import org.bson.Document;
import org.jspecify.annotations.Nullable;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.data.mongodb.core.query.SerializationUtils;
import org.springframework.data.mongodb.core.query.UpdateDefinition;
import org.springframework.lang.Contract;
import org.springframework.util.Assert;

/**
 * Abstraction for {@code db.collection.update()} using an aggregation pipeline. Aggregation pipeline updates use a more
 * expressive update statement expressing conditional updates based on current field values or updating one field using
 * the value of another field(s).
 *
 * <pre class="code">
 * AggregationUpdate update = AggregationUpdate.update().set("average")
 * 		.toValue(ArithmeticOperators.valueOf("tests").avg()).set("grade")
 * 		.toValue(ConditionalOperators
 * 				.switchCases(CaseOperator.when(Gte.valueOf("average").greaterThanEqualToValue(90)).then("A"),
 * 						CaseOperator.when(Gte.valueOf("average").greaterThanEqualToValue(80)).then("B"),
 * 						CaseOperator.when(Gte.valueOf("average").greaterThanEqualToValue(70)).then("C"),
 * 						CaseOperator.when(Gte.valueOf("average").greaterThanEqualToValue(60)).then("D"))
 * 				.defaultTo("F"));
 * </pre>
 *
 * The above sample is equivalent to the JSON update statement:
 *
 * <pre class="code">
 * db.collection.update(
 *    { },
 *    [
 *      { $set: { average : { $avg: "$tests" } } },
 *      { $set: { grade: { $switch: {
 *                            branches: [
 *                                { case: { $gte: [ "$average", 90 ] }, then: "A" },
 *                                { case: { $gte: [ "$average", 80 ] }, then: "B" },
 *                                { case: { $gte: [ "$average", 70 ] }, then: "C" },
 *                                { case: { $gte: [ "$average", 60 ] }, then: "D" }
 *                            ],
 *                            default: "F"
 *      } } } }
 *    ],
 *    { multi: true }
 * )
 * </pre>
 *
 * @author Christoph Strobl
 * @author Mark Paluch
 * @see <a href="https://docs.mongodb.com/manual/reference/method/db.collection.update/#update-with-aggregation-pipeline">MongoDB
 *      Reference Documentation</a>
 * @since 3.0
 */
public class AggregationUpdate extends Aggregation implements UpdateDefinition {

	private boolean isolated = false;
	private final Set<String> keysTouched = new HashSet<>();

	/**
	 * Create new {@link AggregationUpdate}.
	 */
	protected AggregationUpdate() {
		this(new ArrayList<>());
	}

	/**
	 * Create new {@link AggregationUpdate} with the given aggregation pipeline to apply.
	 *
	 * @param pipeline must not be {@literal null}.
	 */
	protected AggregationUpdate(List<AggregationOperation> pipeline) {

		super(pipeline);

		for (AggregationOperation operation : pipeline) {
			if (operation instanceof FieldsExposingAggregationOperation exposingAggregationOperation) {
				exposingAggregationOperation.getFields().forEach(it -> keysTouched.add(it.getName()));
			}
		}
	}

	/**
	 * Start defining the update pipeline to execute.
	 *
	 * @return new instance of {@link AggregationUpdate}.
	 */
	public static AggregationUpdate update() {
		return new AggregationUpdate();
	}

	/**
	 * Create a new AggregationUpdate from the given {@link AggregationOperation}s.
	 *
	 * @return new instance of {@link AggregationUpdate}.
	 */
	public static AggregationUpdate from(List<AggregationOperation> pipeline) {
		return new AggregationUpdate(pipeline);
	}

	/**
	 * Adds new fields to documents. {@code $set} outputs documents that contain all existing fields from the input
	 * documents and newly added fields.
	 *
	 * @param setOperation must not be {@literal null}.
	 * @return this.
	 * @see <a href="https://docs.mongodb.com/manual/reference/operator/aggregation/set/">$set Aggregation Reference</a>
	 */
	@Contract("_ -> this")
	public AggregationUpdate set(SetOperation setOperation) {

		Assert.notNull(setOperation, "SetOperation must not be null");

		setOperation.getFields().forEach(it -> {
			keysTouched.add(it.getName());
		});
		pipeline.add(setOperation);
		return this;
	}

	/**
	 * {@code $unset} removes/excludes fields from documents.
	 *
	 * @param unsetOperation must not be {@literal null}.
	 * @return this.
	 * @see <a href="https://docs.mongodb.com/manual/reference/operator/aggregation/unset/">$unset Aggregation
	 *      Reference</a>
	 */
	@Contract("_ -> this")
	public AggregationUpdate unset(UnsetOperation unsetOperation) {

		Assert.notNull(unsetOperation, "UnsetOperation must not be null");

		pipeline.add(unsetOperation);
		keysTouched.addAll(unsetOperation.removedFieldNames());
		return this;
	}

	/**
	 * {@code $replaceWith} replaces the input document with the specified document. The operation replaces all existing
	 * fields in the input document, including the <strong>_id</strong> field.
	 *
	 * @param replaceWithOperation must not be {@literal null}.
	 * @return this.
	 * @see <a href="https://docs.mongodb.com/manual/reference/operator/aggregation/replaceWith/">$replaceWith Aggregation
	 *      Reference</a>
	 */
	@Contract("_ -> this")
	public AggregationUpdate replaceWith(ReplaceWithOperation replaceWithOperation) {

		Assert.notNull(replaceWithOperation, "ReplaceWithOperation must not be null");
		pipeline.add(replaceWithOperation);
		return this;
	}

	/**
	 * {@code $replaceWith} replaces the input document with the value.
	 *
	 * @param value must not be {@literal null}.
	 * @return this.
	 */
	@Contract("_ -> this")
	public AggregationUpdate replaceWith(Object value) {

		Assert.notNull(value, "Value must not be null");
		return replaceWith(ReplaceWithOperation.replaceWithValue(value));
	}

	/**
	 * Fluent API variant for {@code $set} adding a single {@link SetOperation pipeline operation} every time. To update
	 * multiple fields within one {@link SetOperation} use {@link #set(SetOperation)}.
	 *
	 * @param key must not be {@literal null}.
	 * @return new instance of {@link SetValueAppender}.
	 * @see #set(SetOperation)
	 */
	@Contract("_ -> new")
	public SetValueAppender set(String key) {

		Assert.notNull(key, "Key must not be null");

		return new SetValueAppender() {

			@Override
			public AggregationUpdate toValue(@Nullable Object value) {
				return set(SetOperation.builder().set(key).toValue(value));
			}

			@Override
			public AggregationUpdate toValueOf(Object value) {

				Assert.notNull(value, "Value must not be null");
				return set(SetOperation.builder().set(key).toValueOf(value));
			}
		};
	}

	/**
	 * Short for {@link #unset(UnsetOperation)}.
	 *
	 * @param keys the fields to remove.
	 * @return this.
	 */
	@Contract("_ -> this")
	public AggregationUpdate unset(String... keys) {

		Assert.notNull(keys, "Keys must not be null");
		Assert.noNullElements(keys, "Keys must not contain null elements");

		return unset(new UnsetOperation(Arrays.stream(keys).map(Fields::field).collect(Collectors.toList())));
	}

	/**
	 * Prevents a write operation that affects <strong>multiple</strong> documents from yielding to other reads or writes
	 * once the first document is written. <br />
	 * Use with {@link org.springframework.data.mongodb.core.MongoOperations#updateMulti(Query, UpdateDefinition, Class)}.
	 *
	 * @return never {@literal null}.
	 */
	@Contract("-> this")
	public AggregationUpdate isolated() {

		isolated = true;
		return this;
	}

	@Override
	public boolean isIsolated() {
		return isolated;
	}

	@Override
	public Document getUpdateObject() {
		return new Document("", toPipeline(Aggregation.DEFAULT_CONTEXT));
	}

	@Override
	public boolean modifies(String key) {
		return keysTouched.contains(key);
	}

	@Override
	public UpdateDefinition inc(String key) {

		set(new SetOperation(key, ArithmeticOperators.valueOf(key).add(1)));
		return this;
	}

	@Override
	public List<ArrayFilter> getArrayFilters() {
		return Collections.emptyList();
	}

	@Override
	public String toString() {

		StringJoiner joiner = new StringJoiner(",\n", "[\n", "\n]");
		toPipeline(Aggregation.DEFAULT_CONTEXT).stream().map(SerializationUtils::serializeToJsonSafely)
				.forEach(joiner::add);
		return joiner.toString();
	}

	/**
	 * Fluent API AggregationUpdate builder.
	 *
	 * @author Christoph Strobl
	 */
	public interface SetValueAppender {

		/**
		 * Define the target value as is.
		 *
		 * @param value can be {@literal null}.
		 * @return never {@literal null}.
		 */
		AggregationUpdate toValue(@Nullable Object value);

		/**
		 * Define the target value as value, an {@link AggregationExpression} or a {@link Field} reference.
		 *
		 * @param value can be {@literal null}.
		 * @return never {@literal null}.
		 */
		AggregationUpdate toValueOf(Object value);
	}
}