SparqlFragment.java
/*******************************************************************************
* Copyright (c) 2020 Eclipse RDF4J contributors.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Distribution License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* SPDX-License-Identifier: BSD-3-Clause
*******************************************************************************/
package org.eclipse.rdf4j.sail.shacl.ast;
import static org.eclipse.rdf4j.sail.shacl.ast.constraintcomponents.AbstractConstraintComponent.VALUES_INJECTION_POINT;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.eclipse.rdf4j.model.Namespace;
import org.eclipse.rdf4j.model.Resource;
import org.eclipse.rdf4j.sail.shacl.ast.paths.Path;
import org.eclipse.rdf4j.sail.shacl.ast.targets.EffectiveTarget;
import org.eclipse.rdf4j.sail.shacl.wrapper.data.ConnectionsGroup;
public class SparqlFragment {
// This is currently experimental!
private static final boolean USE_UNION_PRESERVING_JOIN = false;
private final Set<Namespace> namespaces = new HashSet<>();
private final String fragment;
private final List<String> unionFragments = new ArrayList<>();
private final List<StatementMatcher> statementMatchers = new ArrayList<>();
private final TraceBack traceBackFunction;
private boolean filterCondition;
private boolean bgp;
private boolean union;
private final boolean supportsIncrementalEvaluation;
private SparqlFragment(Collection<Namespace> namespaces, String fragment, boolean filterCondition, boolean bgp,
List<StatementMatcher> statementMatchers, TraceBack traceBackFunction,
boolean supportsIncrementalEvaluation) {
this.namespaces.addAll(namespaces);
this.fragment = fragment;
this.filterCondition = filterCondition;
this.bgp = bgp;
this.statementMatchers.addAll(statementMatchers);
this.traceBackFunction = traceBackFunction;
this.supportsIncrementalEvaluation = supportsIncrementalEvaluation;
assert filterCondition != bgp;
}
private SparqlFragment(Collection<Namespace> namespaces, List<String> unionFragments,
List<StatementMatcher> statementMatchers,
TraceBack traceBackFunction, boolean supportsIncrementalEvaluation) {
this.namespaces.addAll(namespaces);
this.fragment = null;
this.unionFragments.addAll(unionFragments);
this.union = true;
this.statementMatchers.addAll(statementMatchers);
this.traceBackFunction = traceBackFunction;
this.supportsIncrementalEvaluation = supportsIncrementalEvaluation;
}
public static SparqlFragment filterCondition(Collection<Namespace> namespaces, String fragment,
List<StatementMatcher> statementMatchers) {
return new SparqlFragment(namespaces, fragment, true, false, statementMatchers, null, true);
}
public static SparqlFragment filterCondition(Collection<Namespace> namespaces, String fragment,
List<StatementMatcher> statementMatchers,
boolean supportsIncrementalEvaluation) {
return new SparqlFragment(namespaces, fragment, true, false, statementMatchers, null,
supportsIncrementalEvaluation);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String query,
boolean supportsIncrementalEvaluation) {
return new SparqlFragment(namespaces, query, false, true, List.of(), null, supportsIncrementalEvaluation);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String fragment,
List<StatementMatcher> statementMatchers) {
return new SparqlFragment(namespaces, fragment, false, true, statementMatchers, null, true);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String fragment,
List<StatementMatcher> statementMatchers,
TraceBack traceBackFunction) {
return new SparqlFragment(namespaces, fragment, false, true, statementMatchers, traceBackFunction, true);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String fragment,
List<StatementMatcher> statementMatchers,
TraceBack traceBackFunction, boolean supportsIncrementalEvaluation) {
return new SparqlFragment(namespaces, fragment, false, true, statementMatchers, traceBackFunction,
supportsIncrementalEvaluation);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String fragment,
StatementMatcher statementMatcher) {
return new SparqlFragment(namespaces, fragment, false, true, List.of(statementMatcher), null, true);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String fragment,
StatementMatcher statementMatcher, TraceBack traceBackFunction) {
return new SparqlFragment(namespaces, fragment, false, true, List.of(statementMatcher), traceBackFunction,
true);
}
public static SparqlFragment bgp(Collection<Namespace> namespaces, String fragment) {
return new SparqlFragment(namespaces, fragment, false, true, List.of(), null, true);
}
public static SparqlFragment and(List<SparqlFragment> sparqlFragments) {
String collect = sparqlFragments.stream()
.peek(s -> {
assert s.filterCondition;
})
.map(SparqlFragment::getFragment)
.collect(Collectors.joining(" ) && ( ", "( ",
" )"));
boolean supportsIncrementalEvaluation = sparqlFragments.stream()
.allMatch(SparqlFragment::supportsIncrementalEvaluation);
Set<Namespace> namespaces = sparqlFragments.stream()
.flatMap(s -> s.namespaces.stream())
.collect(Collectors.toSet());
return filterCondition(namespaces, collect,
getStatementMatchers(sparqlFragments), supportsIncrementalEvaluation);
}
public static SparqlFragment or(List<SparqlFragment> sparqlFragments) {
String collect = sparqlFragments.stream()
.peek(s -> {
assert s.filterCondition;
})
.map(SparqlFragment::getFragment)
.collect(Collectors.joining(" ) || ( ", "( ",
" )"));
boolean supportsIncrementalEvaluation = sparqlFragments.stream()
.allMatch(SparqlFragment::supportsIncrementalEvaluation);
Set<Namespace> namespaces = sparqlFragments.stream()
.flatMap(s -> s.namespaces.stream())
.collect(Collectors.toSet());
return filterCondition(namespaces, collect,
getStatementMatchers(sparqlFragments), supportsIncrementalEvaluation);
}
public static SparqlFragment join(List<SparqlFragment> sparqlFragments) {
return join(sparqlFragments, null);
}
public static SparqlFragment join(List<SparqlFragment> sparqlFragments, TraceBack traceBackFunction) {
if (USE_UNION_PRESERVING_JOIN && sparqlFragments.stream().anyMatch(s1 -> s1.union)) {
return unionPreservingJoin(sparqlFragments, traceBackFunction);
} else {
String queryFragment = sparqlFragments.stream()
.peek(s -> {
assert !s.filterCondition;
})
.map(SparqlFragment::getFragment)
.reduce((a, b) -> a + "\n" + b)
.orElse("");
boolean supportsIncrementalEvaluation = sparqlFragments.stream()
.allMatch(SparqlFragment::supportsIncrementalEvaluation);
Set<Namespace> namespaces = sparqlFragments.stream()
.flatMap(s -> s.namespaces.stream())
.collect(Collectors.toSet());
return bgp(namespaces, queryFragment, getStatementMatchers(sparqlFragments), traceBackFunction,
supportsIncrementalEvaluation);
}
}
private static SparqlFragment unionPreservingJoin(List<SparqlFragment> sparqlFragments,
TraceBack traceBackFunction) {
List<String> workingSet = new ArrayList<>();
SparqlFragment firstSparqlFragment = sparqlFragments.get(0);
if (firstSparqlFragment.union) {
workingSet.addAll(firstSparqlFragment.unionFragments);
} else {
assert firstSparqlFragment.bgp;
workingSet.add(firstSparqlFragment.fragment);
}
for (int i = 1; i < sparqlFragments.size(); i++) {
SparqlFragment sparqlFragment = sparqlFragments.get(i);
if (sparqlFragment.union) {
List<String> newWorkingSet = new ArrayList<>();
for (String unionFragment : sparqlFragment.unionFragments) {
for (String workingSetFragment : workingSet) {
newWorkingSet.add(workingSetFragment + "\n" + unionFragment);
}
}
workingSet = newWorkingSet;
} else {
assert sparqlFragment.bgp;
workingSet = workingSet
.stream()
.map(s -> sparqlFragment.fragment + "\n" + s)
.collect(Collectors.toList());
}
}
boolean supportsIncrementalEvaluation = sparqlFragments.stream()
.allMatch(SparqlFragment::supportsIncrementalEvaluation);
Set<Namespace> namespaces = sparqlFragments.stream()
.flatMap(s -> s.namespaces.stream())
.collect(Collectors.toSet());
SparqlFragment union = unionQueryStrings(namespaces, workingSet, traceBackFunction,
supportsIncrementalEvaluation);
union.addStatementMatchers(getStatementMatchers(sparqlFragments));
return union;
}
public static boolean isFilterCondition(List<SparqlFragment> sparqlFragments) {
for (SparqlFragment sparqlFragment : sparqlFragments) {
if (sparqlFragment.isFilterCondition()) {
return true;
}
}
return false;
}
public static List<StatementMatcher> getStatementMatchers(List<SparqlFragment> sparqlFragments) {
return sparqlFragments
.stream()
.flatMap(s -> s.statementMatchers.stream())
.collect(Collectors.toList());
}
public static SparqlFragment unionQueryStrings(Set<Namespace> namespaces, List<String> query,
TraceBack traceBackFunction,
boolean supportsIncrementalEvaluation) {
return new SparqlFragment(namespaces, query, Collections.emptyList(), traceBackFunction,
supportsIncrementalEvaluation);
}
public static SparqlFragment union(List<SparqlFragment> sparqlFragments) {
List<String> sparqlFragmentString = sparqlFragments
.stream()
.map(SparqlFragment::getFragment)
.collect(Collectors.toList());
boolean supportsIncrementalEvaluation = sparqlFragments.stream()
.allMatch(SparqlFragment::supportsIncrementalEvaluation);
Set<Namespace> namespaces = sparqlFragments.stream()
.flatMap(s -> s.namespaces.stream())
.collect(Collectors.toSet());
return new SparqlFragment(namespaces, sparqlFragmentString, getStatementMatchers(sparqlFragments), null,
supportsIncrementalEvaluation);
}
public static SparqlFragment union(List<SparqlFragment> sparqlFragments, TraceBack traceBackFunction) {
List<String> sparqlFragmentString = sparqlFragments
.stream()
.map(SparqlFragment::getFragment)
.collect(Collectors.toList());
boolean supportsIncrementalEvaluation = sparqlFragments.stream()
.allMatch(SparqlFragment::supportsIncrementalEvaluation);
Set<Namespace> namespaces = sparqlFragments.stream()
.flatMap(s -> s.namespaces.stream())
.collect(Collectors.toSet());
return new SparqlFragment(namespaces, sparqlFragmentString, getStatementMatchers(sparqlFragments),
traceBackFunction,
supportsIncrementalEvaluation);
}
public static SparqlFragment unionQueryStrings(Set<Namespace> namespaces, String query1, String query2,
String query3,
boolean supportsIncrementalEvaluation) {
return new SparqlFragment(namespaces, List.of(query1, query2, query3), Collections.emptyList(), null,
supportsIncrementalEvaluation);
}
public String getFragment() {
if (union) {
return unionFragments.stream()
.collect(
Collectors.joining(
"\n} UNION {\n" + VALUES_INJECTION_POINT + "\n",
"{\n" + VALUES_INJECTION_POINT + "\n",
"\n}"));
}
return fragment;
}
public boolean isFilterCondition() {
return filterCondition;
}
public List<StatementMatcher> getStatementMatchers() {
return statementMatchers;
}
public void addStatementMatchers(List<StatementMatcher> statementMatchers) {
this.statementMatchers.addAll(statementMatchers);
}
public boolean supportsIncrementalEvaluation() {
return supportsIncrementalEvaluation;
}
public String getNamespacesForSparql() {
return ShaclPrefixParser.toSparqlPrefixes(namespaces);
}
public Stream<EffectiveTarget.SubjectObjectAndMatcher> getRoot(ConnectionsGroup connectionsGroup,
Resource[] dataGraph,
Path path, StatementMatcher currentStatementMatcher,
List<EffectiveTarget.SubjectObjectAndMatcher.SubjectObject> currentStatements) {
assert traceBackFunction != null;
return traceBackFunction.getRoot(connectionsGroup, dataGraph, path, currentStatementMatcher, currentStatements);
}
@Override
public String toString() {
return "SparqlFragment{" +
"fragment='" + fragment + '\'' +
", unionFragments=" + unionFragments +
", statementMatchers=" + statementMatchers +
", traceBackFunction=" + traceBackFunction +
", filterCondition=" + filterCondition +
", bgp=" + bgp +
", union=" + union +
", supportsIncrementalEvaluation=" + supportsIncrementalEvaluation +
'}';
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
SparqlFragment that = (SparqlFragment) o;
if (filterCondition != that.filterCondition) {
return false;
}
if (bgp != that.bgp) {
return false;
}
if (union != that.union) {
return false;
}
if (supportsIncrementalEvaluation != that.supportsIncrementalEvaluation) {
return false;
}
if (!namespaces.equals(that.namespaces)) {
return false;
}
if (!Objects.equals(fragment, that.fragment)) {
return false;
}
if (!unionFragments.equals(that.unionFragments)) {
return false;
}
if (!statementMatchers.equals(that.statementMatchers)) {
return false;
}
return Objects.equals(traceBackFunction, that.traceBackFunction);
}
@Override
public int hashCode() {
int result = namespaces.hashCode();
result = 31 * result + (fragment != null ? fragment.hashCode() : 0);
result = 31 * result + unionFragments.hashCode();
result = 31 * result + statementMatchers.hashCode();
result = 31 * result + (traceBackFunction != null ? traceBackFunction.hashCode() : 0);
result = 31 * result + (filterCondition ? 1 : 0);
result = 31 * result + (bgp ? 1 : 0);
result = 31 * result + (union ? 1 : 0);
result = 31 * result + (supportsIncrementalEvaluation ? 1 : 0);
return result;
}
public interface TraceBack {
Stream<EffectiveTarget.SubjectObjectAndMatcher> getRoot(
ConnectionsGroup connectionsGroup,
Resource[] dataGraph,
Path path,
StatementMatcher currentStatementMatcher,
List<EffectiveTarget.SubjectObjectAndMatcher.SubjectObject> currentStatements);
}
}