BindSelect.java

/*******************************************************************************
 * Copyright (c) 2020 Eclipse RDF4J contributors.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Distribution License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/org/documents/edl-v10.php.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 *******************************************************************************/

package org.eclipse.rdf4j.sail.shacl.ast.planNodes;

import static java.util.stream.Collectors.toCollection;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.apache.commons.text.StringEscapeUtils;
import org.eclipse.rdf4j.common.iteration.CloseableIteration;
import org.eclipse.rdf4j.model.Resource;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.Dataset;
import org.eclipse.rdf4j.query.MalformedQueryException;
import org.eclipse.rdf4j.query.algebra.BindingSetAssignment;
import org.eclipse.rdf4j.query.algebra.TupleExpr;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractQueryModelVisitor;
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
import org.eclipse.rdf4j.sail.SailConnection;
import org.eclipse.rdf4j.sail.memory.MemoryStoreConnection;
import org.eclipse.rdf4j.sail.shacl.ast.SparqlFragment;
import org.eclipse.rdf4j.sail.shacl.ast.SparqlQueryParserCache;
import org.eclipse.rdf4j.sail.shacl.ast.StatementMatcher;
import org.eclipse.rdf4j.sail.shacl.ast.StatementMatcher.Variable;
import org.eclipse.rdf4j.sail.shacl.ast.constraintcomponents.AbstractConstraintComponent;
import org.eclipse.rdf4j.sail.shacl.ast.constraintcomponents.ConstraintComponent;
import org.eclipse.rdf4j.sail.shacl.ast.targets.EffectiveTarget;
import org.eclipse.rdf4j.sail.shacl.wrapper.data.ConnectionsGroup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Takes a plan node as a source and for each tuple in the source it will build a BindingSet from the vars and the tuple
 * and inject it into the query
 *
 * @author H��vard Ottestad
 */
public class BindSelect implements PlanNode {

	private static final Logger logger = LoggerFactory.getLogger(BindSelect.class);

	private final SailConnection connection;
	private final Dataset dataset;
	private final Function<BindingSet, ValidationTuple> mapper;

	private final String query;
	private final List<Variable<Value>> vars;
	private final int bulkSize;
	private final PlanNode source;
	private final EffectiveTarget.Extend direction;
	private final boolean includePropertyShapeValues;
	private final List<String> varNames;
	private final ConstraintComponent.Scope scope;
	private final String prefixes;
	private StackTraceElement[] stackTrace;
	private boolean printed = false;
	private ValidationExecutionLogger validationExecutionLogger;

	public BindSelect(SailConnection connection, Resource[] dataGraph, SparqlFragment query,
			List<Variable<Value>> vars, PlanNode source,
			List<String> varNames, ConstraintComponent.Scope scope, int bulkSize, EffectiveTarget.Extend direction,
			boolean includePropertyShapeValues, ConnectionsGroup connectionsGroup) {
		this.connection = connection;
		assert this.connection != null;
		this.mapper = (bindingSet) -> new ValidationTuple(bindingSet, varNames, scope, includePropertyShapeValues,
				dataGraph);
		this.varNames = varNames;
		this.scope = scope;
		this.vars = vars;
		this.bulkSize = bulkSize;
		this.source = PlanNodeHelper.handleSorting(this, source, connectionsGroup);

		if (query.getFragment().trim().equals("")) {
			throw new IllegalStateException();
		}

		this.query = StatementMatcher.StableRandomVariableProvider.normalize(query.getFragment(), vars, List.of());
		this.prefixes = query.getNamespacesForSparql();
		this.direction = direction;
		this.includePropertyShapeValues = includePropertyShapeValues;

		dataset = PlanNodeHelper.asDefaultGraphDataset(dataGraph);

		// this.stackTrace = Thread.currentThread().getStackTrace();

	}

	private void updateQuery(TupleExpr parsedQuery, List<BindingSet> newBindindingset, int expectedSize) {
		try {

			parsedQuery
					.visit(new AbstractQueryModelVisitor<Exception>() {
						@Override
						public void meet(BindingSetAssignment node) throws Exception {
							Set<String> bindingNames = node.getBindingNames();
							if (bindingNames.size() == expectedSize) { // TODO consider checking if bindingnames is
								// equal to
								// vars
								node.setBindingSets(newBindindingset);
							}
							super.meet(node);
						}

					});
		} catch (Exception e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public CloseableIteration<? extends ValidationTuple> iterator() {
		return new LoggingCloseableIteration(this, validationExecutionLogger) {

			CloseableIteration<? extends BindingSet> bindingSet;

			private CloseableIteration<? extends ValidationTuple> iterator;
			List<ValidationTuple> bulk;

			TupleExpr parsedQuery = null;

			@Override
			protected void init() {
				iterator = source.iterator();
				bulk = new ArrayList<>(bulkSize);

			}

			public void calculateNext() {

				while (bindingSet == null || !bindingSet.hasNext()) {

					if (bindingSet != null) {
						bindingSet.close();
					}

					if (bulk.isEmpty() && !iterator.hasNext()) {
						return;
					}

					ValidationTuple next;
					if (bulk.isEmpty()) {
						next = iterator.next();
						bulk.add(next);
					} else {
						next = bulk.get(0);
					}

					if (includePropertyShapeValues) {
						assert next.getScope() == ConstraintComponent.Scope.propertyShape;
						assert next.hasValue();
					}

					int targetChainSize;
					if (includePropertyShapeValues || next.getScope() != ConstraintComponent.Scope.propertyShape) {
						targetChainSize = next.getFullChainSize(true);
					} else {
						targetChainSize = next.getFullChainSize(includePropertyShapeValues);
					}

					if (parsedQuery == null) {
						parsedQuery = getParsedQuery(targetChainSize);
					}

					while (bulk.size() < bulkSize && iterator.hasNext()) {
						bulk.add(iterator.next());
					}

					List<String> varNames;

					if (direction == EffectiveTarget.Extend.right) {
						varNames = vars
								.stream()
								.limit(targetChainSize)
								.map(StatementMatcher.Variable::getName)
								.collect(Collectors.toList());
					} else {
						varNames = vars
								.stream()
								.skip(vars.size() - targetChainSize)
								.map(StatementMatcher.Variable::getName)
								.collect(Collectors.toList());
					}

					Set<String> varNamesSet = new HashSet<>(varNames);

					List<BindingSet> bindingSets = bulk
							.stream()
							.filter(t -> {
								int temp;
								if (includePropertyShapeValues
										|| t.getScope() != ConstraintComponent.Scope.propertyShape) {
									temp = t.getFullChainSize(true);
								} else {
									temp = t.getFullChainSize(includePropertyShapeValues);
								}

								return temp == targetChainSize;
							})
							.map(t -> {
								List<Value> targetChain = t.getTargetChain(includePropertyShapeValues);
								if (targetChain.size() == 1) {
									return new SingletonBindingSet(varNames.get(0), targetChain.get(0));
								} else {
									return new SimpleBindingSet(varNamesSet, varNames, targetChain);
								}
							})
							.collect(Collectors.toList());

					bulk = bulk
							.stream()
							.filter(t -> {
								int temp;
								if (includePropertyShapeValues
										|| t.getScope() != ConstraintComponent.Scope.propertyShape) {
									temp = t.getFullChainSize(true);
								} else {
									temp = t.getFullChainSize(includePropertyShapeValues);
								}

								return temp != targetChainSize;
							})
							.collect(toCollection(ArrayList::new));

					updateQuery(parsedQuery, bindingSets, targetChainSize);

					bindingSet = connection.evaluate(parsedQuery, dataset,
							EmptyBindingSet.getInstance(), true);
				}
			}

			@Override
			public void localClose() {
				try {
					bulk = null;
					parsedQuery = null;
					if (iterator != null) {
						assert !iterator.hasNext();
						iterator.close();
					}
				} finally {
					if (bindingSet != null) {
						bindingSet.close();
					}
				}
			}

			@Override
			protected boolean localHasNext() {
				calculateNext();
				return bindingSet != null && bindingSet.hasNext();
			}

			@Override
			protected ValidationTuple loggingNext() {
				calculateNext();
				return mapper.apply(bindingSet.next());
			}

		};
	}

	private TupleExpr getParsedQuery(int targetChainSize) {

		StringBuilder values = new StringBuilder("\nVALUES( ");
		if (direction == EffectiveTarget.Extend.right) {

			for (int i = 0; i < targetChainSize; i++) {
				values.append(vars.get(i).asSparqlVariable()).append(" ");
			}
		} else if (direction == EffectiveTarget.Extend.left) {
			for (int i = vars.size() - targetChainSize; i < vars.size(); i++) {
				values.append(vars.get(i).asSparqlVariable()).append(" ");
			}

		} else {
			throw new IllegalStateException("Unknown direction: " + direction);
		}

		values.append("){}\n");

		String query = BindSelect.this.query;

		query = query.replace(AbstractConstraintComponent.VALUES_INJECTION_POINT, values.toString());
		query = prefixes + "select * where { " + values + query + "\n}";

		try {
			return SparqlQueryParserCache.get(query);
		} catch (MalformedQueryException e) {
			logger.error("Malformed query:\n{}", query);
			throw e;
		}
	}

	@Override
	public int depth() {
		return 0;
	}

	@Override
	public void getPlanAsGraphvizDot(StringBuilder stringBuilder) {
		if (printed) {
			return;
		}
		printed = true;
		stringBuilder.append(getId() + " [label=\"" + StringEscapeUtils.escapeJava(this.toString()) + "\"];")
				.append("\n");

		// added/removed connections are always newly minted per plan node, so we instead need to compare the underlying
		// sail
//		if (connection instanceof MemoryStoreConnection) {
//			stringBuilder
//					.append(System.identityHashCode(((MemoryStoreConnection) connection).getSail()) + " -> " + getId())
//					.append("\n");
//		} else {
		stringBuilder.append(System.identityHashCode(connection) + " -> " + getId()).append("\n");
//		}

	}

	@Override
	public String getId() {
		return System.identityHashCode(this) + "";
	}

	@Override
	public void receiveLogger(ValidationExecutionLogger validationExecutionLogger) {
		this.validationExecutionLogger = validationExecutionLogger;
		source.receiveLogger(validationExecutionLogger);
	}

	@Override
	public boolean producesSorted() {
		return false;
	}

	@Override
	public boolean requiresSorted() {
		return true;
	}

	@Override
	public boolean equals(Object o) {
		if (this == o) {
			return true;
		}
		if (o == null || getClass() != o.getClass()) {
			return false;
		}
		BindSelect that = (BindSelect) o;

		// added/removed connections are always newly minted per plan node, so we instead need to compare the underlying
		// sail
		if (connection instanceof MemoryStoreConnection && that.connection instanceof MemoryStoreConnection) {
			return bulkSize == that.bulkSize &&
					includePropertyShapeValues == that.includePropertyShapeValues &&
					((MemoryStoreConnection) connection).getSail()
							.equals(((MemoryStoreConnection) that.connection).getSail())
					&&
					varNames.equals(that.varNames) &&
					scope.equals(that.scope) &&
					query.equals(that.query) &&
					vars.equals(that.vars) &&
					source.equals(that.source) &&
					Objects.equals(dataset, that.dataset) &&
					direction == that.direction;
		} else {
			return bulkSize == that.bulkSize &&
					includePropertyShapeValues == that.includePropertyShapeValues &&
					Objects.equals(connection, that.connection) &&
					varNames.equals(that.varNames) &&
					scope.equals(that.scope) &&
					query.equals(that.query) &&
					vars.equals(that.vars) &&
					source.equals(that.source) &&
					Objects.equals(dataset, that.dataset) &&
					direction == that.direction;
		}

	}

	@Override
	public int hashCode() {
		// added/removed connections are always newly minted per plan node, so we instead need to compare the underlying
		// sail
		if (connection instanceof MemoryStoreConnection) {
			return Objects.hash(((MemoryStoreConnection) connection).getSail(), varNames, scope, query, vars, bulkSize,
					source, direction, includePropertyShapeValues, dataset);
		} else {
			return Objects.hash(connection, varNames, scope, query, vars, bulkSize, source, direction,
					includePropertyShapeValues, dataset);
		}
	}

	@Override
	public String toString() {
		return "BindSelect{" +
				"query='" + query.replace("\n", "\t") + '\'' +
				", vars=" + vars +
				", bulkSize=" + bulkSize +
				", source=" + source +
				", direction=" + direction +
				", includePropertyShapeValues=" + includePropertyShapeValues +
				", varNames=" + varNames +
				", scope=" + scope +
				'}';
	}

}