RewriteCaseToMap.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableSet;

import java.lang.invoke.MethodHandle;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.SystemSessionProperties.REWRITE_CASE_TO_MAP_ENABLED;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.coalesce;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.constantNull;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class RewriteCaseToMap
        extends RowExpressionRewriteRuleSet
{
    public RewriteCaseToMap(FunctionAndTypeManager functionAndTypeManager)
    {
        super(new Rewriter(functionAndTypeManager));
    }

    private static class Rewriter
            implements PlanRowExpressionRewriter
    {
        private final CaseToMapRewriter caseToMapRewriter;

        public Rewriter(FunctionAndTypeManager functionAndTypeManager)
        {
            requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
            this.caseToMapRewriter = new CaseToMapRewriter(functionAndTypeManager);
        }

        @Override
        public RowExpression rewrite(RowExpression expression, Rule.Context context)
        {
            return RowExpressionTreeRewriter.rewriteWith(caseToMapRewriter, expression);
        }
    }

    private static class CaseToMapRewriter
            extends RowExpressionRewriter<Void>
    {
        private final FunctionAndTypeManager functionAndTypeManager;
        private final FunctionResolution functionResolution;
        private final LogicalRowExpressions logicalRowExpressions;

        private CaseToMapRewriter(FunctionAndTypeManager functionAndTypeManager)
        {
            this.functionAndTypeManager = functionAndTypeManager;
            this.functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());
            this.logicalRowExpressions = new LogicalRowExpressions(
                    new RowExpressionDeterminismEvaluator(functionAndTypeManager),
                    functionResolution,
                    functionAndTypeManager);
        }

        private boolean addKeyValue(RowExpression key, Set<RowExpression> keySet, List<RowExpression> keys, RowExpression value, List<RowExpression> values)
        {
            // matching types and non-null values only allowed
            if (!(key instanceof ConstantExpression) ||
                    ((ConstantExpression) key).getValue() == null ||
                    (keys.size() > 0 && !keys.get(0).getType().equals(key.getType()))) {
                return false;
            }

            if (keySet.add(key)) {
                // We allow all same type only
                if (values.size() > 0 && !values.get(0).getType().equals(value.getType())) {
                    return false;
                }

                keys.add(key);
                values.add(value);
            }

            return true;
        }

        @Override
        public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
        {
            if (node.getForm() != SWITCH) {
                return rewriteRowExpression(node, context, treeRewriter);
            }

            // by construction we should have at least one WHEN and ELSE is always added if missing
            int numArgs = node.getArguments().size();
            RowExpression lastArg = node.getArguments().get(numArgs - 1);

            checkState(numArgs >= 2);
            checkState(!(lastArg instanceof SpecialFormExpression && ((SpecialFormExpression) lastArg).getForm().equals(WHEN)));

            if (!(lastArg instanceof ConstantExpression)) {
                return node;
            }

            RowExpression firstArg = node.getArguments().get(0);
            Set<RowExpression> keySet = new HashSet<>();
            List<RowExpression> whens = new ArrayList<RowExpression>(node.getArguments().size());
            List<RowExpression> thens = new ArrayList<RowExpression>(node.getArguments().size());
            RowExpression checkExpr;
            int start;

            if (!(firstArg instanceof SpecialFormExpression && ((SpecialFormExpression) firstArg).getForm().equals(WHEN))) {
                if (firstArg.equals(constant(true, BOOLEAN))) {
                    // We generate weird CASE (true) WHEN p1 THEN v1 etc. for non-searched case
                    // So drop the true
                    checkExpr = null;
                }
                else {
                    checkExpr = firstArg;
                }

                start = 1;
            }
            else {
                checkExpr = null;
                start = 0;
            }

            for (int i = start; i < numArgs - 1; i++) {
                RowExpression whenClause = node.getArguments().get(i);
                RowExpression value = whenClause.getChildren().get(1);

                if (!(value instanceof ConstantExpression)) {
                    // THEN is not a constant
                    return node;
                }

                RowExpression when = whenClause.getChildren().get(0);
                RowExpression key;
                RowExpression curCheck;

                if (when instanceof ConstantExpression) {
                    if (!addKeyValue(when, keySet, whens, value, thens)) {
                        return node;
                    }
                }
                else if (logicalRowExpressions.isEqualsExpression(when)) {
                    RowExpression lhs = when.getChildren().get(0);
                    RowExpression rhs = when.getChildren().get(1);

                    if (!lhs.getType().equals(rhs.getType())) {
                        // We keep it simple
                        return node;
                    }

                    if (lhs instanceof ConstantExpression) {
                        curCheck = rhs;
                        key = lhs;
                    }
                    else if (rhs instanceof ConstantExpression) {
                        curCheck = lhs;
                        key = rhs;
                    }
                    else {
                        return node;
                    }

                    if (checkExpr == null) {
                        checkExpr = curCheck;
                    }
                    else if (!curCheck.equals(checkExpr)) {
                        return node;
                    }

                    if (!addKeyValue(key, keySet, whens, value, thens)) {
                        return node;
                    }
                }
                else if (when instanceof SpecialFormExpression && ((SpecialFormExpression) when).getForm() == IN) {
                    curCheck = ((SpecialFormExpression) when).getArguments().get(0);
                    if (checkExpr == null) {
                        checkExpr = curCheck;
                    }
                    else if (!curCheck.equals(checkExpr)) {
                        return node;
                    }

                    // For IN also we try to gather the args
                    for (int j = 1; j < ((SpecialFormExpression) when).getArguments().size(); j++) {
                        key = ((SpecialFormExpression) when).getArguments().get(j);
                        if (!addKeyValue(key, keySet, whens, value, thens)) {
                            return node;
                        }
                    }
                }
                else {
                    return node;
                }
            }

            if (checkExpr == null) {
                return node;
            }

            // Here we have all values!
            RowExpression mapLookup = makeMapAndAccess(whens, thens, checkExpr);

            // if there is a non-trivial else, we coalesce
            if (lastArg != null && !lastArg.equals(constantNull(thens.get(0).getType()))) {
                // Null could be a legit value so we coalesce to the else part only if there was no key match
                RowExpression keyArray = call("ARRAY", functionResolution.arrayConstructor(whens.stream().map(x -> x.getType()).collect(Collectors.toList())), new ArrayType(whens.get(0).getType()), whens);
                RowExpression contains = call(functionAndTypeManager, "contains", BOOLEAN, keyArray, checkExpr);
                return coalesce(mapLookup, specialForm(IF, mapLookup.getType(), contains, constant(null, mapLookup.getType()), lastArg));
            }

            return mapLookup;
        }

        private RowExpression makeMapAndAccess(List<RowExpression> keys, List<RowExpression> values, RowExpression mapIndex)
        {
            RowExpression keyArray = call("ARRAY", functionResolution.arrayConstructor(keys.stream().map(x -> x.getType()).collect(Collectors.toList())), new ArrayType(keys.get(0).getType()), keys);
            RowExpression valueArray = call("ARRAY", functionResolution.arrayConstructor(values.stream().map(x -> x.getType()).collect(Collectors.toList())), new ArrayType(values.get(0).getType()), values);
            Type keyType = keys.get(0).getType();
            Type valueType = values.get(0).getType();
            MethodHandle keyEquals =
                    functionAndTypeManager.getJavaScalarFunctionImplementation(
                            functionAndTypeManager.resolveOperator(OperatorType.EQUAL, fromTypes(keyType, keyType))).getMethodHandle();
            MethodHandle keyHashcode =
                    functionAndTypeManager.getJavaScalarFunctionImplementation(
                            functionAndTypeManager.resolveOperator(OperatorType.HASH_CODE, fromTypes(keyType))).getMethodHandle();
            RowExpression map = call(functionAndTypeManager, "MAP", new MapType(keyType, valueType, keyEquals, keyHashcode), keyArray, valueArray);
            return call(functionAndTypeManager, "element_at", valueType, map, mapIndex);
        }
    }

    @Override
    public boolean isRewriterEnabled(Session session)
    {
        return session.getSystemProperty(REWRITE_CASE_TO_MAP_ENABLED, Boolean.class);
    }

    @Override
    public Set<Rule<?>> rules()
    {
        return ImmutableSet.of(
                projectRowExpressionRewriteRule(),
                filterRowExpressionRewriteRule(),
                joinRowExpressionRewriteRule());
    }
}