CanonicalRowExpressionRewriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.expressions;

import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;

public class CanonicalRowExpressionRewriter
        extends RowExpressionRewriter<Void>
{
    private final boolean removeConstants;

    private CanonicalRowExpressionRewriter(boolean removeConstants)
    {
        this.removeConstants = removeConstants;
    }

    public static RowExpression canonicalizeRowExpression(RowExpression expression, boolean removeConstants)
    {
        return RowExpressionTreeRewriter.rewriteWith(new CanonicalRowExpressionRewriter(removeConstants), expression, null);
    }

    @Override
    public RowExpression rewriteInputReference(InputReferenceExpression input, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        return input.canonicalize();
    }

    @Override
    public RowExpression rewriteCall(CallExpression call, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        List<RowExpression> arguments = rewrite(call.getArguments(), context, treeRewriter);

        if (!sameElements(call.getArguments(), arguments)) {
            return new CallExpression(Optional.empty(), call.getDisplayName(), call.getFunctionHandle(), call.getType(), arguments);
        }
        return call.canonicalize();
    }

    @Override
    public RowExpression rewriteConstant(ConstantExpression literal, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        if (!removeConstants) {
            return literal.canonicalize();
        }
        // We replace the constant value with null.
        return new ConstantExpression(null, literal.getType());
    }

    @Override
    public RowExpression rewriteLambda(LambdaDefinitionExpression lambda, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        RowExpression body = treeRewriter.rewrite(lambda.getBody(), context);
        if (body != lambda.getBody()) {
            return new LambdaDefinitionExpression(Optional.empty(), lambda.getArgumentTypes(), lambda.getArguments(), body);
        }

        return lambda.canonicalize();
    }

    @Override
    public RowExpression rewriteVariableReference(VariableReferenceExpression variable, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        return variable.canonicalize();
    }

    @Override
    public RowExpression rewriteSpecialForm(SpecialFormExpression specialForm, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        List<RowExpression> arguments = rewrite(specialForm.getArguments(), context, treeRewriter);

        if (!sameElements(specialForm.getArguments(), arguments)) {
            return new SpecialFormExpression(Optional.empty(), specialForm.getForm(), specialForm.getType(), arguments);
        }
        return specialForm.canonicalize();
    }

    private List<RowExpression> rewrite(List<RowExpression> items, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
    {
        List<RowExpression> rewrittenExpressions = new ArrayList<>();
        for (RowExpression expression : items) {
            rewrittenExpressions.add(treeRewriter.rewrite(expression, context));
        }
        return Collections.unmodifiableList(rewrittenExpressions);
    }

    @SuppressWarnings("ObjectEquality")
    private static <T> boolean sameElements(Collection<? extends T> a, Collection<? extends T> b)
    {
        if (a.size() != b.size()) {
            return false;
        }

        Iterator<? extends T> first = a.iterator();
        Iterator<? extends T> second = b.iterator();

        while (first.hasNext() && second.hasNext()) {
            if (first.next() != second.next()) {
                return false;
            }
        }

        return true;
    }
}