PlanRemoteProjections.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.ExternalCallExpressionChecker;
import com.facebook.presto.sql.planner.optimizations.SymbolMapper;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static com.facebook.presto.SystemSessionProperties.isRemoteFunctionsEnabled;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_USER_ERROR;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.LOCAL;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.REMOTE;
import static com.facebook.presto.spi.plan.ProjectNode.Locality.UNKNOWN;
import static com.facebook.presto.sql.planner.plan.Patterns.project;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
public class PlanRemoteProjections
implements Rule<ProjectNode>
{
private static final Pattern<ProjectNode> PATTERN = project();
private final FunctionAndTypeManager functionAndTypeManager;
public PlanRemoteProjections(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
}
@Override
public Pattern<ProjectNode> getPattern()
{
return PATTERN;
}
@Override
public Result apply(ProjectNode node, Captures captures, Rule.Context context)
{
if (!node.getLocality().equals(UNKNOWN)) {
// Already planned
return Result.empty();
}
// Fast check for remote functions
if (node.getAssignments().getExpressions().stream().noneMatch(expression -> expression.accept(new ExternalCallExpressionChecker(functionAndTypeManager), null))) {
// No remote function
return Result.ofPlanNode(new ProjectNode(node.getSourceLocation(), node.getId(), node.getSource(), node.getAssignments(), LOCAL));
}
if (!isRemoteFunctionsEnabled(context.getSession())) {
throw new PrestoException(GENERIC_USER_ERROR, "Remote functions are not enabled");
}
List<ProjectionContext> projectionContexts = planRemoteAssignments(node.getAssignments(), context.getVariableAllocator());
checkState(!projectionContexts.isEmpty(), "Expect non-empty projectionContexts");
PlanNode rewritten = node.getSource();
for (ProjectionContext projectionContext : projectionContexts) {
rewritten = new ProjectNode(node.getSourceLocation(), context.getIdAllocator().getNextId(), rewritten, Assignments.builder().putAll(projectionContext.getProjections()).build(), projectionContext.remote ? REMOTE : LOCAL);
}
return Result.ofPlanNode(rewritten);
}
@VisibleForTesting
public List<ProjectionContext> planRemoteAssignments(Assignments assignments, VariableAllocator variableAllocator)
{
ImmutableList.Builder<List<ProjectionContext>> assignmentProjections = ImmutableList.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : assignments.getMap().entrySet()) {
List<ProjectionContext> rewritten = entry.getValue().accept(new Visitor(functionAndTypeManager, variableAllocator), null);
if (rewritten.isEmpty()) {
assignmentProjections.add(ImmutableList.of(new ProjectionContext(ImmutableMap.of(entry.getKey(), entry.getValue()), false)));
}
else {
checkState(rewritten.get(rewritten.size() - 1).getProjections().size() == 1, "Expect at most 1 assignment from last projection in rewrite");
ProjectionContext last = rewritten.get(rewritten.size() - 1);
ImmutableList.Builder<ProjectionContext> projectionContextBuilder = ImmutableList.builder();
projectionContextBuilder.addAll(rewritten.subList(0, rewritten.size() - 1));
projectionContextBuilder.add(new ProjectionContext(ImmutableMap.of(entry.getKey(), getOnlyElement(last.getProjections().values())), last.isRemote()));
assignmentProjections.add(projectionContextBuilder.build());
}
}
List<ProjectionContext> mergedProjectionContexts = mergeProjectionContexts(assignmentProjections.build());
return dedupVariables(mergedProjectionContexts);
}
private static List<ProjectionContext> dedupVariables(List<ProjectionContext> projectionContexts)
{
ImmutableList.Builder<ProjectionContext> dedupedProjectionContexts = ImmutableList.builder();
Set<VariableReferenceExpression> originalVariable = projectionContexts.get(projectionContexts.size() - 1).getProjections().keySet();
SymbolMapper mapper = null;
for (int i = 0; i < projectionContexts.size(); i++) {
Map<VariableReferenceExpression, RowExpression> projections = projectionContexts.get(i).getProjections();
// Apply mapping from previous projection
if (mapper != null) {
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> newProjections = ImmutableMap.builder();
for (Map.Entry<VariableReferenceExpression, RowExpression> entry : projections.entrySet()) {
newProjections.put(entry.getKey(), mapper.map(entry.getValue()));
}
projections = newProjections.build();
}
// Dedup
ImmutableMultimap.Builder<RowExpression, VariableReferenceExpression> reverseProjectionsBuilder = ImmutableMultimap.builder();
projections.forEach((key, value) -> reverseProjectionsBuilder.put(value, key));
ImmutableMultimap<RowExpression, VariableReferenceExpression> reverseProjections = reverseProjectionsBuilder.build();
if (reverseProjections.keySet().size() == projectionContexts.get(i).getProjections().size() && reverseProjections.keySet().stream().noneMatch(VariableReferenceExpression.class::isInstance)) {
// No duplication
dedupedProjectionContexts.add(new ProjectionContext(projections, projectionContexts.get(i).isRemote()));
mapper = null;
}
else {
SymbolMapper.Builder mapperBuilder = SymbolMapper.builder();
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> dedupedProjectionsBuilder = ImmutableMap.builder();
for (RowExpression key : reverseProjections.keySet()) {
List<VariableReferenceExpression> values = ImmutableList.copyOf(reverseProjections.get(key));
if (key instanceof VariableReferenceExpression) {
values.forEach(variable -> mapperBuilder.put(variable, (VariableReferenceExpression) key));
dedupedProjectionsBuilder.put((VariableReferenceExpression) key, key);
}
else if (values.size() > 1) {
// Consolidate to one variable, prefer variables from original plan
List<VariableReferenceExpression> fromOriginal = originalVariable.stream().filter(values::contains).collect(toImmutableList());
VariableReferenceExpression variable = fromOriginal.isEmpty() ? values.get(0) : getOnlyElement(fromOriginal);
for (int j = 0; j < values.size(); j++) {
if (!values.get(j).equals(variable)) {
mapperBuilder.put(values.get(j), variable);
}
}
dedupedProjectionsBuilder.put(variable, key);
}
else {
checkState(values.size() == 1, "Expect only 1 value");
dedupedProjectionsBuilder.put(values.get(0), key);
}
}
dedupedProjectionContexts.add(new ProjectionContext(dedupedProjectionsBuilder.build(), projectionContexts.get(i).isRemote()));
mapper = mapperBuilder.build();
}
}
return dedupedProjectionContexts.build();
}
private static List<ProjectionContext> mergeProjectionContexts(List<List<ProjectionContext>> projectionContexts)
{
int assignmentsCount = projectionContexts.size();
int[] indices = new int[assignmentsCount];
ImmutableList.Builder<ProjectionContext> mergedAssignments = ImmutableList.builder();
boolean remote = false;
while (true) {
boolean finished = true;
for (int i = 0; i < projectionContexts.size(); i++) {
if (projectionContexts.get(i).size() > indices[i]) {
finished = false;
break;
}
}
if (finished) {
break;
}
ImmutableMap.Builder<VariableReferenceExpression, RowExpression> projectionBuilder = ImmutableMap.builder();
boolean hasNonIdentityProjection = false;
for (int i = 0; i < projectionContexts.size(); i++) {
if (projectionContexts.get(i).size() > indices[i]) {
ProjectionContext projectionContext = projectionContexts.get(i).get(indices[i]);
if (projectionContexts.get(i).get(indices[i]).isRemote() == remote) {
projectionBuilder.putAll(projectionContext.getProjections());
indices[i]++;
hasNonIdentityProjection = true;
}
else if (remote && !projectionContext.isRemote()) {
// For remote stage, pass identity projection for local parameters
projectionBuilder.putAll(projectionContext.getProjections().keySet().stream().collect(toImmutableMap(identity(), identity())));
}
}
else {
// Pass identity projection for shorter assignment chains
Map<VariableReferenceExpression, RowExpression> projections = projectionContexts.get(i).get(projectionContexts.get(i).size() - 1).getProjections();
projectionBuilder.putAll(projections.keySet().stream().collect(toImmutableMap(identity(), identity())));
}
}
ImmutableMap<VariableReferenceExpression, RowExpression> merged = projectionBuilder.build();
if (hasNonIdentityProjection) {
// Have non-identity assignments
mergedAssignments.add(new ProjectionContext(merged, remote));
}
remote = !remote;
}
return mergedAssignments.build();
}
private static VariableReferenceExpression getAssignedArgument(List<ProjectionContext> projectionContexts)
{
checkState(projectionContexts.get(projectionContexts.size() - 1).getProjections().size() == 1, "Expect only 1 projection for argument");
return getOnlyElement(projectionContexts.get(projectionContexts.size() - 1).getProjections().keySet());
}
private static class Visitor
implements RowExpressionVisitor<List<ProjectionContext>, Void>
{
private final FunctionAndTypeManager functionAndTypeManager;
private final VariableAllocator variableAllocator;
public Visitor(FunctionAndTypeManager functionAndTypeManager, VariableAllocator variableAllocator)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
}
@Override
public List<ProjectionContext> visitCall(CallExpression call, Void context)
{
FunctionMetadata functionMetadata = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle());
boolean local = !functionMetadata.getImplementationType().isExternalExecution();
// Break function arguments into local and remote projections first
ImmutableList.Builder<RowExpression> newArgumentsBuilder = ImmutableList.builder();
List<ProjectionContext> processedArguments = processArguments(call.getArguments(), newArgumentsBuilder, local);
List<RowExpression> newArguments = newArgumentsBuilder.build();
CallExpression newCall = new CallExpression(
call.getSourceLocation(),
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
newArguments);
if (local) {
if (processedArguments.isEmpty() || (processedArguments.size() == 1 && !processedArguments.get(0).isRemote())) {
// This call and all its arguments are local
return ImmutableList.of();
}
else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
// This call and its arguments has local projections, merge the call into the last local projection
ImmutableList.Builder<ProjectionContext> projectionContextBuilder = ImmutableList.builder();
projectionContextBuilder.addAll(processedArguments.subList(0, processedArguments.size() - 1));
ProjectionContext last = processedArguments.get(processedArguments.size() - 1);
projectionContextBuilder.add(new ProjectionContext(
ImmutableMap.of(
variableAllocator.newVariable(call),
new CallExpression(
call.getSourceLocation(),
call.getDisplayName(),
call.getFunctionHandle(),
call.getType(),
newArguments.stream()
.map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
}
else {
// This call is local but last projection is remote, add another level of projection
ImmutableList.Builder<ProjectionContext> projectionContextBuilder = ImmutableList.builder();
projectionContextBuilder.addAll(processedArguments);
projectionContextBuilder.add(new ProjectionContext(ImmutableMap.of(variableAllocator.newVariable(newCall), newCall), false));
return projectionContextBuilder.build();
}
}
else {
// this call is remote, add another level of projection
// TODO if all arguments are input reference or constant (maybe variable reference?) we could skip a projection
ImmutableList.Builder<ProjectionContext> projectionContextBuilder = ImmutableList.builder();
projectionContextBuilder.addAll(processedArguments);
projectionContextBuilder.add(new ProjectionContext(ImmutableMap.of(variableAllocator.newVariable(newCall), newCall), true));
return projectionContextBuilder.build();
}
}
@Override
public List<ProjectionContext> visitInputReference(InputReferenceExpression reference, Void context)
{
throw new IllegalStateException("Optimizers should not see InputReferenceExpression");
}
@Override
public List<ProjectionContext> visitConstant(ConstantExpression literal, Void context)
{
return ImmutableList.of();
}
@Override
public List<ProjectionContext> visitLambda(LambdaDefinitionExpression lambda, Void context)
{
return ImmutableList.of();
}
@Override
public List<ProjectionContext> visitVariableReference(VariableReferenceExpression reference, Void context)
{
return ImmutableList.of();
}
@Override
public List<ProjectionContext> visitSpecialForm(SpecialFormExpression specialForm, Void context)
{
ImmutableList.Builder<RowExpression> newArgumentsBuilder = ImmutableList.builder();
List<ProjectionContext> processedArguments = processArguments(specialForm.getArguments(), newArgumentsBuilder, true);
List<RowExpression> newArguments = newArgumentsBuilder.build();
if (processedArguments.isEmpty() || (processedArguments.size() == 1 && !processedArguments.get(0).isRemote())) {
// Arguments do not contain remote projection
return ImmutableList.of();
}
else if (!processedArguments.get(processedArguments.size() - 1).isRemote()) {
// There are remote projections, but the previous stage is local, so merge them
ImmutableList.Builder<ProjectionContext> projectionContextBuilder = ImmutableList.builder();
projectionContextBuilder.addAll(processedArguments.subList(0, processedArguments.size() - 1));
ProjectionContext last = processedArguments.get(processedArguments.size() - 1);
projectionContextBuilder.add(new ProjectionContext(
ImmutableMap.of(
variableAllocator.newVariable(specialForm),
new SpecialFormExpression(
specialForm.getForm(),
specialForm.getType(),
newArguments.stream()
.map(argument -> argument instanceof VariableReferenceExpression ? last.getProjections().get(argument) : argument)
.collect(toImmutableList()))),
false));
return projectionContextBuilder.build();
}
else {
// Last projection is remote, add another level of projection
ImmutableList.Builder<ProjectionContext> projectionContextBuilder = ImmutableList.builder();
projectionContextBuilder.addAll(processedArguments);
projectionContextBuilder.add(new ProjectionContext(
ImmutableMap.of(
variableAllocator.newVariable(specialForm),
new SpecialFormExpression(
specialForm.getForm(),
specialForm.getType(),
newArguments)),
false));
return projectionContextBuilder.build();
}
}
private List<ProjectionContext> processArguments(List<RowExpression> arguments, ImmutableList.Builder<RowExpression> newArguments, boolean local)
{
// Break function arguments into local and remote projections first
ImmutableList.Builder<List<ProjectionContext>> argumentProjections = ImmutableList.builder();
for (RowExpression argument : arguments) {
if (local && argument instanceof ConstantExpression) {
newArguments.add(argument);
}
else {
List<ProjectionContext> argumentProjection = argument.accept(this, null);
if (argumentProjection.isEmpty()) {
VariableReferenceExpression variable = variableAllocator.newVariable(argument);
argumentProjection = ImmutableList.of(new ProjectionContext(ImmutableMap.of(variable, argument), false));
}
argumentProjections.add(argumentProjection);
newArguments.add(getAssignedArgument(argumentProjection));
}
}
return mergeProjectionContexts(argumentProjections.build());
}
}
public static class ProjectionContext
{
private final Map<VariableReferenceExpression, RowExpression> projections;
private final boolean remote;
ProjectionContext(Map<VariableReferenceExpression, RowExpression> projections, boolean remote)
{
this.projections = requireNonNull(projections, "projections is null");
this.remote = remote;
}
public Map<VariableReferenceExpression, RowExpression> getProjections()
{
return projections;
}
public boolean isRemote()
{
return remote;
}
}
}