SymbolMapper.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.optimizations;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.spi.PrestoWarning;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Ordering;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.PartitioningScheme;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.StatisticAggregations;
import com.facebook.presto.spi.plan.StatisticAggregationsDescriptor;
import com.facebook.presto.spi.plan.TableFinishNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.StatisticsWriterNode;
import com.facebook.presto.sql.planner.plan.TableWriterMergeNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.SymbolReference;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY;
import static com.facebook.presto.spi.plan.AggregationNode.groupingSets;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation;
import static com.facebook.presto.sql.planner.optimizations.PartitioningUtils.translateVariable;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
public class SymbolMapper
{
private final Map<String, String> mapping;
private final TypeProvider types;
private final WarningCollector warningCollector;
public SymbolMapper(Map<VariableReferenceExpression, VariableReferenceExpression> mapping, WarningCollector warningCollector)
{
requireNonNull(mapping, "mapping is null");
this.mapping = mapping.entrySet().stream().collect(toImmutableMap(entry -> entry.getKey().getName(), entry -> entry.getValue().getName()));
ImmutableSet.Builder<VariableReferenceExpression> variables = ImmutableSet.builder();
mapping.entrySet().forEach(entry -> {
variables.add(entry.getKey());
variables.add(entry.getValue());
});
this.types = TypeProvider.fromVariables(variables.build());
this.warningCollector = warningCollector;
}
public SymbolMapper(Map<String, String> mapping, TypeProvider types, WarningCollector warningCollector)
{
requireNonNull(mapping, "mapping is null");
this.mapping = ImmutableMap.copyOf(mapping);
this.types = requireNonNull(types, "types is null");
this.warningCollector = warningCollector;
}
public Symbol map(Symbol symbol)
{
String canonical = symbol.getName();
while (mapping.containsKey(canonical) && !mapping.get(canonical).equals(canonical)) {
canonical = mapping.get(canonical);
}
return new Symbol(canonical);
}
public VariableReferenceExpression map(VariableReferenceExpression variable)
{
String canonical = variable.getName();
while (mapping.containsKey(canonical) && !mapping.get(canonical).equals(canonical)) {
canonical = mapping.get(canonical);
}
if (canonical.equals(variable.getName())) {
return variable;
}
return new VariableReferenceExpression(variable.getSourceLocation(), canonical, types.get(new SymbolReference(getNodeLocation(variable.getSourceLocation()), canonical)));
}
public Expression map(Expression value)
{
return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>()
{
@Override
public Expression rewriteSymbolReference(SymbolReference node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
Symbol canonical = map(Symbol.from(node));
return canonical.toSymbolReference();
}
}, value);
}
public RowExpression map(RowExpression value)
{
return RowExpressionTreeRewriter.rewriteWith(new RowExpressionRewriter<Void>()
{
@Override
public RowExpression rewriteVariableReference(VariableReferenceExpression variable, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
{
return map(variable);
}
}, value);
}
public OrderingScheme map(OrderingScheme orderingScheme)
{
// SymbolMapper inlines symbol with multiple level reference (SymbolInliner only inline single level).
ImmutableList.Builder<VariableReferenceExpression> orderBy = ImmutableList.builder();
HashMap<VariableReferenceExpression, SortOrder> orderingMap = new HashMap<>();
for (VariableReferenceExpression variable : orderingScheme.getOrderByVariables()) {
VariableReferenceExpression translated = map(variable);
// Some variables may become duplicates after canonicalization, so we put them only once.
if (!orderingMap.containsKey(translated)) {
orderBy.add(translated);
orderingMap.put(translated, orderingScheme.getOrdering(variable));
}
else if (orderingMap.get(translated) != orderingScheme.getOrdering(variable)) {
warningCollector.add(new PrestoWarning(
MULTIPLE_ORDER_BY,
"Multiple ORDER BY for a variable were given, only first provided will be considered"));
}
}
return new OrderingScheme(orderBy.build().stream().map(variable -> new Ordering(variable, orderingMap.get(variable))).collect(toImmutableList()));
}
public AggregationNode map(AggregationNode node, PlanNode source)
{
return map(node, source, node.getId());
}
public AggregationNode map(AggregationNode node, PlanNode source, PlanNodeIdAllocator idAllocator)
{
return map(node, source, idAllocator.getNextId());
}
private AggregationNode map(AggregationNode node, PlanNode source, PlanNodeId newNodeId)
{
ImmutableMap.Builder<VariableReferenceExpression, Aggregation> aggregations = ImmutableMap.builder();
for (Entry<VariableReferenceExpression, Aggregation> entry : node.getAggregations().entrySet()) {
aggregations.put(map(entry.getKey()), map(entry.getValue()));
}
return new AggregationNode(
source.getSourceLocation(),
newNodeId,
source,
aggregations.build(),
groupingSets(
mapAndDistinctVariable(node.getGroupingKeys()),
node.getGroupingSetCount(),
node.getGlobalGroupingSets()),
mapAndDistinctVariable(node.getPreGroupedVariables()),
node.getStep(),
node.getHashVariable().map(this::map),
node.getGroupIdVariable().map(this::map),
node.getAggregationId());
}
private Aggregation map(Aggregation aggregation)
{
return new Aggregation(
new CallExpression(
aggregation.getCall().getSourceLocation(),
aggregation.getCall().getDisplayName(),
aggregation.getCall().getFunctionHandle(),
aggregation.getCall().getType(),
aggregation.getArguments().stream().map(this::map).collect(toImmutableList())),
aggregation.getFilter().map(this::map),
aggregation.getOrderBy().map(this::map),
aggregation.isDistinct(),
aggregation.getMask().map(this::map));
}
public TopNNode map(TopNNode node, PlanNode source, PlanNodeId newNodeId)
{
ImmutableList.Builder<VariableReferenceExpression> variables = ImmutableList.builder();
ImmutableMap.Builder<VariableReferenceExpression, SortOrder> orderings = ImmutableMap.builder();
Set<VariableReferenceExpression> seenCanonicals = new HashSet<>(node.getOrderingScheme().getOrderByVariables().size());
for (VariableReferenceExpression variable : node.getOrderingScheme().getOrderByVariables()) {
VariableReferenceExpression canonical = map(variable);
if (seenCanonicals.add(canonical)) {
seenCanonicals.add(canonical);
variables.add(canonical);
orderings.put(canonical, node.getOrderingScheme().getOrdering(variable));
}
}
ImmutableMap<VariableReferenceExpression, SortOrder> orderingMap = orderings.build();
return new TopNNode(
node.getSourceLocation(),
newNodeId,
source,
node.getCount(),
new OrderingScheme(variables.build().stream().map(variable -> new Ordering(variable, orderingMap.get(variable))).collect(toImmutableList())),
node.getStep());
}
public TableWriterNode map(TableWriterNode node, PlanNode source)
{
return map(node, source, node.getId());
}
public TableWriterNode map(TableWriterNode node, PlanNode source, PlanNodeId newNodeId)
{
// Intentionally does not use canonicalizeAndDistinct as that would remove columns
ImmutableList<VariableReferenceExpression> columns = node.getColumns().stream()
.map(this::map)
.collect(toImmutableList());
Set<VariableReferenceExpression> notNullColumnVariables = node.getNotNullColumnVariables().stream()
.map(this::map)
.collect(toImmutableSet());
return new TableWriterNode(
source.getSourceLocation(),
newNodeId,
node.getStatsEquivalentPlanNode(),
source,
node.getTarget(),
map(node.getRowCountVariable()),
map(node.getFragmentVariable()),
map(node.getTableCommitContextVariable()),
columns,
node.getColumnNames(),
notNullColumnVariables,
node.getTablePartitioningScheme().map(partitioningScheme -> canonicalize(partitioningScheme, source)),
node.getStatisticsAggregation().map(this::map),
node.getTaskCountIfScaledWriter(),
node.getIsTemporaryTableWriter());
}
public StatisticsWriterNode map(StatisticsWriterNode node, PlanNode source)
{
return new StatisticsWriterNode(
node.getSourceLocation(),
node.getId(),
source,
node.getTableHandle(),
node.getRowCountVariable(),
node.isRowCountEnabled(),
node.getDescriptor().map(this::map));
}
public TableFinishNode map(TableFinishNode node, PlanNode source)
{
return new TableFinishNode(
node.getSourceLocation(),
node.getId(),
source,
node.getTarget(),
map(node.getRowCountVariable()),
node.getStatisticsAggregation().map(this::map),
node.getStatisticsAggregationDescriptor().map(descriptor -> descriptor.map(this::map)),
node.getCteMaterializationInfo());
}
public TableWriterMergeNode map(TableWriterMergeNode node, PlanNode source)
{
return new TableWriterMergeNode(
node.getSourceLocation(),
node.getId(),
source,
map(node.getRowCountVariable()),
map(node.getFragmentVariable()),
map(node.getTableCommitContextVariable()),
node.getStatisticsAggregation().map(this::map));
}
private PartitioningScheme canonicalize(PartitioningScheme scheme, PlanNode source)
{
return new PartitioningScheme(translateVariable(scheme.getPartitioning(), this::map),
mapAndDistinctVariable(source.getOutputVariables()),
scheme.getHashColumn().map(this::map),
scheme.isReplicateNullsAndAny(),
scheme.isScaleWriters(),
scheme.getEncoding(),
scheme.getBucketToPartition());
}
private StatisticAggregations map(StatisticAggregations statisticAggregations)
{
Map<VariableReferenceExpression, Aggregation> aggregations = statisticAggregations.getAggregations().entrySet().stream()
.collect(toImmutableMap(entry -> map(entry.getKey()), entry -> map(entry.getValue())));
return new StatisticAggregations(aggregations, mapAndDistinctVariable(statisticAggregations.getGroupingVariables()));
}
private StatisticAggregationsDescriptor<VariableReferenceExpression> map(StatisticAggregationsDescriptor<VariableReferenceExpression> descriptor)
{
return descriptor.map(this::map);
}
private List<Symbol> mapAndDistinctSymbol(List<Symbol> outputs)
{
Set<Symbol> added = new HashSet<>();
ImmutableList.Builder<Symbol> builder = ImmutableList.builder();
for (Symbol symbol : outputs) {
Symbol canonical = map(symbol);
if (added.add(canonical)) {
builder.add(canonical);
}
}
return builder.build();
}
private List<VariableReferenceExpression> mapAndDistinctVariable(List<VariableReferenceExpression> outputs)
{
Set<VariableReferenceExpression> added = new HashSet<>();
ImmutableList.Builder<VariableReferenceExpression> builder = ImmutableList.builder();
for (VariableReferenceExpression variable : outputs) {
VariableReferenceExpression canonical = map(variable);
if (added.add(canonical)) {
builder.add(canonical);
}
}
return builder.build();
}
public static SymbolMapper.Builder builder(WarningCollector warningCollector)
{
return new Builder(warningCollector);
}
public static SymbolMapper.Builder builder()
{
return new Builder(WarningCollector.NOOP);
}
public static class Builder
{
private final ImmutableMap.Builder<VariableReferenceExpression, VariableReferenceExpression> mappingsBuilder;
private final WarningCollector warningCollector;
public Builder(WarningCollector warningCollector)
{
this.warningCollector = warningCollector;
this.mappingsBuilder = ImmutableMap.builder();
}
public SymbolMapper build()
{
return new SymbolMapper(mappingsBuilder.build(), warningCollector);
}
public void put(VariableReferenceExpression from, VariableReferenceExpression to)
{
mappingsBuilder.put(from, to);
}
}
}