FunctionArgumentCheckerForAccessControlUtils.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.analyzer;

import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.sql.tree.DereferenceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.FunctionCall;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.SubscriptExpression;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.sql.tree.StackableAstVisitor.StackableAstVisitorContext;
import static com.google.common.base.Preconditions.checkState;
import static java.lang.Math.toIntExact;
import static java.util.Collections.reverse;

public class FunctionArgumentCheckerForAccessControlUtils
{
    private static final QualifiedName TRANSFORM = QualifiedName.of("transform");
    private static final QualifiedName CARDINALITY = QualifiedName.of("cardinality");

    private FunctionArgumentCheckerForAccessControlUtils() {}

    // Returns whether function argument at `argumentIndex` for function `node` needs to be checked
    // for column level access control.
    // For e.g., consider SQL `transform(arr, col -> col.x)`
    // Here, we only need to check for access of subfield `x` in column `arr` which is of type `Array<struct>`.
    // So we can just parse lambda and ignore the first argument for access checks.
    public static boolean isUnusedArgumentForAccessControl(FunctionCall node, int argumentIndex, ExpressionAnalyzer.Context context)
    {
        if (node.getName().equals(TRANSFORM)) {
            checkState(node.getArguments().size() == 2);
            return argumentIndex == 0;
        }
        if (node.getName().equals(CARDINALITY)) {
            checkState(node.getArguments().size() == 1);
            return argumentIndex == 0;
        }
        return false;
    }

    // Parses arguments of function `node` which are a lambda expression, and returns a map
    // of their lambda arguments to resolved subfield.
    // For e.g., consider SQL `SELECT transform(arr, col -> col.x) FROM table`
    // Return value = Map('col' -> ResolvedSubfield(table.arr))
    public static Map<Identifier, ResolvedSubfield> getResolvedLambdaArguments(
            FunctionCall node,
            StackableAstVisitorContext<ExpressionAnalyzer.Context> context,
            Map<NodeRef<Expression>, Type> expressionTypes)
    {
        ImmutableMap.Builder<Identifier, ResolvedSubfield> resolvedLambdaArguments = ImmutableMap.builder();
        if (node.getName().equals(TRANSFORM)) {
            checkState(node.getArguments().size() == 2);
            if (!(node.getArguments().get(1) instanceof LambdaExpression)) {
                return ImmutableMap.of();
            }
            Expression arrayExpression = node.getArguments().get(0);
            LambdaExpression lambdaExpression = ((LambdaExpression) node.getArguments().get(1));
            Optional<ResolvedSubfield> resolvedSubfield = resolveSubfield(arrayExpression, context, expressionTypes);
            if (resolvedSubfield.isPresent()) {
                resolvedLambdaArguments.put(
                        lambdaExpression.getArguments().get(0).getName(),
                        resolvedSubfield.get());
            }
        }
        return resolvedLambdaArguments.build();
    }

    public static Optional<ResolvedSubfield> resolveSubfield(
            Expression node,
            StackableAstVisitorContext<ExpressionAnalyzer.Context> context,
            Map<NodeRef<Expression>, Type> expressionTypes)
    {
        // If expression is nested with multiple dereferences and subscripts, we only look at the topmost one.
        if (!isTopMostReference(node, context)) {
            return Optional.empty();
        }

        Scope scope = context.getContext().getScope();
        Expression childNode = node;
        List<Subfield.PathElement> columnDereferences = new ArrayList<>();
        while (true) {
            // Dereference row/array/map expressions
            if (childNode instanceof SubscriptExpression) {
                SubscriptExpression subscriptExpression = (SubscriptExpression) childNode;
                childNode = subscriptExpression.getBase();
                Type baseType = expressionTypes.get(NodeRef.of(childNode));
                if (baseType == null || !(baseType instanceof RowType)) {
                    continue;
                }
                int index = toIntExact(((LongLiteral) subscriptExpression.getIndex()).getValue());
                RowType baseRowType = (RowType) baseType;
                Optional<String> dereference = baseRowType.getFields().get(index - 1).getName();
                if (!dereference.isPresent()) {
                    break;
                }
                columnDereferences.add(new Subfield.NestedField(dereference.get()));
                continue;
            }

            QualifiedName childQualifiedName;
            // Dereference subfield expressions
            if (childNode instanceof DereferenceExpression) {
                childQualifiedName = DereferenceExpression.getQualifiedName((DereferenceExpression) childNode);
            }
            // Base case
            else if (childNode instanceof Identifier) {
                childQualifiedName = QualifiedName.of(((Identifier) childNode).getValue());
            }
            else {
                break;
            }
            // If we found the full de-referenced expression, return it as a ResolvedSubfield
            if (childQualifiedName != null) {
                Optional<ResolvedField> resolvedField = scope.tryResolveField(childNode, childQualifiedName);
                if (resolvedField.isPresent() && !resolvedField.get().getField().getOriginTable().isPresent()) {
                    // Try to resolve using lambda expressions
                    Optional<ResolvedSubfield> resolvedSubField = Optional.ofNullable(context.getContext().getResolvedLambdaArguments().get(childNode));
                    if (resolvedSubField.isPresent()) {
                        resolvedField = Optional.of(resolvedSubField.get().getResolvedField());
                        columnDereferences.addAll(Lists.reverse(resolvedSubField.get().getSubfield().getPath()));
                    }
                }
                if (resolvedField.isPresent() &&
                        resolvedField.get().getField().getOriginColumnName().isPresent() &&
                        resolvedField.get().getField().getOriginTable().isPresent()) {
                    reverse(columnDereferences);
                    return Optional.of(new ResolvedSubfield(
                            resolvedField.get(),
                            new Subfield(resolvedField.get().getField().getOriginColumnName().get(), columnDereferences)));
                }
            }
            // If we cannot resolve full de-referenced name, that means that there are
            // more dereferences to be resolved, so we continue the while loop with new childNode.
            if (childNode instanceof DereferenceExpression) {
                columnDereferences.add(new Subfield.NestedField(((DereferenceExpression) childNode).getField().getValue()));
                childNode = ((DereferenceExpression) childNode).getBase();
                continue;
            }
            break;
        }
        return Optional.empty();
    }

    public static boolean isDereferenceOrSubscript(Expression node)
    {
        return node instanceof DereferenceExpression || node instanceof SubscriptExpression;
    }

    public static boolean isTopMostReference(Expression node, StackableAstVisitorContext<ExpressionAnalyzer.Context> context)
    {
        if (!context.getPreviousNode().isPresent()) {
            return true;
        }
        return !isDereferenceOrSubscript((Expression) context.getPreviousNode().get());
    }
}