RowExpressionTreeRewriter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.expressions;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.ExistsExpression;
import com.facebook.presto.spi.relation.InSubqueryExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.QuantifiedComparisonExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionVisitor;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.UnresolvedSymbolExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
public final class RowExpressionTreeRewriter<C>
{
private final RowExpressionRewriter<C> rewriter;
private final RowExpressionVisitor<RowExpression, Context<C>> visitor;
public static <C, T extends RowExpression> T rewriteWith(RowExpressionRewriter<C> rewriter, T node)
{
return new RowExpressionTreeRewriter<>(rewriter).rewrite(node, null);
}
public static <C, T extends RowExpression> T rewriteWith(RowExpressionRewriter<C> rewriter, T node, C context)
{
return new RowExpressionTreeRewriter<>(rewriter).rewrite(node, context);
}
public RowExpressionTreeRewriter(RowExpressionRewriter<C> rewriter)
{
this.rewriter = rewriter;
this.visitor = new RewritingVisitor();
}
private List<RowExpression> rewrite(List<RowExpression> items, Context<C> context)
{
List<RowExpression> rewrittenExpressions = new ArrayList<>();
for (RowExpression expression : items) {
rewrittenExpressions.add(rewrite(expression, context.get()));
}
return Collections.unmodifiableList(rewrittenExpressions);
}
@SuppressWarnings("unchecked")
public <T extends RowExpression> T rewrite(T node, C context)
{
return (T) node.accept(visitor, new Context<>(context, false));
}
/**
* Invoke the default rewrite logic explicitly. Specifically, it skips the invocation of the expression rewriter for the provided node.
*/
@SuppressWarnings("unchecked")
public <T extends RowExpression> T defaultRewrite(T node, C context)
{
return (T) node.accept(visitor, new Context<>(context, true));
}
private class RewritingVisitor
implements RowExpressionVisitor<RowExpression, Context<C>>
{
@Override
public RowExpression visitInputReference(InputReferenceExpression input, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteInputReference(input, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
return input;
}
@Override
public RowExpression visitCall(CallExpression call, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteCall(call, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
List<RowExpression> arguments = rewrite(call.getArguments(), context);
if (!sameElements(call.getArguments(), arguments)) {
return new CallExpression(call.getSourceLocation(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), arguments);
}
return call;
}
@Override
public RowExpression visitConstant(ConstantExpression literal, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteConstant(literal, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
return literal;
}
@Override
public RowExpression visitLambda(LambdaDefinitionExpression lambda, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteLambda(lambda, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
RowExpression body = rewrite(lambda.getBody(), context.get());
if (body != lambda.getBody()) {
return new LambdaDefinitionExpression(lambda.getSourceLocation(), lambda.getArgumentTypes(), lambda.getArguments(), body);
}
return lambda;
}
@Override
public RowExpression visitVariableReference(VariableReferenceExpression variable, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteVariableReference(variable, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
return variable;
}
@Override
public RowExpression visitSpecialForm(SpecialFormExpression specialForm, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteSpecialForm(specialForm, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
List<RowExpression> arguments = rewrite(specialForm.getArguments(), context);
if (!sameElements(specialForm.getArguments(), arguments)) {
return new SpecialFormExpression(specialForm.getForm(), specialForm.getType(), arguments);
}
return specialForm;
}
@Override
public RowExpression visitInSubqueryExpression(InSubqueryExpression inSubqueryExpression, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteRowExpression(inSubqueryExpression, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
VariableReferenceExpression value = rewrite(inSubqueryExpression.getValue(), context.get());
VariableReferenceExpression subquery = rewrite(inSubqueryExpression.getSubquery(), context.get());
if (inSubqueryExpression.getValue() != value || inSubqueryExpression.getSubquery() != subquery) {
return new InSubqueryExpression(inSubqueryExpression.getSourceLocation(), value, subquery);
}
return inSubqueryExpression;
}
@Override
public RowExpression visitQuantifiedComparisonExpression(QuantifiedComparisonExpression quantifiedComparisonExpression, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteRowExpression(quantifiedComparisonExpression, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
RowExpression value = rewrite(quantifiedComparisonExpression.getValue(), context.get());
RowExpression subquery = rewrite(quantifiedComparisonExpression.getSubquery(), context.get());
if (quantifiedComparisonExpression.getValue() != value || quantifiedComparisonExpression.getSubquery() != subquery) {
return new QuantifiedComparisonExpression(
quantifiedComparisonExpression.getSourceLocation(),
quantifiedComparisonExpression.getOperator(),
quantifiedComparisonExpression.getQuantifier(),
value,
subquery);
}
return quantifiedComparisonExpression;
}
@Override
public RowExpression visitExistsExpression(ExistsExpression existsExpression, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteRowExpression(existsExpression, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
RowExpression subquery = rewrite(existsExpression.getSubquery(), context.get());
if (existsExpression.getSubquery() != subquery) {
return new ExistsExpression(existsExpression.getSourceLocation(), subquery);
}
return existsExpression;
}
@Override
public RowExpression visitUnresolvedSymbolExpression(UnresolvedSymbolExpression unresolvedSymbolExpression, Context<C> context)
{
if (!context.isDefaultRewrite()) {
RowExpression result = rewriter.rewriteRowExpression(unresolvedSymbolExpression, context.get(), RowExpressionTreeRewriter.this);
if (result != null) {
return result;
}
}
return unresolvedSymbolExpression;
}
}
public static class Context<C>
{
private final boolean defaultRewrite;
private final C context;
private Context(C context, boolean defaultRewrite)
{
this.context = context;
this.defaultRewrite = defaultRewrite;
}
public C get()
{
return context;
}
public boolean isDefaultRewrite()
{
return defaultRewrite;
}
}
@SuppressWarnings("ObjectEquality")
private static <T> boolean sameElements(Collection<? extends T> a, Collection<? extends T> b)
{
if (a.size() != b.size()) {
return false;
}
Iterator<? extends T> first = a.iterator();
Iterator<? extends T> second = b.iterator();
while (first.hasNext() && second.hasNext()) {
if (first.next() != second.next()) {
return false;
}
}
return true;
}
}