ArrayBindingBasedQueryEvaluationContext.java
/*******************************************************************************
* Copyright (c) 2021 Eclipse RDF4J contributors.
*
* All rights reserved. This program and the accompanying materials
* are made available under the terms of the Eclipse Distribution License v1.0
* which accompanies this distribution, and is available at
* http://www.eclipse.org/org/documents/edl-v10.php.
*
* SPDX-License-Identifier: BSD-3-Clause
*******************************************************************************/
package org.eclipse.rdf4j.query.algebra.evaluation.impl;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.eclipse.rdf4j.common.annotation.InternalUseOnly;
import org.eclipse.rdf4j.model.Literal;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.query.Binding;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.Dataset;
import org.eclipse.rdf4j.query.MutableBindingSet;
import org.eclipse.rdf4j.query.QueryEvaluationException;
import org.eclipse.rdf4j.query.algebra.BindingSetAssignment;
import org.eclipse.rdf4j.query.algebra.ExtensionElem;
import org.eclipse.rdf4j.query.algebra.Group;
import org.eclipse.rdf4j.query.algebra.GroupElem;
import org.eclipse.rdf4j.query.algebra.MultiProjection;
import org.eclipse.rdf4j.query.algebra.Projection;
import org.eclipse.rdf4j.query.algebra.ProjectionElem;
import org.eclipse.rdf4j.query.algebra.QueryModelNode;
import org.eclipse.rdf4j.query.algebra.QueryRoot;
import org.eclipse.rdf4j.query.algebra.StatementPattern;
import org.eclipse.rdf4j.query.algebra.UnaryTupleOperator;
import org.eclipse.rdf4j.query.algebra.Var;
import org.eclipse.rdf4j.query.algebra.ZeroLengthPath;
import org.eclipse.rdf4j.query.algebra.evaluation.ArrayBindingSet;
import org.eclipse.rdf4j.query.algebra.evaluation.iterator.ZeroLengthPathIteration;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractSimpleQueryModelVisitor;
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;
public final class ArrayBindingBasedQueryEvaluationContext implements QueryEvaluationContext {
public static final Predicate<BindingSet> HAS_BINDING_FALSE = (bs) -> false;
public static final Function<BindingSet, Binding> GET_BINDING_NULL = (bs) -> null;
public static final Function<BindingSet, Value> GET_VALUE_NULL = (bs) -> null;
public static final BiConsumer<Value, MutableBindingSet> SET_BINDING_NO_OP = (val, bs) -> {
};
public static final BiConsumer<Value, MutableBindingSet> ADD_BINDING_NO_OP = SET_BINDING_NO_OP;
private final QueryEvaluationContext context;
private final String[] allVariables;
private final Set<String> allVariablesSet;
private final ArrayBindingSet defaultArrayBindingSet;
private final Predicate<BindingSet>[] hasBinding;
private final Function<BindingSet, Binding>[] getBinding;
private final Function<BindingSet, Value>[] getValue;
private final BiConsumer<Value, MutableBindingSet>[] setBinding;
private final BiConsumer<Value, MutableBindingSet>[] addBinding;
private final Comparator<Value> comparator;
private final boolean initialized;
@InternalUseOnly
public ArrayBindingBasedQueryEvaluationContext(QueryEvaluationContext context, String[] allVariables,
Comparator<Value> comparator) {
assert new HashSet<>(Arrays.asList(allVariables)).size() == allVariables.length;
this.context = context;
this.allVariables = allVariables;
this.allVariablesSet = Set.of(allVariables);
this.defaultArrayBindingSet = new ArrayBindingSet(allVariables);
this.comparator = comparator;
hasBinding = new Predicate[allVariables.length];
getBinding = new Function[allVariables.length];
getValue = new Function[allVariables.length];
setBinding = new BiConsumer[allVariables.length];
addBinding = new BiConsumer[allVariables.length];
for (int i = 0; i < allVariables.length; i++) {
hasBinding[i] = hasBinding(allVariables[i]);
getBinding[i] = getBinding(allVariables[i]);
getValue[i] = getValue(allVariables[i]);
setBinding[i] = setBinding(allVariables[i]);
addBinding[i] = addBinding(allVariables[i]);
}
initialized = true;
}
@Override
public Comparator<Value> getComparator() {
return comparator;
}
@Override
public Literal getNow() {
return context.getNow();
}
@Override
public Dataset getDataset() {
return context.getDataset();
}
@Override
public ArrayBindingSet createBindingSet() {
return new ArrayBindingSet(allVariables);
}
@Override
public Predicate<BindingSet> hasBinding(String variableName) {
if (initialized) {
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i] == variableName) {
return hasBinding[i];
}
}
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i].equals(variableName)) {
return hasBinding[i];
}
}
return HAS_BINDING_FALSE;
}
assert variableName != null && !variableName.isEmpty();
Function<ArrayBindingSet, Boolean> directHasVariable = defaultArrayBindingSet.getDirectHasBinding(variableName);
if (directHasVariable != null) {
return new HasBinding(variableName, directHasVariable);
} else {
// If the variable is not in the default set, it can never be part of this array binding
return HAS_BINDING_FALSE;
}
}
static private class HasBinding implements Predicate<BindingSet> {
private final String variableName;
private final Function<ArrayBindingSet, Boolean> directHasVariable;
public HasBinding(String variableName, Function<ArrayBindingSet, Boolean> directHasVariable) {
this.variableName = variableName;
this.directHasVariable = directHasVariable;
}
@Override
public boolean test(BindingSet bs) {
if (bs.isEmpty()) {
return false;
}
if (bs instanceof ArrayBindingSet) {
return directHasVariable.apply((ArrayBindingSet) bs);
} else {
return bs.hasBinding(variableName);
}
}
}
@Override
public Function<BindingSet, Binding> getBinding(String variableName) {
if (initialized) {
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i] == variableName) {
return getBinding[i];
}
}
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i].equals(variableName)) {
return getBinding[i];
}
}
return GET_BINDING_NULL;
}
Function<ArrayBindingSet, Binding> directAccessForVariable = defaultArrayBindingSet
.getDirectGetBinding(variableName);
if (directAccessForVariable != null) {
return (bs) -> {
if (bs.isEmpty()) {
return null;
} else if (bs instanceof ArrayBindingSet) {
return directAccessForVariable.apply((ArrayBindingSet) bs);
} else {
return bs.getBinding(variableName);
}
};
} else {
return GET_BINDING_NULL;
}
}
@Override
public Function<BindingSet, Value> getValue(String variableName) {
if (initialized) {
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i] == variableName) {
return getValue[i];
}
}
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i].equals(variableName)) {
return getValue[i];
}
}
return GET_VALUE_NULL;
}
Function<ArrayBindingSet, Value> directAccessForVariable = defaultArrayBindingSet
.getDirectGetValue(variableName);
if (directAccessForVariable != null) {
return new ValueGetter(variableName, directAccessForVariable);
} else {
// If the variable is not in the default set, it can never be part of this array binding
return GET_VALUE_NULL;
}
}
private static class ValueGetter implements Function<BindingSet, Value> {
private final String variableName;
private final Function<ArrayBindingSet, Value> directAccessForVariable;
public ValueGetter(String variableName, Function<ArrayBindingSet, Value> directAccessForVariable) {
this.variableName = variableName;
this.directAccessForVariable = directAccessForVariable;
}
@Override
public Value apply(BindingSet bs) {
if (bs.isEmpty()) {
return null;
}
if (bs instanceof ArrayBindingSet) {
return directAccessForVariable.apply((ArrayBindingSet) bs);
} else {
return bs.getValue(variableName);
}
}
}
@Override
public BiConsumer<Value, MutableBindingSet> setBinding(String variableName) {
if (initialized) {
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i] == variableName) {
return setBinding[i];
}
}
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i].equals(variableName)) {
return setBinding[i];
}
}
return SET_BINDING_NO_OP;
}
BiConsumer<Value, ArrayBindingSet> directAccessForVariable = defaultArrayBindingSet
.getDirectSetBinding(variableName);
if (directAccessForVariable != null) {
return (val, bs) -> {
if (bs instanceof ArrayBindingSet) {
directAccessForVariable.accept(val, (ArrayBindingSet) bs);
} else {
bs.setBinding(variableName, val);
}
};
} else {
return SET_BINDING_NO_OP;
}
}
@Override
public BiConsumer<Value, MutableBindingSet> addBinding(String variableName) {
if (initialized) {
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i] == variableName) {
return addBinding[i];
}
}
for (int i = 0; i < allVariables.length; i++) {
if (allVariables[i].equals(variableName)) {
return addBinding[i];
}
}
return ADD_BINDING_NO_OP;
}
BiConsumer<Value, ArrayBindingSet> wrapped = defaultArrayBindingSet.getDirectAddBinding(variableName);
if (wrapped != null) {
return (val, bs) -> {
if (bs instanceof ArrayBindingSet) {
wrapped.accept(val, (ArrayBindingSet) bs);
} else {
bs.addBinding(variableName, val);
}
};
} else {
return ADD_BINDING_NO_OP;
}
}
@Override
public ArrayBindingSet createBindingSet(BindingSet bindings) {
if (bindings instanceof ArrayBindingSet) {
return new ArrayBindingSet((ArrayBindingSet) bindings, allVariables);
} else if (bindings == EmptyBindingSet.getInstance()) {
return createBindingSet();
} else {
return new ArrayBindingSet(bindings, allVariablesSet, allVariables);
}
}
public static String[] findAllVariablesUsedInQuery(QueryRoot node) {
HashMap<String, String> varNames = new LinkedHashMap<>();
AbstractSimpleQueryModelVisitor<QueryEvaluationException> queryModelVisitorBase = new AbstractSimpleQueryModelVisitor<>(
true) {
@Override
public void meetOther(QueryModelNode node) throws QueryEvaluationException {
super.meetOther(node);
}
@Override
public void meet(Var node) throws QueryEvaluationException {
super.meet(node);
// We can skip constants that are only used in StatementPatterns since these are never added to the
// BindingSet anyway
if (!(node.isConstant() && node.getParentNode() instanceof StatementPattern)) {
Var replacement = new Var(varNames.computeIfAbsent(node.getName(), k -> k), node.getValue(),
node.isAnonymous(), node.isConstant());
node.replaceWith(replacement);
}
}
@Override
public void meet(ProjectionElem node) throws QueryEvaluationException {
super.meet(node);
node.setName(varNames.computeIfAbsent(node.getName(), k -> k));
node.setProjectionAlias(varNames.computeIfAbsent(node.getProjectionAlias().orElse(null), k -> k));
}
@Override
protected void meetUnaryTupleOperator(UnaryTupleOperator node) throws QueryEvaluationException {
if (node instanceof Projection) {
node.getArg().visit(this);
((Projection) node).getProjectionElemList().visit(this);
} else {
node.visitChildren(this);
}
}
@Override
public void meet(MultiProjection node) throws QueryEvaluationException {
for (String bindingName : node.getBindingNames()) {
varNames.computeIfAbsent(bindingName, k -> k);
}
super.meet(node);
}
@Override
public void meet(ZeroLengthPath node) throws QueryEvaluationException {
varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_SUBJECT_VAR, k -> k);
varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_PREDICATE_VAR, k -> k);
varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_OBJECT_VAR, k -> k);
varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_SEQUENCE_VAR, k -> k);
super.meet(node);
}
@Override
public void meet(ExtensionElem node) throws QueryEvaluationException {
node.setName(varNames.computeIfAbsent(node.getName(), k -> k));
super.meet(node);
}
@Override
public void meet(GroupElem node) throws QueryEvaluationException {
node.setName(varNames.computeIfAbsent(node.getName(), k -> k));
super.meet(node);
}
@Override
public void meet(BindingSetAssignment node) throws QueryEvaluationException {
Set<String> bindingNames = node.getBindingNames();
Set<String> collect = bindingNames.stream()
.map(varName -> varNames.computeIfAbsent(varName, k -> k))
.collect(Collectors.toSet());
node.setBindingNames(collect);
super.meet(node);
}
@Override
public void meet(Group node) throws QueryEvaluationException {
List<String> collect = node.getGroupBindingNames()
.stream()
.map(varName -> varNames.computeIfAbsent(varName, k -> k))
.collect(Collectors.toList());
node.setGroupBindingNames(collect);
super.meet(node);
}
};
node.visit(queryModelVisitorBase);
return varNames.keySet().toArray(new String[0]);
}
}