ArrayBindingBasedQueryEvaluationContext.java

/*******************************************************************************
 * Copyright (c) 2021 Eclipse RDF4J contributors.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Distribution License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/org/documents/edl-v10.php.
 *
 * SPDX-License-Identifier: BSD-3-Clause
 *******************************************************************************/
package org.eclipse.rdf4j.query.algebra.evaluation.impl;

import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;

import org.eclipse.rdf4j.common.annotation.InternalUseOnly;
import org.eclipse.rdf4j.model.Literal;
import org.eclipse.rdf4j.model.Value;
import org.eclipse.rdf4j.query.Binding;
import org.eclipse.rdf4j.query.BindingSet;
import org.eclipse.rdf4j.query.Dataset;
import org.eclipse.rdf4j.query.MutableBindingSet;
import org.eclipse.rdf4j.query.QueryEvaluationException;
import org.eclipse.rdf4j.query.algebra.BindingSetAssignment;
import org.eclipse.rdf4j.query.algebra.ExtensionElem;
import org.eclipse.rdf4j.query.algebra.Group;
import org.eclipse.rdf4j.query.algebra.GroupElem;
import org.eclipse.rdf4j.query.algebra.MultiProjection;
import org.eclipse.rdf4j.query.algebra.Projection;
import org.eclipse.rdf4j.query.algebra.ProjectionElem;
import org.eclipse.rdf4j.query.algebra.QueryModelNode;
import org.eclipse.rdf4j.query.algebra.QueryRoot;
import org.eclipse.rdf4j.query.algebra.StatementPattern;
import org.eclipse.rdf4j.query.algebra.UnaryTupleOperator;
import org.eclipse.rdf4j.query.algebra.Var;
import org.eclipse.rdf4j.query.algebra.ZeroLengthPath;
import org.eclipse.rdf4j.query.algebra.evaluation.ArrayBindingSet;
import org.eclipse.rdf4j.query.algebra.evaluation.iterator.ZeroLengthPathIteration;
import org.eclipse.rdf4j.query.algebra.helpers.AbstractSimpleQueryModelVisitor;
import org.eclipse.rdf4j.query.impl.EmptyBindingSet;

public final class ArrayBindingBasedQueryEvaluationContext implements QueryEvaluationContext {

	public static final Predicate<BindingSet> HAS_BINDING_FALSE = (bs) -> false;
	public static final Function<BindingSet, Binding> GET_BINDING_NULL = (bs) -> null;
	public static final Function<BindingSet, Value> GET_VALUE_NULL = (bs) -> null;
	public static final BiConsumer<Value, MutableBindingSet> SET_BINDING_NO_OP = (val, bs) -> {
	};
	public static final BiConsumer<Value, MutableBindingSet> ADD_BINDING_NO_OP = SET_BINDING_NO_OP;

	private final QueryEvaluationContext context;
	private final String[] allVariables;
	private final Set<String> allVariablesSet;
	private final ArrayBindingSet defaultArrayBindingSet;
	private final Predicate<BindingSet>[] hasBinding;
	private final Function<BindingSet, Binding>[] getBinding;
	private final Function<BindingSet, Value>[] getValue;
	private final BiConsumer<Value, MutableBindingSet>[] setBinding;
	private final BiConsumer<Value, MutableBindingSet>[] addBinding;
	private final Comparator<Value> comparator;

	private final boolean initialized;

	@InternalUseOnly
	public ArrayBindingBasedQueryEvaluationContext(QueryEvaluationContext context, String[] allVariables,
			Comparator<Value> comparator) {
		assert new HashSet<>(Arrays.asList(allVariables)).size() == allVariables.length;
		this.context = context;
		this.allVariables = allVariables;
		this.allVariablesSet = Set.of(allVariables);
		this.defaultArrayBindingSet = new ArrayBindingSet(allVariables);
		this.comparator = comparator;

		hasBinding = new Predicate[allVariables.length];
		getBinding = new Function[allVariables.length];
		getValue = new Function[allVariables.length];
		setBinding = new BiConsumer[allVariables.length];
		addBinding = new BiConsumer[allVariables.length];

		for (int i = 0; i < allVariables.length; i++) {
			hasBinding[i] = hasBinding(allVariables[i]);
			getBinding[i] = getBinding(allVariables[i]);
			getValue[i] = getValue(allVariables[i]);
			setBinding[i] = setBinding(allVariables[i]);
			addBinding[i] = addBinding(allVariables[i]);
		}

		initialized = true;

	}

	@Override
	public Comparator<Value> getComparator() {
		return comparator;
	}

	@Override
	public Literal getNow() {
		return context.getNow();
	}

	@Override
	public Dataset getDataset() {
		return context.getDataset();
	}

	@Override
	public ArrayBindingSet createBindingSet() {
		return new ArrayBindingSet(allVariables);
	}

	@Override
	public Predicate<BindingSet> hasBinding(String variableName) {
		if (initialized) {
			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i] == variableName) {
					return hasBinding[i];
				}
			}

			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i].equals(variableName)) {
					return hasBinding[i];
				}
			}

			return HAS_BINDING_FALSE;
		}

		assert variableName != null && !variableName.isEmpty();

		Function<ArrayBindingSet, Boolean> directHasVariable = defaultArrayBindingSet.getDirectHasBinding(variableName);

		if (directHasVariable != null) {
			return new HasBinding(variableName, directHasVariable);
		} else {
			// If the variable is not in the default set, it can never be part of this array binding
			return HAS_BINDING_FALSE;
		}

	}

	static private class HasBinding implements Predicate<BindingSet> {

		private final String variableName;
		private final Function<ArrayBindingSet, Boolean> directHasVariable;

		public HasBinding(String variableName, Function<ArrayBindingSet, Boolean> directHasVariable) {
			this.variableName = variableName;
			this.directHasVariable = directHasVariable;
		}

		@Override
		public boolean test(BindingSet bs) {
			if (bs.isEmpty()) {
				return false;
			}
			if (bs instanceof ArrayBindingSet) {
				return directHasVariable.apply((ArrayBindingSet) bs);
			} else {
				return bs.hasBinding(variableName);
			}
		}
	}

	@Override
	public Function<BindingSet, Binding> getBinding(String variableName) {
		if (initialized) {
			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i] == variableName) {
					return getBinding[i];
				}
			}

			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i].equals(variableName)) {
					return getBinding[i];
				}
			}

			return GET_BINDING_NULL;
		}

		Function<ArrayBindingSet, Binding> directAccessForVariable = defaultArrayBindingSet
				.getDirectGetBinding(variableName);

		if (directAccessForVariable != null) {
			return (bs) -> {
				if (bs.isEmpty()) {
					return null;
				} else if (bs instanceof ArrayBindingSet) {
					return directAccessForVariable.apply((ArrayBindingSet) bs);
				} else {
					return bs.getBinding(variableName);
				}
			};
		} else {
			return GET_BINDING_NULL;
		}
	}

	@Override
	public Function<BindingSet, Value> getValue(String variableName) {
		if (initialized) {
			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i] == variableName) {
					return getValue[i];
				}
			}

			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i].equals(variableName)) {
					return getValue[i];
				}
			}

			return GET_VALUE_NULL;
		}

		Function<ArrayBindingSet, Value> directAccessForVariable = defaultArrayBindingSet
				.getDirectGetValue(variableName);

		if (directAccessForVariable != null) {
			return new ValueGetter(variableName, directAccessForVariable);
		} else {
			// If the variable is not in the default set, it can never be part of this array binding
			return GET_VALUE_NULL;
		}

	}

	private static class ValueGetter implements Function<BindingSet, Value> {

		private final String variableName;
		private final Function<ArrayBindingSet, Value> directAccessForVariable;

		public ValueGetter(String variableName, Function<ArrayBindingSet, Value> directAccessForVariable) {

			this.variableName = variableName;
			this.directAccessForVariable = directAccessForVariable;
		}

		@Override
		public Value apply(BindingSet bs) {
			if (bs.isEmpty()) {
				return null;
			}
			if (bs instanceof ArrayBindingSet) {
				return directAccessForVariable.apply((ArrayBindingSet) bs);
			} else {
				return bs.getValue(variableName);
			}
		}
	}

	@Override
	public BiConsumer<Value, MutableBindingSet> setBinding(String variableName) {
		if (initialized) {
			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i] == variableName) {
					return setBinding[i];
				}
			}

			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i].equals(variableName)) {
					return setBinding[i];
				}
			}

			return SET_BINDING_NO_OP;
		}

		BiConsumer<Value, ArrayBindingSet> directAccessForVariable = defaultArrayBindingSet
				.getDirectSetBinding(variableName);
		if (directAccessForVariable != null) {
			return (val, bs) -> {
				if (bs instanceof ArrayBindingSet) {
					directAccessForVariable.accept(val, (ArrayBindingSet) bs);
				} else {
					bs.setBinding(variableName, val);
				}
			};
		} else {
			return SET_BINDING_NO_OP;
		}
	}

	@Override
	public BiConsumer<Value, MutableBindingSet> addBinding(String variableName) {
		if (initialized) {
			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i] == variableName) {
					return addBinding[i];
				}
			}
			for (int i = 0; i < allVariables.length; i++) {
				if (allVariables[i].equals(variableName)) {
					return addBinding[i];
				}
			}

			return ADD_BINDING_NO_OP;
		}

		BiConsumer<Value, ArrayBindingSet> wrapped = defaultArrayBindingSet.getDirectAddBinding(variableName);
		if (wrapped != null) {
			return (val, bs) -> {
				if (bs instanceof ArrayBindingSet) {
					wrapped.accept(val, (ArrayBindingSet) bs);
				} else {
					bs.addBinding(variableName, val);
				}
			};
		} else {
			return ADD_BINDING_NO_OP;
		}
	}

	@Override
	public ArrayBindingSet createBindingSet(BindingSet bindings) {
		if (bindings instanceof ArrayBindingSet) {
			return new ArrayBindingSet((ArrayBindingSet) bindings, allVariables);
		} else if (bindings == EmptyBindingSet.getInstance()) {
			return createBindingSet();
		} else {
			return new ArrayBindingSet(bindings, allVariablesSet, allVariables);
		}
	}

	public static String[] findAllVariablesUsedInQuery(QueryRoot node) {
		HashMap<String, String> varNames = new LinkedHashMap<>();
		AbstractSimpleQueryModelVisitor<QueryEvaluationException> queryModelVisitorBase = new AbstractSimpleQueryModelVisitor<>(
				true) {

			@Override
			public void meetOther(QueryModelNode node) throws QueryEvaluationException {
				super.meetOther(node);
			}

			@Override
			public void meet(Var node) throws QueryEvaluationException {
				super.meet(node);
				// We can skip constants that are only used in StatementPatterns since these are never added to the
				// BindingSet anyway
				if (!(node.isConstant() && node.getParentNode() instanceof StatementPattern)) {
					Var replacement = new Var(varNames.computeIfAbsent(node.getName(), k -> k), node.getValue(),
							node.isAnonymous(), node.isConstant());
					node.replaceWith(replacement);
				}
			}

			@Override
			public void meet(ProjectionElem node) throws QueryEvaluationException {
				super.meet(node);
				node.setName(varNames.computeIfAbsent(node.getName(), k -> k));
				node.setProjectionAlias(varNames.computeIfAbsent(node.getProjectionAlias().orElse(null), k -> k));
			}

			@Override
			protected void meetUnaryTupleOperator(UnaryTupleOperator node) throws QueryEvaluationException {
				if (node instanceof Projection) {
					node.getArg().visit(this);
					((Projection) node).getProjectionElemList().visit(this);
				} else {
					node.visitChildren(this);
				}
			}

			@Override
			public void meet(MultiProjection node) throws QueryEvaluationException {
				for (String bindingName : node.getBindingNames()) {
					varNames.computeIfAbsent(bindingName, k -> k);
				}
				super.meet(node);
			}

			@Override
			public void meet(ZeroLengthPath node) throws QueryEvaluationException {
				varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_SUBJECT_VAR, k -> k);
				varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_PREDICATE_VAR, k -> k);
				varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_OBJECT_VAR, k -> k);
				varNames.computeIfAbsent(ZeroLengthPathIteration.ANON_SEQUENCE_VAR, k -> k);
				super.meet(node);
			}

			@Override
			public void meet(ExtensionElem node) throws QueryEvaluationException {
				node.setName(varNames.computeIfAbsent(node.getName(), k -> k));
				super.meet(node);
			}

			@Override
			public void meet(GroupElem node) throws QueryEvaluationException {
				node.setName(varNames.computeIfAbsent(node.getName(), k -> k));
				super.meet(node);
			}

			@Override
			public void meet(BindingSetAssignment node) throws QueryEvaluationException {
				Set<String> bindingNames = node.getBindingNames();

				Set<String> collect = bindingNames.stream()
						.map(varName -> varNames.computeIfAbsent(varName, k -> k))
						.collect(Collectors.toSet());

				node.setBindingNames(collect);

				super.meet(node);
			}

			@Override
			public void meet(Group node) throws QueryEvaluationException {
				List<String> collect = node.getGroupBindingNames()
						.stream()
						.map(varName -> varNames.computeIfAbsent(varName, k -> k))
						.collect(Collectors.toList());
				node.setGroupBindingNames(collect);
				super.meet(node);
			}

		};
		node.visit(queryModelVisitorBase);
		return varNames.keySet().toArray(new String[0]);
	}
}