TargetChainRetriever.java
/*******************************************************************************
* Copyright (c) 2020 Eclipse RDF4J contributors.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Distribution License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* SPDX-License-Identifier: BSD-3-Clause
*******************************************************************************/
package org.eclipse.rdf4j.sail.shacl.ast.targets;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.eclipse.rdf4j.common.iteration.CloseableIteration;
import org.eclipse.rdf4j.model.Resource;
import org.eclipse.rdf4j.model.Statement;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.query.Binding;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.Dataset;
import org.eclipse.rdf4j.query.MalformedQueryException;
import org.eclipse.rdf4j.query.QueryLanguage;
import org.eclipse.rdf4j.query.algebra.BindingSetAssignment;
import org.eclipse.rdf4j.query.algebra.QueryRoot;
import org.eclipse.rdf4j.query.algebra.evaluation.impl.ArrayBindingBasedQueryEvaluationContext;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractSimpleQueryModelVisitor;
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
import org.eclipse.rdf4j.query.impl.MapBindingSet;
import org.eclipse.rdf4j.query.impl.SimpleBinding;
import org.eclipse.rdf4j.query.parser.ParsedQuery;
import org.eclipse.rdf4j.query.parser.QueryParserFactory;
import org.eclipse.rdf4j.query.parser.QueryParserRegistry;
import org.eclipse.rdf4j.sail.SailConnection;
import org.eclipse.rdf4j.sail.shacl.ast.SparqlFragment;
import org.eclipse.rdf4j.sail.shacl.ast.StatementMatcher;
import org.eclipse.rdf4j.sail.shacl.ast.StatementMatcher.Variable;
import org.eclipse.rdf4j.sail.shacl.ast.constraintcomponents.ConstraintComponent;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.LoggingCloseableIteration;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.PlanNode;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.PlanNodeHelper;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.SimpleBindingSet;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.SingletonBindingSet;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.ValidationExecutionLogger;
import org.eclipse.rdf4j.sail.shacl.ast.planNodes.ValidationTuple;
import org.eclipse.rdf4j.sail.shacl.wrapper.data.ConnectionsGroup;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Used to run the query that represents the target and sets the bindings based on values that match the statement
* patterns from the added/removed sail connection
*/
public class TargetChainRetriever implements PlanNode {
private static final Logger logger = LoggerFactory.getLogger(TargetChainRetriever.class);
private static final int BULK_SIZE = 1000;
private final ConnectionsGroup connectionsGroup;
private final List<StatementMatcher> statementMatchers;
private final List<StatementMatcher> removedStatementMatchers;
private final String queryFragment;
private final QueryParserFactory queryParserFactory;
private final ConstraintComponent.Scope scope;
private final Resource[] dataGraph;
private final Dataset dataset;
private final Set<String> varNames;
private final String sparqlProjection;
private final EffectiveTarget.EffectiveTargetFragment removedStatementTarget;
private final boolean hasValue;
private final Set<String> varNamesInQueryFragment;
private final String queryStr;
private final Set<StatementMatcher> originalStatementMatchers;
private StackTraceElement[] stackTrace;
private ValidationExecutionLogger validationExecutionLogger;
public TargetChainRetriever(ConnectionsGroup connectionsGroup,
Resource[] dataGraph, List<StatementMatcher> statementMatchers,
List<StatementMatcher> removedStatementMatchers,
EffectiveTarget.EffectiveTargetFragment removedStatementTarget, SparqlFragment queryFragment,
List<Variable<Value>> vars, ConstraintComponent.Scope scope, boolean hasValue) {
this.connectionsGroup = connectionsGroup;
this.dataGraph = dataGraph;
this.varNames = vars.stream().map(StatementMatcher.Variable::getName).collect(Collectors.toSet());
assert !this.varNames.isEmpty();
this.dataset = PlanNodeHelper.asDefaultGraphDataset(this.dataGraph);
var union = statementMatchers;
if (removedStatementMatchers != null) {
union = new ArrayList<>(statementMatchers);
union.addAll(removedStatementMatchers);
}
this.queryFragment = queryFragment.getNamespacesForSparql()
+ StatementMatcher.StableRandomVariableProvider.normalize(queryFragment.getFragment(), vars, union);
this.originalStatementMatchers = new HashSet<>(statementMatchers);
this.statementMatchers = StatementMatcher.reduce(statementMatchers);
assert originalStatementMatchers.containsAll(this.statementMatchers);
this.scope = scope;
this.sparqlProjection = vars.stream()
.map(StatementMatcher.Variable::asSparqlVariable)
.reduce((a, b) -> a + " " + b)
.orElseThrow(IllegalStateException::new);
this.queryParserFactory = QueryParserRegistry.getInstance()
.get(QueryLanguage.SPARQL)
.get();
this.queryStr = "select * where {\n" + this.queryFragment + "\n}";
this.varNamesInQueryFragment = Set.of(ArrayBindingBasedQueryEvaluationContext
.findAllVariablesUsedInQuery(((QueryRoot) queryParserFactory.getParser()
.parseQuery(queryStr, null)
.getTupleExpr())));
assert !varNamesInQueryFragment.isEmpty();
this.removedStatementMatchers = removedStatementMatchers != null
? StatementMatcher.reduce(removedStatementMatchers)
: Collections.emptyList();
assert removedStatementMatchers == null || removedStatementMatchers.containsAll(this.removedStatementMatchers);
this.removedStatementTarget = removedStatementTarget;
this.hasValue = hasValue;
assert scope == ConstraintComponent.Scope.propertyShape || !this.hasValue;
if (logger.isDebugEnabled()) {
this.stackTrace = Thread.currentThread().getStackTrace();
}
}
@Override
public CloseableIteration<? extends ValidationTuple> iterator() {
return new LoggingCloseableIteration(this, validationExecutionLogger) {
private final Iterator<StatementMatcher> statementPatternIterator = statementMatchers.iterator();
private final Iterator<StatementMatcher> removedStatementIterator = removedStatementMatchers.iterator();
private StatementMatcher currentStatementMatcher;
private String sparqlValuesDecl;
private Set<String> currentVarNames;
private CloseableIteration<? extends Statement> statements;
private ValidationTuple next;
private CloseableIteration<? extends BindingSet> results;
private ParsedQuery parsedQuery;
private boolean removedStatement = false;
private final List<BindingSet> bulk = new ArrayList<>(BULK_SIZE);
@Override
protected void init() {
// no-op
}
public void calculateNextStatementMatcher() {
if (statements != null && statements.hasNext()) {
return;
}
if (!statementPatternIterator.hasNext() && !removedStatementIterator.hasNext()) {
if (statements != null) {
statements.close();
statements = null;
}
return;
}
do {
if (statements != null) {
statements.close();
statements = null;
}
if (!statementPatternIterator.hasNext() && !removedStatementIterator.hasNext()) {
break;
}
SailConnection connection;
if (statementPatternIterator.hasNext()) {
currentStatementMatcher = statementPatternIterator.next();
connection = connectionsGroup.getAddedStatements();
removedStatement = false;
} else {
if (!connectionsGroup.getStats().hasRemoved()) {
break;
}
currentStatementMatcher = removedStatementIterator.next();
connection = connectionsGroup.getRemovedStatements();
removedStatement = true;
}
// we need to add the inherited names if we are going to chase the root of the
// currentStatementMatcher later
boolean addInherited = chaseRoot();
this.sparqlValuesDecl = currentStatementMatcher.getSparqlValuesDecl(varNames, addInherited,
varNamesInQueryFragment);
this.currentVarNames = currentStatementMatcher.getVarNames(varNames, addInherited,
varNamesInQueryFragment);
if (currentVarNames.isEmpty()) {
logger.error("currentVarNames should not be empty!");
throw new IllegalStateException("currentVarNames should not be empty!");
}
statements = connection.getStatements(
currentStatementMatcher.getSubjectValue(),
currentStatementMatcher.getPredicateValue(),
currentStatementMatcher.getObjectValue(), false, dataGraph);
} while (!statements.hasNext());
parsedQuery = null;
}
private boolean chaseRoot() {
return removedStatementTarget != null && removedStatement
&& !originalStatementMatchers.contains(currentStatementMatcher);
}
private void calculateNextResult() {
if (next != null) {
return;
}
while (results == null || !results.hasNext()) {
try {
if (results != null) {
results.close();
results = null;
}
while (statements == null || !statements.hasNext()) {
calculateNextStatementMatcher();
if (statements == null) {
return;
}
}
if (parsedQuery == null) {
String query = "select " + sparqlProjection + " where {\n" +
sparqlValuesDecl +
queryFragment + "\n" +
"}";
parsedQuery = queryParserFactory.getParser().parseQuery(query, null);
}
List<BindingSet> bulk = readStatementsInBulk(currentVarNames);
setBindings(currentVarNames, bulk);
results = connectionsGroup.getBaseConnection()
.evaluate(parsedQuery.getTupleExpr(), dataset,
EmptyBindingSet.getInstance(), true);
} catch (MalformedQueryException e) {
logger.error("Malformed query:\n{}", queryFragment);
throw e;
}
}
if (results.hasNext()) {
BindingSet nextBinding = results.next();
if (nextBinding.size() == 1) {
Iterator<Binding> iterator = nextBinding.iterator();
if (iterator.hasNext()) {
next = new ValidationTuple(iterator.next().getValue(), scope, hasValue, dataGraph);
} else {
next = new ValidationTuple((Value) null, scope, hasValue, dataGraph);
}
} else {
Value[] values = StreamSupport.stream(nextBinding.spliterator(), false)
.sorted(Comparator.comparing(Binding::getName))
.map(Binding::getValue)
.toArray(Value[]::new);
next = new ValidationTuple(values, scope, hasValue, dataGraph);
}
}
}
private List<BindingSet> readStatementsInBulk(Set<String> variableNames) {
bulk.clear();
while (bulk.size() < BULK_SIZE && statements.hasNext()) {
Statement next = statements.next();
Stream<EffectiveTarget.SubjectObjectAndMatcher> rootStatements = Stream
.of(new EffectiveTarget.SubjectObjectAndMatcher(
List.of(new EffectiveTarget.SubjectObjectAndMatcher.SubjectObject(next)),
currentStatementMatcher));
if (chaseRoot()) {
// we only need to find the root if the currentStatementMatcher doesn't match anything in the
// query
Stream<EffectiveTarget.SubjectObjectAndMatcher> root = removedStatementTarget.getRoot(
connectionsGroup,
dataGraph, currentStatementMatcher,
next);
if (root != null) {
rootStatements = root;
}
}
rootStatements
.filter(EffectiveTarget.SubjectObjectAndMatcher::hasStatements)
.flatMap(statementsAndMatcher -> {
StatementMatcher newCurrentStatementMatcher = statementsAndMatcher
.getStatementMatcher();
return statementsAndMatcher.getStatements()
.stream()
.map(temp -> {
Binding[] bindings = new Binding[variableNames.size()];
int j = 0;
assert newCurrentStatementMatcher.getPredicateValue() != null
|| !currentVarNames
.contains(newCurrentStatementMatcher.getPredicateName());
if (newCurrentStatementMatcher.getSubjectValue() == null
&& currentVarNames
.contains(newCurrentStatementMatcher.getSubjectName())) {
bindings[j++] = new SimpleBinding(
newCurrentStatementMatcher.getSubjectName(),
temp.getSubject());
}
if (newCurrentStatementMatcher.getObjectValue() == null
&& currentVarNames
.contains(newCurrentStatementMatcher.getObjectName())) {
bindings[j++] = new SimpleBinding(
newCurrentStatementMatcher.getObjectName(),
temp.getObject());
}
if (bindings.length == 1) {
if (bindings[0] == null) {
throw new IllegalStateException("Binding is null!");
}
return new SingletonBindingSet(bindings[0].getName(),
bindings[0].getValue());
} else {
return new SimpleBindingSet(variableNames, bindings);
}
});
})
.distinct()
.forEach(bulk::add);
}
return bulk;
}
private void setBindings(Set<String> varNames, List<BindingSet> bulk) {
parsedQuery.getTupleExpr()
.visit(new AbstractSimpleQueryModelVisitor<>(false) {
@Override
public void meet(BindingSetAssignment node) {
Set<String> bindingNames = node.getBindingNames();
if (bindingNames.equals(varNames)) {
node.setBindingSets(bulk);
}
super.meet(node);
}
});
}
@Override
public void localClose() {
try {
if (statements != null) {
statements.close();
}
} finally {
if (results != null) {
results.close();
}
}
}
@Override
protected ValidationTuple loggingNext() {
calculateNextResult();
ValidationTuple temp = next;
next = null;
return temp;
}
@Override
protected boolean localHasNext() {
calculateNextResult();
return next != null;
}
};
}
private static boolean bindingsEquivalent(StatementMatcher currentStatementMatcher, MapBindingSet bindings,
MapBindingSet previousBindings) {
if (currentStatementMatcher == null || bindings == null || previousBindings == null) {
return false;
}
boolean equivalent = true;
if (equivalent && currentStatementMatcher.getSubjectValue() == null
&& !currentStatementMatcher.subjectIsWildcard()) {
equivalent = Objects.equals(bindings.getBinding(currentStatementMatcher.getSubjectName()),
previousBindings.getBinding(currentStatementMatcher.getSubjectName()));
}
if (equivalent && currentStatementMatcher.getPredicateValue() == null
&& !currentStatementMatcher.predicateIsWildcard()) {
equivalent = Objects.equals(bindings.getBinding(currentStatementMatcher.getPredicateName()),
previousBindings.getBinding(currentStatementMatcher.getPredicateName()));
}
if (equivalent && currentStatementMatcher.getObjectValue() == null
&& !currentStatementMatcher.objectIsWildcard()) {
equivalent = Objects.equals(bindings.getBinding(currentStatementMatcher.getObjectName()),
previousBindings.getBinding(currentStatementMatcher.getObjectName()));
}
return equivalent;
}
@Override
public int depth() {
return 0;
}
@Override
public void getPlanAsGraphvizDot(StringBuilder stringBuilder) {
}
@Override
public String getId() {
return System.identityHashCode(this) + "";
}
@Override
public void receiveLogger(ValidationExecutionLogger validationExecutionLogger) {
this.validationExecutionLogger = validationExecutionLogger;
}
@Override
public boolean producesSorted() {
return false;
}
@Override
public boolean requiresSorted() {
return false;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TargetChainRetriever that = (TargetChainRetriever) o;
return statementMatchers.equals(that.statementMatchers) &&
removedStatementMatchers.equals(that.removedStatementMatchers) &&
queryFragment.equals(that.queryFragment) &&
Objects.equals(dataset, that.dataset) &&
scope == that.scope;
}
@Override
public int hashCode() {
return Objects.hash(statementMatchers, removedStatementMatchers, queryFragment, scope, dataset);
}
@Override
public String toString() {
return "TargetChainRetriever{" +
"statementPatterns=" + statementMatchers +
", removedStatementMatchers=" + removedStatementMatchers +
", query='" + queryFragment.replace("\n", "\t") + '\'' +
", scope=" + scope +
'}';
}
}