PlanMatchPattern.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.assertions;
import com.facebook.presto.Session;
import com.facebook.presto.common.block.SortOrder;
import com.facebook.presto.common.predicate.Domain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.cost.StatsProvider;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Step;
import com.facebook.presto.spi.plan.CteConsumerNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.DataOrganizationSpecification;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.ExceptNode;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.IndexSourceNode;
import com.facebook.presto.spi.plan.IntersectNode;
import com.facebook.presto.spi.plan.JoinDistributionType;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.JoinType;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.MergeJoinNode;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PlanFragmentId;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.plan.WindowNode.Frame.BoundType;
import com.facebook.presto.spi.plan.WindowNode.Frame.WindowType;
import com.facebook.presto.spi.statistics.SourceInfo;
import com.facebook.presto.spi.statistics.SourceInfo.ConfidenceLevel;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.ParsingOptions.DecimalLiteralTreatment;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.Symbol;
import com.facebook.presto.sql.planner.iterative.GroupReference;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.ExchangeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.facebook.presto.sql.planner.plan.OffsetNode;
import com.facebook.presto.sql.planner.plan.RemoteSourceNode;
import com.facebook.presto.sql.planner.plan.SequenceNode;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SortItem;
import com.facebook.presto.sql.tree.WindowFrame;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import java.util.stream.IntStream;
import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_FIRST;
import static com.facebook.presto.common.block.SortOrder.ASC_NULLS_LAST;
import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_FIRST;
import static com.facebook.presto.common.block.SortOrder.DESC_NULLS_LAST;
import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences;
import static com.facebook.presto.sql.planner.assertions.MatchResult.NO_MATCH;
import static com.facebook.presto.sql.planner.assertions.MatchResult.match;
import static com.facebook.presto.sql.planner.assertions.StrictAssignedSymbolsMatcher.actualAssignments;
import static com.facebook.presto.sql.planner.assertions.StrictSymbolsMatcher.actualOutputs;
import static com.facebook.presto.sql.tree.SortItem.NullOrdering.FIRST;
import static com.facebook.presto.sql.tree.SortItem.NullOrdering.UNDEFINED;
import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING;
import static com.facebook.presto.sql.tree.SortItem.Ordering.DESCENDING;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
public final class PlanMatchPattern
{
private final List<Matcher> matchers = new ArrayList<>();
private final List<PlanMatchPattern> sourcePatterns;
private boolean anyTree;
public static PlanMatchPattern node(Class<? extends PlanNode> nodeClass, PlanMatchPattern... sources)
{
return any(sources).with(new PlanNodeMatcher(nodeClass));
}
public static PlanMatchPattern any(PlanMatchPattern... sources)
{
return new PlanMatchPattern(ImmutableList.copyOf(sources));
}
/**
* Matches to any tree of nodes with children matching to given source matchers.
* anyNodeTree(tableScanNode("nation")) - will match to any plan which all leafs contain
* any node containing table scan from nation table.
*
* @note anyTree does not match zero nodes. E.g. output(anyTree(tableScan)) will NOT match TableScan node followed by OutputNode.
*/
public static PlanMatchPattern anyTree(PlanMatchPattern... sources)
{
return any(sources).matchToAnyNodeTree();
}
public static PlanMatchPattern anyNot(Class<? extends PlanNode> excludeNodeClass, PlanMatchPattern... sources)
{
return any(sources).with(new NotPlanNodeMatcher(excludeNodeClass));
}
public static PlanMatchPattern tableScan(String expectedTableName)
{
return TableScanMatcher.create(expectedTableName);
}
public static PlanMatchPattern tableScan(String expectedTableName, Map<String, String> columnReferences)
{
PlanMatchPattern result = tableScan(expectedTableName);
return result.addColumnReferences(expectedTableName, columnReferences);
}
public static PlanMatchPattern strictTableScan(String expectedTableName, Map<String, String> columnReferences)
{
return tableScan(expectedTableName, columnReferences)
.withExactAssignedOutputs(columnReferences.values().stream()
.map(columnName -> columnReference(expectedTableName, columnName))
.collect(toImmutableList()));
}
public static PlanMatchPattern constrainedTableScan(String expectedTableName, Map<String, Domain> constraint)
{
return TableScanMatcher.builder(expectedTableName)
.expectedConstraint(constraint)
.build();
}
public static PlanMatchPattern constrainedTableScan(String expectedTableName, Map<String, Domain> constraint, Map<String, String> columnReferences)
{
PlanMatchPattern result = constrainedTableScan(expectedTableName, constraint);
return result.addColumnReferences(expectedTableName, columnReferences);
}
public static PlanMatchPattern constrainedTableScanWithTableLayout(String expectedTableName, Map<String, Domain> constraint, Map<String, String> columnReferences)
{
PlanMatchPattern result = TableScanMatcher.builder(expectedTableName)
.expectedConstraint(constraint)
.hasTableLayout()
.build();
return result.addColumnReferences(expectedTableName, columnReferences);
}
public static PlanMatchPattern indexSource(String expectedTableName)
{
return node(IndexSourceNode.class)
.with(new IndexSourceMatcher(expectedTableName));
}
public static PlanMatchPattern indexSource(String expectedTableName, Map<String, String> columnReferences)
{
return node(IndexSourceNode.class)
.with(new IndexSourceMatcher(expectedTableName))
.addColumnReferences(expectedTableName, columnReferences);
}
public static PlanMatchPattern strictIndexSource(String expectedTableName, Map<String, String> columnReferences)
{
return node(IndexSourceNode.class)
.with(new IndexSourceMatcher(expectedTableName))
.withExactAssignedOutputs(columnReferences.values().stream()
.map(columnName -> columnReference(expectedTableName, columnName))
.collect(toImmutableList()));
}
public static PlanMatchPattern constrainedIndexSource(String expectedTableName, Map<String, Domain> constraint, Map<String, String> columnReferences)
{
return node(IndexSourceNode.class)
.with(new IndexSourceMatcher(expectedTableName, constraint))
.addColumnReferences(expectedTableName, columnReferences);
}
private PlanMatchPattern addColumnReferences(String expectedTableName, Map<String, String> columnReferences)
{
columnReferences.forEach((key, value) -> withAlias(key, columnReference(expectedTableName, value)));
return this;
}
public static PlanMatchPattern aggregation(
Map<String, ExpectedValueProvider<FunctionCall>> aggregations,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source);
aggregations.entrySet().forEach(
aggregation -> result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue())));
return result;
}
public static PlanMatchPattern aggregation(
Map<String, ExpectedValueProvider<FunctionCall>> aggregations,
Step step,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source).with(new AggregationStepMatcher(step));
aggregations.entrySet().forEach(
aggregation -> result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue())));
return result;
}
public static PlanMatchPattern aggregation(
GroupingSetDescriptor groupingSets,
Map<Optional<String>, ExpectedValueProvider<FunctionCall>> aggregations,
Map<Symbol, Symbol> masks,
Optional<Symbol> groupId,
Step step,
PlanMatchPattern source)
{
return aggregation(groupingSets, aggregations, ImmutableList.of(), masks, groupId, step, source);
}
public static PlanMatchPattern aggregation(
GroupingSetDescriptor groupingSets,
Map<Optional<String>, ExpectedValueProvider<FunctionCall>> aggregations,
List<String> preGroupedSymbols,
Map<Symbol, Symbol> masks,
Optional<Symbol> groupId,
Step step,
PlanMatchPattern source)
{
PlanMatchPattern result = node(AggregationNode.class, source);
aggregations.entrySet().forEach(
aggregation ->
{
if (aggregation.getKey().isPresent() && masks.containsKey(new Symbol(aggregation.getKey().get()))) {
result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue(), masks.get(new Symbol(aggregation.getKey().get()))));
}
else {
result.withAlias(aggregation.getKey(), new AggregationFunctionMatcher(aggregation.getValue()));
}
});
// Put the AggregationMatcher at the end as the mask mapping will use the output mapping from aggregation function calls above
result.with(new AggregationMatcher(groupingSets, preGroupedSymbols, masks, groupId, step));
return result;
}
public static PlanMatchPattern markDistinct(
String markerSymbol,
List<String> distinctSymbols,
PlanMatchPattern source)
{
return node(MarkDistinctNode.class, source).with(new MarkDistinctMatcher(
new SymbolAlias(markerSymbol),
toSymbolAliases(distinctSymbols),
Optional.empty()));
}
public static PlanMatchPattern markDistinct(
String markerSymbol,
List<String> distinctSymbols,
String hashSymbol,
PlanMatchPattern source)
{
return node(MarkDistinctNode.class, source).with(new MarkDistinctMatcher(
new SymbolAlias(markerSymbol),
toSymbolAliases(distinctSymbols),
Optional.of(new SymbolAlias(hashSymbol))));
}
public static ExpectedValueProvider<WindowNode.Frame> windowFrame(
WindowType type,
BoundType startType,
Optional<String> startValue,
BoundType endType,
Optional<String> endValue,
Optional<String> sortKey)
{
return windowFrame(type, startType, startValue, Optional.empty(), sortKey, Optional.empty(), endType, endValue, Optional.empty(), sortKey, Optional.empty());
}
public static ExpectedValueProvider<WindowNode.Frame> windowFrame(
WindowType type,
BoundType startType,
Optional<String> startValue,
Optional<Type> startValueType,
Optional<String> sortKeyForStartComparison,
Optional<Type> sortKeyForStartComparisonType,
BoundType endType,
Optional<String> endValue,
Optional<Type> endValueType,
Optional<String> sortKeyForEndComparison,
Optional<Type> sortKeyForEndComparisonType)
{
return new WindowFrameProvider(
type,
startType,
startValue.map(SymbolAlias::new),
startValueType,
sortKeyForStartComparison.map(SymbolAlias::new),
sortKeyForStartComparisonType,
endType,
endValue.map(SymbolAlias::new),
endValueType,
sortKeyForEndComparison.map(SymbolAlias::new),
sortKeyForEndComparisonType);
}
public static PlanMatchPattern window(Consumer<WindowMatcher.Builder> windowMatcherBuilderConsumer, PlanMatchPattern source)
{
WindowMatcher.Builder windowMatcherBuilder = new WindowMatcher.Builder(source);
windowMatcherBuilderConsumer.accept(windowMatcherBuilder);
return windowMatcherBuilder.build();
}
public static PlanMatchPattern rowNumber(Consumer<RowNumberMatcher.Builder> rowNumberMatcherBuilderConsumer, PlanMatchPattern source)
{
RowNumberMatcher.Builder rowNumberMatcherBuilder = new RowNumberMatcher.Builder(source);
rowNumberMatcherBuilderConsumer.accept(rowNumberMatcherBuilder);
return rowNumberMatcherBuilder.build();
}
public static PlanMatchPattern topNRowNumber(Consumer<TopNRowNumberMatcher.Builder> topNRowNumberMatcherBuilderConsumer, PlanMatchPattern source)
{
TopNRowNumberMatcher.Builder topNRowNumberMatcherBuilder = new TopNRowNumberMatcher.Builder(source);
topNRowNumberMatcherBuilderConsumer.accept(topNRowNumberMatcherBuilder);
return topNRowNumberMatcherBuilder.build();
}
public static PlanMatchPattern sort(PlanMatchPattern source)
{
return node(SortNode.class, source);
}
public static PlanMatchPattern sort(List<Ordering> orderBy, PlanMatchPattern source)
{
return node(SortNode.class, source)
.with(new SortMatcher(orderBy));
}
public static PlanMatchPattern topN(long count, List<Ordering> orderBy, PlanMatchPattern source)
{
return node(TopNNode.class, source).with(new TopNMatcher(count, orderBy));
}
public static PlanMatchPattern output(PlanMatchPattern source)
{
return node(OutputNode.class, source);
}
public static PlanMatchPattern output(List<String> outputs, PlanMatchPattern source)
{
PlanMatchPattern result = output(source);
result.withOutputs(outputs);
return result;
}
public static PlanMatchPattern strictOutput(List<String> outputs, PlanMatchPattern source)
{
return output(outputs, source).withExactOutputs(outputs);
}
public static PlanMatchPattern project(PlanMatchPattern source)
{
return node(ProjectNode.class, source);
}
public static PlanMatchPattern project(Map<String, ExpressionMatcher> assignments, PlanMatchPattern source)
{
PlanMatchPattern result = project(source);
assignments.entrySet().forEach(
assignment -> result.withAlias(assignment.getKey(), assignment.getValue()));
return result;
}
public static PlanMatchPattern strictProject(Map<String, ExpressionMatcher> assignments, PlanMatchPattern source)
{
/*
* Under the current implementation of project, all of the outputs are also in the assignment.
* If the implementation changes, this will need to change too.
*/
return project(assignments, source)
.withExactAssignedOutputs(assignments.values())
.withExactAssignments(assignments.values());
}
public static PlanMatchPattern semiJoin(String sourceSymbolAlias, String filteringSymbolAlias, String outputAlias, PlanMatchPattern source, PlanMatchPattern filtering)
{
return semiJoin(sourceSymbolAlias, filteringSymbolAlias, outputAlias, Optional.empty(), source, filtering);
}
public static PlanMatchPattern semiJoin(String sourceSymbolAlias, String filteringSymbolAlias, String outputAlias, Optional<SemiJoinNode.DistributionType> distributionType, PlanMatchPattern source, PlanMatchPattern filtering)
{
return node(SemiJoinNode.class, source, filtering).with(new SemiJoinMatcher(sourceSymbolAlias, filteringSymbolAlias, outputAlias, distributionType));
}
public static PlanMatchPattern join(PlanMatchPattern left, PlanMatchPattern right)
{
return node(JoinNode.class, left, right);
}
public static PlanMatchPattern join(JoinType joinType, List<ExpectedValueProvider<EquiJoinClause>> expectedEquiCriteria, PlanMatchPattern left, PlanMatchPattern right)
{
return join(joinType, expectedEquiCriteria, Optional.empty(), left, right);
}
public static PlanMatchPattern join(JoinType joinType, List<ExpectedValueProvider<EquiJoinClause>> expectedEquiCriteria, Optional<String> expectedFilter, PlanMatchPattern left, PlanMatchPattern right)
{
return join(joinType, expectedEquiCriteria, expectedFilter, Optional.empty(), left, right);
}
public static PlanMatchPattern join(JoinType joinType, List<ExpectedValueProvider<EquiJoinClause>> expectedEquiCriteria, Optional<String> expectedFilter, Optional<JoinDistributionType> expectedDistributionType, PlanMatchPattern left, PlanMatchPattern right)
{
return node(JoinNode.class, left, right).with(
new JoinMatcher(
joinType,
expectedEquiCriteria,
expectedFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))),
expectedDistributionType,
Optional.empty()));
}
public static PlanMatchPattern join(
JoinType joinType,
List<ExpectedValueProvider<EquiJoinClause>> expectedEquiCriteria,
Map<String, String> expectedDynamicFilter,
Optional<String> expectedStaticFilter,
PlanMatchPattern leftSource,
PlanMatchPattern right)
{
Map<SymbolAlias, SymbolAlias> expectedDynamicFilterAliases = expectedDynamicFilter.entrySet().stream()
.collect(toImmutableMap(entry -> new SymbolAlias(entry.getKey()), entry -> new SymbolAlias(entry.getValue())));
DynamicFilterMatcher dynamicFilterMatcher = new DynamicFilterMatcher(
expectedDynamicFilterAliases,
expectedStaticFilter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate))));
JoinMatcher joinMatcher = new JoinMatcher(
joinType,
expectedEquiCriteria,
Optional.empty(),
Optional.empty(),
Optional.of(dynamicFilterMatcher));
return node(JoinNode.class, anyTree(node(FilterNode.class, leftSource).with(dynamicFilterMatcher)), right)
.with(joinMatcher);
}
public static PlanMatchPattern indexJoin(PlanMatchPattern left, PlanMatchPattern right)
{
return node(IndexJoinNode.class, left, right);
}
public static PlanMatchPattern cteConsumer(String cteName)
{
CteConsumerMatcher cteConsumerMatcher = new CteConsumerMatcher(cteName);
return node(CteConsumerNode.class).with(cteConsumerMatcher);
}
public static PlanMatchPattern cteProducer(String cteName, PlanMatchPattern source)
{
CteProducerMatcher cteProducerMatcher = new CteProducerMatcher(cteName);
return node(CteProducerNode.class, source).with(cteProducerMatcher);
}
public static PlanMatchPattern sequence(PlanMatchPattern... sources)
{
return node(SequenceNode.class, sources);
}
public static PlanMatchPattern spatialJoin(String expectedFilter, PlanMatchPattern left, PlanMatchPattern right)
{
return spatialJoin(expectedFilter, Optional.empty(), left, right);
}
public static PlanMatchPattern spatialJoin(String expectedFilter, Optional<String> kdbTree, PlanMatchPattern left, PlanMatchPattern right)
{
return node(SpatialJoinNode.class, left, right).with(
new SpatialJoinMatcher(SpatialJoinNode.Type.INNER, rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedFilter, new ParsingOptions())), kdbTree));
}
public static PlanMatchPattern spatialLeftJoin(String expectedFilter, PlanMatchPattern left, PlanMatchPattern right)
{
return node(SpatialJoinNode.class, left, right).with(
new SpatialJoinMatcher(SpatialJoinNode.Type.LEFT, rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedFilter, new ParsingOptions())), Optional.empty()));
}
public static PlanMatchPattern mergeJoin(JoinType joinType, List<ExpectedValueProvider<EquiJoinClause>> expectedEquiCriteria, Optional<Expression> filter, PlanMatchPattern left, PlanMatchPattern right)
{
return node(MergeJoinNode.class, left, right).with(
new MergeJoinMatcher(
joinType,
expectedEquiCriteria,
filter));
}
public static PlanMatchPattern unnest(PlanMatchPattern source)
{
return node(UnnestNode.class, source);
}
public static PlanMatchPattern unnest(Map<String, List<String>> unnestVariables, PlanMatchPattern source)
{
return node(UnnestNode.class, source).with(new UnnestMatcher(unnestVariables));
}
public static PlanMatchPattern exchange(PlanMatchPattern... sources)
{
return node(ExchangeNode.class, sources);
}
public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, PlanMatchPattern... sources)
{
return exchange(scope, type, ImmutableList.of(), sources);
}
public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, List<Ordering> orderBy, PlanMatchPattern... sources)
{
return exchange(scope, type, orderBy, ImmutableSet.of(), sources);
}
public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, List<Ordering> orderBy, Set<String> partitionedBy, PlanMatchPattern... sources)
{
return node(ExchangeNode.class, sources)
.with(new ExchangeMatcher(scope, type, orderBy, partitionedBy));
}
public static PlanMatchPattern union(PlanMatchPattern... sources)
{
return node(UnionNode.class, sources);
}
public static PlanMatchPattern assignUniqueId(String uniqueSymbolAlias, PlanMatchPattern source)
{
return node(AssignUniqueId.class, source)
.withAlias(uniqueSymbolAlias, new AssignUniqueIdMatcher());
}
public static PlanMatchPattern intersect(PlanMatchPattern... sources)
{
return node(IntersectNode.class, sources);
}
public static PlanMatchPattern except(PlanMatchPattern... sources)
{
return node(ExceptNode.class, sources);
}
public static ExpectedValueProvider<EquiJoinClause> equiJoinClause(String left, String right)
{
return new EquiJoinClauseProvider(new SymbolAlias(left), new SymbolAlias(right));
}
public static SymbolAlias symbol(String alias)
{
return new SymbolAlias(alias);
}
public static PlanMatchPattern filter(String expectedPredicate, PlanMatchPattern source)
{
return filter(rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedPredicate)), source);
}
public static PlanMatchPattern filterWithDecimal(String expectedPredicate, PlanMatchPattern source)
{
return filter(rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(expectedPredicate, new ParsingOptions(DecimalLiteralTreatment.AS_DECIMAL))), source);
}
public static PlanMatchPattern filter(Expression expectedPredicate, PlanMatchPattern source)
{
return node(FilterNode.class, source).with(new FilterMatcher(expectedPredicate));
}
public static PlanMatchPattern filter(PlanMatchPattern source)
{
return node(FilterNode.class, source);
}
public static PlanMatchPattern apply(List<String> correlationSymbolAliases, Map<String, ExpressionMatcher> subqueryAssignments, PlanMatchPattern inputPattern, PlanMatchPattern subqueryPattern)
{
PlanMatchPattern result = node(ApplyNode.class, inputPattern, subqueryPattern)
.with(new CorrelationMatcher(correlationSymbolAliases));
subqueryAssignments.entrySet().forEach(
assignment -> result.withAlias(assignment.getKey(), assignment.getValue()));
return result;
}
public static PlanMatchPattern lateral(List<String> correlationSymbolAliases, PlanMatchPattern inputPattern, PlanMatchPattern subqueryPattern)
{
return node(LateralJoinNode.class, inputPattern, subqueryPattern)
.with(new CorrelationMatcher(correlationSymbolAliases));
}
public static PlanMatchPattern groupingSet(List<List<String>> groups, Map<String, String> identityMappings, String groupIdAlias, PlanMatchPattern source)
{
return node(GroupIdNode.class, source).with(new GroupIdMatcher(groups, identityMappings, groupIdAlias));
}
public static PlanMatchPattern groupingSet(List<List<String>> groups, Map<String, String> identityMappings, String groupIdAlias, Map<String, ExpressionMatcher> groupingColumns, PlanMatchPattern source)
{
PlanMatchPattern result = node(GroupIdNode.class, source).with(new GroupIdMatcher(groups, identityMappings, groupIdAlias));
groupingColumns.entrySet().forEach(
groupingColumn -> result.withAlias(groupingColumn.getKey(), groupingColumn.getValue()));
return result;
}
private static PlanMatchPattern values(
Map<String, Integer> aliasToIndex,
Optional<Integer> expectedOutputSymbolCount,
Optional<List<List<Expression>>> expectedRows)
{
return node(ValuesNode.class).with(new ValuesMatcher(aliasToIndex, expectedOutputSymbolCount, expectedRows));
}
private static PlanMatchPattern values(List<String> aliases, Optional<List<List<Expression>>> expectedRows)
{
return values(
Maps.uniqueIndex(IntStream.range(0, aliases.size()).boxed().iterator(), aliases::get),
Optional.of(aliases.size()),
expectedRows);
}
public static PlanMatchPattern values(Map<String, Integer> aliasToIndex)
{
return values(aliasToIndex, Optional.empty(), Optional.empty());
}
public static PlanMatchPattern values(String... aliases)
{
return values(ImmutableList.copyOf(aliases));
}
public static PlanMatchPattern values(List<String> aliases, List<List<Expression>> expectedRows)
{
return values(aliases, Optional.of(expectedRows));
}
public static PlanMatchPattern values(List<String> aliases)
{
return values(aliases, Optional.empty());
}
public static PlanMatchPattern offset(long rowCount, PlanMatchPattern source)
{
return node(OffsetNode.class, source).with(new OffsetMatcher(rowCount));
}
public static PlanMatchPattern limit(long limit, PlanMatchPattern source)
{
return limit(limit, false, source);
}
public static PlanMatchPattern limit(long limit, boolean partial, PlanMatchPattern source)
{
return node(LimitNode.class, source).with(new LimitMatcher(limit, partial));
}
public static PlanMatchPattern enforceSingleRow(PlanMatchPattern source)
{
return node(EnforceSingleRowNode.class, source);
}
public static PlanMatchPattern tableWriter(List<String> columns, List<String> columnNames, PlanMatchPattern source)
{
return node(TableWriterNode.class, source).with(new TableWriterMatcher(columns, columnNames));
}
public static PlanMatchPattern remoteSource(List<PlanFragmentId> sourceFragmentIds, Map<String, Integer> outputSymbolAliases)
{
return node(RemoteSourceNode.class).with(new RemoteSourceMatcher(sourceFragmentIds, outputSymbolAliases));
}
public PlanMatchPattern(List<PlanMatchPattern> sourcePatterns)
{
requireNonNull(sourcePatterns, "sourcePatterns are null");
this.sourcePatterns = ImmutableList.copyOf(sourcePatterns);
}
List<PlanMatchingState> shapeMatches(PlanNode node)
{
ImmutableList.Builder<PlanMatchingState> states = ImmutableList.builder();
if (anyTree) {
int sourcesCount = node.getSources().size();
if (sourcesCount > 1) {
states.add(new PlanMatchingState(nCopies(sourcesCount, this)));
}
else {
states.add(new PlanMatchingState(ImmutableList.of(this)));
}
}
if (node instanceof GroupReference) {
if (sourcePatterns.isEmpty() && shapeMatchesMatchers(node)) {
states.add(new PlanMatchingState(ImmutableList.of()));
}
}
else if (node.getSources().size() == sourcePatterns.size() && shapeMatchesMatchers(node)) {
states.add(new PlanMatchingState(sourcePatterns));
}
return states.build();
}
private boolean shapeMatchesMatchers(PlanNode node)
{
return matchers.stream().allMatch(it -> it.shapeMatches(node));
}
MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases)
{
SymbolAliases.Builder newAliases = SymbolAliases.builder();
for (Matcher matcher : matchers) {
MatchResult matchResult;
if (matcher instanceof AggregationMatcher) {
matchResult = matcher.detailMatches(node, stats, session, metadata, symbolAliases.withNewAliases(newAliases.build()));
}
else {
matchResult = matcher.detailMatches(node, stats, session, metadata, symbolAliases);
}
if (!matchResult.isMatch()) {
return NO_MATCH;
}
newAliases.putAll(matchResult.getAliases());
}
return match(newAliases.build());
}
public PlanMatchPattern with(Matcher matcher)
{
matchers.add(matcher);
return this;
}
public PlanMatchPattern withAlias(String alias)
{
return withAlias(Optional.of(alias), new AliasPresent(alias));
}
public PlanMatchPattern withAlias(String alias, RvalueMatcher matcher)
{
return withAlias(Optional.of(alias), matcher);
}
public PlanMatchPattern withAlias(Optional<String> alias, RvalueMatcher matcher)
{
matchers.add(new AliasMatcher(alias, matcher));
return this;
}
public PlanMatchPattern withNumberOfOutputColumns(int numberOfSymbols)
{
matchers.add(new SymbolCardinalityMatcher(numberOfSymbols));
return this;
}
/*
* This is useful if you already know the bindings for the aliases you expect to find
* in the outputs. This is the case for symbols that are produced by a direct or indirect
* source of the node you're applying this to.
*/
public PlanMatchPattern withExactOutputs(String... expectedAliases)
{
return withExactOutputs(ImmutableList.copyOf(expectedAliases));
}
public PlanMatchPattern withExactOutputs(List<String> expectedAliases)
{
matchers.add(new StrictSymbolsMatcher(actualOutputs(), expectedAliases));
return this;
}
/*
* withExactAssignments and withExactAssignedOutputs are needed for matching symbols
* that are produced in the node that you're matching. The name of the symbol bound to
* the alias is *not* known when the Matcher is run, and so you need to match by what
* is being assigned to it.
*/
public PlanMatchPattern withExactAssignedOutputs(RvalueMatcher... expectedAliases)
{
return withExactAssignedOutputs(ImmutableList.copyOf(expectedAliases));
}
public PlanMatchPattern withExactAssignedOutputs(Collection<? extends RvalueMatcher> expectedAliases)
{
matchers.add(new StrictAssignedSymbolsMatcher(actualOutputs(), expectedAliases));
return this;
}
public PlanMatchPattern withExactAssignments(Collection<? extends RvalueMatcher> expectedAliases)
{
matchers.add(new StrictAssignedSymbolsMatcher(actualAssignments(), expectedAliases));
return this;
}
public PlanMatchPattern withOutputRowCount(double expectedOutputRowCount)
{
matchers.add(new StatsOutputRowCountMatcher(expectedOutputRowCount));
return this;
}
public PlanMatchPattern withSourceInfo(SourceInfo sourceInfo)
{
matchers.add(new StatsSourceInfoMatcher(sourceInfo));
return this;
}
public PlanMatchPattern withConfidenceLevel(ConfidenceLevel confidenceLevel)
{
matchers.add(new StatsConfidenceLevelMatcher(confidenceLevel));
return this;
}
public PlanMatchPattern withOutputRowCount(double expectedOutputRowCount, String expectedSourceInfo)
{
matchers.add(new StatsOutputRowCountMatcher(expectedOutputRowCount, expectedSourceInfo));
return this;
}
public PlanMatchPattern withOutputRowCount(boolean exactMatch, String expectedSourceInfo)
{
matchers.add(new StatsOutputRowCountMatcher(exactMatch, expectedSourceInfo));
return this;
}
public PlanMatchPattern withApproximateOutputRowCount(double expectedOutputRowCount, double error)
{
matchers.add(new ApproximateStatsOutputRowCountMatcher(expectedOutputRowCount, error));
return this;
}
public PlanMatchPattern withOutputSize(double expectedOutputSize)
{
matchers.add(new StatsOutputSizeMatcher(expectedOutputSize));
return this;
}
public PlanMatchPattern withJoinStatistics(double expectedJoinBuildKeyCount, double expectedNullJoinBuildKeyCount, double expectedJoinProbeKeyCount, double expectedNullJoinProbeKeyCount)
{
matchers.add(new StatsJoinKeyCountMatcher(expectedJoinBuildKeyCount, expectedNullJoinBuildKeyCount, expectedJoinProbeKeyCount, expectedNullJoinProbeKeyCount));
return this;
}
public static RvalueMatcher columnReference(String tableName, String columnName)
{
return new ColumnReference(tableName, columnName);
}
public static ExpressionMatcher expression(String expression)
{
return new ExpressionMatcher(expression);
}
public static ExpressionMatcher expression(String expression, ParsingOptions.DecimalLiteralTreatment decimalLiteralTreatment)
{
return new ExpressionMatcher(expression, decimalLiteralTreatment);
}
public static ExpressionMatcher expression(Expression expression)
{
return new ExpressionMatcher(expression);
}
public PlanMatchPattern withOutputs(String... aliases)
{
return withOutputs(ImmutableList.copyOf(aliases));
}
public PlanMatchPattern withOutputs(List<String> aliases)
{
matchers.add(new OutputMatcher(aliases));
return this;
}
public PlanMatchPattern matchToAnyNodeTree()
{
anyTree = true;
return this;
}
public boolean isTerminated()
{
return sourcePatterns.isEmpty();
}
public static PlanTestSymbol anySymbol()
{
return new AnySymbol();
}
public static ExpectedValueProvider<FunctionCall> functionCall(String name, List<String> args)
{
return new FunctionCallProvider(QualifiedName.of(name), toSymbolAliases(args));
}
public static ExpectedValueProvider<FunctionCall> functionCall(String name, List<String> args, List<Ordering> orderBy)
{
return new FunctionCallProvider(QualifiedName.of(name), toSymbolAliases(args), orderBy);
}
public static ExpectedValueProvider<FunctionCall> functionCall(
String name,
Optional<WindowFrame> frame,
List<String> args)
{
return new FunctionCallProvider(QualifiedName.of(name), frame, false, toSymbolAliases(args));
}
public static ExpectedValueProvider<FunctionCall> functionCall(
String name,
boolean distinct,
List<PlanTestSymbol> args)
{
return new FunctionCallProvider(QualifiedName.of(name), distinct, args);
}
public static List<Expression> toSymbolReferences(List<PlanTestSymbol> aliases, SymbolAliases symbolAliases)
{
return aliases
.stream()
.map(arg -> arg.toSymbol(symbolAliases).toSymbolReference())
.collect(toImmutableList());
}
private static List<PlanTestSymbol> toSymbolAliases(List<String> aliases)
{
return aliases
.stream()
.map(PlanMatchPattern::symbol)
.collect(toImmutableList());
}
public static ExpectedValueProvider<DataOrganizationSpecification> specification(
List<String> partitionBy,
List<String> orderBy,
Map<String, SortOrder> orderings)
{
return new SpecificationProvider(
partitionBy
.stream()
.map(SymbolAlias::new)
.collect(toImmutableList()),
orderBy
.stream()
.map(SymbolAlias::new)
.collect(toImmutableList()),
orderings
.entrySet()
.stream()
.collect(toImmutableMap(entry -> new SymbolAlias(entry.getKey()), Map.Entry::getValue)));
}
public static Ordering sort(String field, SortItem.Ordering ordering, SortItem.NullOrdering nullOrdering)
{
return new Ordering(field, ordering, nullOrdering);
}
@Override
public String toString()
{
StringBuilder builder = new StringBuilder();
toString(builder, 0);
return builder.toString();
}
private void toString(StringBuilder builder, int indent)
{
checkState(matchers.stream().filter(PlanNodeMatcher.class::isInstance).count() <= 1);
builder.append(indentString(indent)).append("- ");
if (anyTree) {
builder.append("anyTree");
}
else {
builder.append("node");
}
Optional<PlanNodeMatcher> planNodeMatcher = matchers.stream()
.filter(PlanNodeMatcher.class::isInstance)
.map(PlanNodeMatcher.class::cast)
.findFirst();
if (planNodeMatcher.isPresent()) {
builder.append("(").append(planNodeMatcher.get().getNodeClass().getSimpleName()).append(")");
}
builder.append("\n");
List<Matcher> matchersToPrint = matchers.stream()
.filter(matcher -> !(matcher instanceof PlanNodeMatcher))
.collect(toImmutableList());
for (Matcher matcher : matchersToPrint) {
builder.append(indentString(indent + 1)).append(matcher.toString()).append("\n");
}
for (PlanMatchPattern pattern : sourcePatterns) {
pattern.toString(builder, indent + 1);
}
}
private static String indentString(int indent)
{
return Strings.repeat(" ", indent);
}
public static GroupingSetDescriptor globalAggregation()
{
return singleGroupingSet();
}
public static GroupingSetDescriptor singleGroupingSet(String... groupingKeys)
{
return singleGroupingSet(ImmutableList.copyOf(groupingKeys));
}
public static GroupingSetDescriptor singleGroupingSet(List<String> groupingKeys)
{
Set<Integer> globalGroupingSets;
if (groupingKeys.isEmpty()) {
globalGroupingSets = ImmutableSet.of(0);
}
else {
globalGroupingSets = ImmutableSet.of();
}
return new GroupingSetDescriptor(groupingKeys, 1, globalGroupingSets);
}
public static class GroupingSetDescriptor
{
private final List<String> groupingKeys;
private final int groupingSetCount;
private final Set<Integer> globalGroupingSets;
public GroupingSetDescriptor(List<String> groupingKeys, int groupingSetCount, Set<Integer> globalGroupingSets)
{
this.groupingKeys = groupingKeys;
this.groupingSetCount = groupingSetCount;
this.globalGroupingSets = globalGroupingSets;
}
public List<String> getGroupingKeys()
{
return groupingKeys;
}
public int getGroupingSetCount()
{
return groupingSetCount;
}
public Set<Integer> getGlobalGroupingSets()
{
return globalGroupingSets;
}
@Override
public String toString()
{
return toStringHelper(this)
.add("keys", groupingKeys)
.add("count", groupingSetCount)
.add("globalSets", globalGroupingSets)
.toString();
}
}
public static class Ordering
{
private final String field;
private final SortItem.Ordering ordering;
private final SortItem.NullOrdering nullOrdering;
private Ordering(String field, SortItem.Ordering ordering, SortItem.NullOrdering nullOrdering)
{
this.field = field;
this.ordering = ordering;
this.nullOrdering = nullOrdering;
}
public String getField()
{
return field;
}
public SortItem.Ordering getOrdering()
{
return ordering;
}
public SortItem.NullOrdering getNullOrdering()
{
return nullOrdering;
}
public SortOrder getSortOrder()
{
checkState(nullOrdering != UNDEFINED, "nullOrdering is undefined");
if (ordering == ASCENDING) {
if (nullOrdering == FIRST) {
return ASC_NULLS_FIRST;
}
else {
return ASC_NULLS_LAST;
}
}
else {
checkState(ordering == DESCENDING);
if (nullOrdering == FIRST) {
return DESC_NULLS_FIRST;
}
else {
return DESC_NULLS_LAST;
}
}
}
@Override
public String toString()
{
String result = field + " " + ordering;
if (nullOrdering != UNDEFINED) {
result += " NULLS " + nullOrdering;
}
return result;
}
}
}