PushdownSubfields.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.optimizations;

import com.facebook.presto.Session;
import com.facebook.presto.common.QualifiedObjectName;
import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.Subfield.NestedField;
import com.facebook.presto.common.Subfield.PathElement;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor;
import com.facebook.presto.spi.function.LambdaArgumentDescriptor;
import com.facebook.presto.spi.function.LambdaDescriptor;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.CteProducerNode;
import com.facebook.presto.spi.plan.DeleteNode;
import com.facebook.presto.spi.plan.DistinctLimitNode;
import com.facebook.presto.spi.plan.EquiJoinClause;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.JoinNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.OrderingScheme;
import com.facebook.presto.spi.plan.OutputNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.SemiJoinNode;
import com.facebook.presto.spi.plan.SortNode;
import com.facebook.presto.spi.plan.SpatialJoinNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.TableWriterNode;
import com.facebook.presto.spi.plan.TopNNode;
import com.facebook.presto.spi.plan.UnionNode;
import com.facebook.presto.spi.plan.UnnestNode;
import com.facebook.presto.spi.plan.WindowNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.ExpressionOptimizer;
import com.facebook.presto.spi.relation.ExpressionOptimizerProvider;
import com.facebook.presto.spi.relation.LambdaDefinitionExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.plan.ApplyNode;
import com.facebook.presto.sql.planner.plan.ExplainAnalyzeNode;
import com.facebook.presto.sql.planner.plan.GroupIdNode;
import com.facebook.presto.sql.planner.plan.IndexJoinNode;
import com.facebook.presto.sql.planner.plan.RowNumberNode;
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TopNRowNumberNode;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.tree.QualifiedName;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import io.airlift.slice.Slice;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.IntStream;

import static com.facebook.presto.SystemSessionProperties.isLegacyUnnest;
import static com.facebook.presto.SystemSessionProperties.isPushSubfieldsForMapFunctionsEnabled;
import static com.facebook.presto.SystemSessionProperties.isPushdownSubfieldsEnabled;
import static com.facebook.presto.SystemSessionProperties.isPushdownSubfieldsFromArrayLambdasEnabled;
import static com.facebook.presto.common.Subfield.allSubscripts;
import static com.facebook.presto.common.Subfield.noSubfield;
import static com.facebook.presto.common.type.TypeUtils.readNativeValue;
import static com.facebook.presto.common.type.Varchars.isVarcharType;
import static com.facebook.presto.metadata.BuiltInTypeAndFunctionNamespaceManager.JAVA_BUILTIN_NAMESPACE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;

public class PushdownSubfields
        implements PlanOptimizer
{
    private final Metadata metadata;
    private final ExpressionOptimizerProvider expressionOptimizerProvider;
    private boolean isEnabledForTesting;

    public PushdownSubfields(Metadata metadata, ExpressionOptimizerProvider expressionOptimizerProvider)
    {
        this.metadata = requireNonNull(metadata, "metadata is null");
        this.expressionOptimizerProvider = requireNonNull(expressionOptimizerProvider, "expressionOptimizerProvider is null");
    }

    @Override
    public void setEnabledForTesting(boolean isSet)
    {
        isEnabledForTesting = isSet;
    }

    @Override
    public boolean isEnabled(Session session)
    {
        return isEnabledForTesting || isPushdownSubfieldsEnabled(session);
    }

    @Override
    public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider types, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector)
    {
        requireNonNull(plan, "plan is null");
        requireNonNull(session, "session is null");
        requireNonNull(types, "types is null");

        if (!isEnabled(session)) {
            return PlanOptimizerResult.optimizerResult(plan, false);
        }

        Rewriter rewriter = new Rewriter(session, metadata, expressionOptimizerProvider);
        PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new Rewriter.Context());
        return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged());
    }

    private static class Rewriter
            extends SimplePlanRewriter<Rewriter.Context>
    {
        private final Session session;
        private final Metadata metadata;
        private final FunctionResolution functionResolution;
        private final ExpressionOptimizer expressionOptimizer;
        private final SubfieldExtractor subfieldExtractor;
        private static final QualifiedObjectName ARBITRARY_AGGREGATE_FUNCTION = QualifiedObjectName.valueOf(JAVA_BUILTIN_NAMESPACE, "arbitrary");
        private boolean planChanged;

        public Rewriter(Session session, Metadata metadata, ExpressionOptimizerProvider expressionOptimizerProvider)
        {
            this.session = requireNonNull(session, "session is null");
            this.metadata = requireNonNull(metadata, "metadata is null");
            requireNonNull(expressionOptimizerProvider, "expressionOptimizerProvider is null");
            this.functionResolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver());
            this.expressionOptimizer = expressionOptimizerProvider.getExpressionOptimizer(session.toConnectorSession());
            this.subfieldExtractor = new SubfieldExtractor(
                    functionResolution,
                    expressionOptimizer,
                    session.toConnectorSession(),
                    metadata.getFunctionAndTypeManager(),
                    session);
        }

        public boolean isPlanChanged()
        {
            return planChanged;
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getGroupingKeys());

            for (Map.Entry<VariableReferenceExpression, AggregationNode.Aggregation> entry : node.getAggregations().entrySet()) {
                VariableReferenceExpression variable = entry.getKey();
                AggregationNode.Aggregation aggregation = entry.getValue();

                // Allow sub-field pruning to pass through the arbitrary() aggregation
                QualifiedObjectName aggregateName = metadata.getFunctionAndTypeManager().getFunctionMetadata(aggregation.getCall().getFunctionHandle()).getName();
                if (ARBITRARY_AGGREGATE_FUNCTION.equals(aggregateName)) {
                    checkState(aggregation.getArguments().get(0) instanceof VariableReferenceExpression);
                    context.get().addAssignment(variable, (VariableReferenceExpression) aggregation.getArguments().get(0));
                }
                else {
                    aggregation.getArguments().forEach(expression -> expression.accept(subfieldExtractor, context.get()));
                }

                aggregation.getFilter().ifPresent(expression -> expression.accept(subfieldExtractor, context.get()));

                aggregation.getOrderBy()
                        .map(OrderingScheme::getOrderByVariables)
                        .ifPresent(context.get().variables::addAll);

                aggregation.getMask().ifPresent(context.get().variables::add);
            }

            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitApply(ApplyNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getCorrelation());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitDistinctLimit(DistinctLimitNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getDistinctVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitExplainAnalyze(ExplainAnalyzeNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getSource().getOutputVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitFilter(FilterNode node, RewriteContext<Context> context)
        {
            node.getPredicate().accept(subfieldExtractor, context.get());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitGroupId(GroupIdNode node, RewriteContext<Context> context)
        {
            for (Map.Entry<VariableReferenceExpression, VariableReferenceExpression> entry : node.getGroupingColumns().entrySet()) {
                context.get().addAssignment(entry.getKey(), entry.getValue());
            }

            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitIndexJoin(IndexJoinNode node, RewriteContext<Context> context)
        {
            node.getCriteria().stream()
                    .map(IndexJoinNode.EquiJoinClause::getProbe)
                    .forEach(context.get().variables::add);
            node.getCriteria().stream()
                    .map(IndexJoinNode.EquiJoinClause::getIndex)
                    .forEach(context.get().variables::add);
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitJoin(JoinNode node, RewriteContext<Context> context)
        {
            node.getCriteria().stream()
                    .map(EquiJoinClause::getLeft)
                    .forEach(context.get().variables::add);
            node.getCriteria().stream()
                    .map(EquiJoinClause::getRight)
                    .forEach(context.get().variables::add);

            node.getFilter()
                    .ifPresent(expression -> expression.accept(subfieldExtractor, context.get()));

            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getDistinctVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitOutput(OutputNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getOutputVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitCteProducer(CteProducerNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getSource().getOutputVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitProject(ProjectNode node, RewriteContext<Context> context)
        {
            for (Map.Entry<VariableReferenceExpression, RowExpression> entry : node.getAssignments().entrySet()) {
                VariableReferenceExpression variable = entry.getKey();
                RowExpression expression = entry.getValue();

                if (expression instanceof VariableReferenceExpression) {
                    context.get().addAssignment(variable, (VariableReferenceExpression) expression);
                    continue;
                }

                Optional<List<Subfield>> subfield = toSubfield(expression, functionResolution, expressionOptimizer, session.toConnectorSession(), metadata.getFunctionAndTypeManager(), isPushSubfieldsForMapFunctionsEnabled(session));
                if (subfield.isPresent()) {
                    subfield.get().forEach(element -> context.get().addAssignment(variable, element));
                    continue;
                }

                expression.accept(subfieldExtractor, context.get());
            }

            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitRowNumber(RowNumberNode node, RewriteContext<Context> context)
        {
            context.get().variables.add(node.getRowNumberVariable());
            context.get().variables.addAll(node.getPartitionBy());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext<Context> context)
        {
            context.get().variables.add(node.getSourceJoinVariable());
            context.get().variables.add(node.getFilteringSourceJoinVariable());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitSort(SortNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getOrderingScheme().getOrderByVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext<Context> context)
        {
            node.getFilter().accept(subfieldExtractor, context.get());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitTableScan(TableScanNode node, RewriteContext<Context> context)
        {
            if (context.get().subfields.isEmpty()) {
                return node;
            }

            ImmutableMap.Builder<VariableReferenceExpression, ColumnHandle> newAssignments = ImmutableMap.builder();

            for (Map.Entry<VariableReferenceExpression, ColumnHandle> entry : node.getAssignments().entrySet()) {
                VariableReferenceExpression variable = entry.getKey();
                if (context.get().variables.contains(variable)) {
                    newAssignments.put(entry);
                    continue;
                }

                List<Subfield> subfields = context.get().findSubfields(variable.getName());

                verify(!subfields.isEmpty(), "Missing variable: " + variable);

                String columnName = getColumnName(session, metadata, node.getTable(), entry.getValue());

                List<Subfield> subfieldsWithoutNoSubfield = subfields.stream().filter(subfield -> !containsNoSubfieldPathElement(subfield)).collect(toList());
                List<Subfield> subfieldsWithNoSubfield = subfields.stream().filter(subfield -> containsNoSubfieldPathElement(subfield)).collect(toList());

                // Prune subfields: if one subfield is a prefix of another subfield, keep the shortest one.
                // Example: {a.b.c, a.b} -> {a.b}
                List<Subfield> columnSubfields = subfieldsWithoutNoSubfield.stream()
                        .filter(subfield -> !prefixExists(subfield, subfieldsWithoutNoSubfield))
                        .map(Subfield::getPath)
                        .map(path -> new Subfield(columnName, path))
                        .collect(toList());

                columnSubfields.addAll(subfieldsWithNoSubfield.stream()
                        .filter(subfield -> !isPrefixOf(dropNoSubfield(subfield), subfieldsWithoutNoSubfield))
                        .map(Subfield::getPath)
                        .map(path -> new Subfield(columnName, path))
                        .collect(toList()));

                planChanged = true;
                newAssignments.put(variable, entry.getValue().withRequiredSubfields(ImmutableList.copyOf(columnSubfields)));
            }

            return new TableScanNode(
                    node.getSourceLocation(),
                    node.getId(),
                    node.getTable(),
                    node.getOutputVariables(),
                    newAssignments.build(),
                    node.getTableConstraints(),
                    node.getCurrentConstraint(),
                    node.getEnforcedConstraint(), node.getCteMaterializationInfo());
        }

        @Override
        public PlanNode visitTableWriter(TableWriterNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getColumns());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitDelete(DeleteNode node, RewriteContext<Context> context)
        {
            if (node.getInputDistribution().isPresent()) {
                context.get().variables.addAll(node.getInputDistribution().get().getInputVariables());
            }
            node.getRowId().ifPresent(r -> context.get().variables.add(r));
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitTopN(TopNNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getOrderingScheme().getOrderByVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext<Context> context)
        {
            context.get().variables.add(node.getRowNumberVariable());
            context.get().variables.addAll(node.getPartitionBy());
            context.get().variables.addAll(node.getOrderingScheme().getOrderByVariables());
            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitUnion(UnionNode node, RewriteContext<Context> context)
        {
            for (Map.Entry<VariableReferenceExpression, List<VariableReferenceExpression>> entry : node.getVariableMapping().entrySet()) {
                entry.getValue().forEach(variable -> context.get().addAssignment(entry.getKey(), variable));
            }

            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitUnnest(UnnestNode node, RewriteContext<Context> context)
        {
            ImmutableList.Builder<Subfield> newSubfields = ImmutableList.builder();
            for (Map.Entry<VariableReferenceExpression, List<VariableReferenceExpression>> entry : node.getUnnestVariables().entrySet()) {
                VariableReferenceExpression container = entry.getKey();
                boolean found = false;

                if (isRowType(container) && !isLegacyUnnest(session)) {
                    for (VariableReferenceExpression field : entry.getValue()) {
                        if (context.get().variables.contains(field)) {
                            found = true;
                            newSubfields.add(new Subfield(container.getName(), ImmutableList.of(allSubscripts(), nestedField(field.getName()))));
                        }
                        else {
                            List<Subfield> matchingSubfields = context.get().findSubfields(field.getName());
                            if (!matchingSubfields.isEmpty()) {
                                found = true;
                                matchingSubfields.stream()
                                        .map(Subfield::getPath)
                                        .map(path -> new Subfield(container.getName(), ImmutableList.<Subfield.PathElement>builder()
                                                .add(allSubscripts())
                                                .add(nestedField(field.getName()))
                                                .addAll(path)
                                                .build()))
                                        .forEach(newSubfields::add);
                            }
                        }
                    }
                }
                else {
                    for (VariableReferenceExpression field : entry.getValue()) {
                        if (context.get().variables.contains(field)) {
                            found = true;
                            context.get().variables.add(container);
                        }
                        else {
                            List<Subfield> matchingSubfields = context.get().findSubfields(field.getName());

                            if (!matchingSubfields.isEmpty()) {
                                found = true;
                                matchingSubfields.stream()
                                        .map(Subfield::getPath)
                                        .map(path -> new Subfield(container.getName(), ImmutableList.<PathElement>builder()
                                                .add(allSubscripts())
                                                .addAll(path)
                                                .build()))
                                        .forEach(newSubfields::add);
                            }
                        }
                    }
                }
                if (!found) {
                    context.get().variables.add(container);
                }
            }
            context.get().subfields.addAll(newSubfields.build());

            return context.defaultRewrite(node, context.get());
        }

        @Override
        public PlanNode visitWindow(WindowNode node, RewriteContext<Context> context)
        {
            context.get().variables.addAll(node.getSpecification().getPartitionBy());

            node.getSpecification().getOrderingScheme()
                    .map(OrderingScheme::getOrderByVariables)
                    .ifPresent(context.get().variables::addAll);

            node.getWindowFunctions().values().stream()
                    .map(WindowNode.Function::getFunctionCall)
                    .map(CallExpression::getArguments)
                    .flatMap(List::stream)
                    .forEach(expression -> expression.accept(subfieldExtractor, context.get()));

            node.getWindowFunctions().values().stream()
                    .map(WindowNode.Function::getFrame)
                    .map(WindowNode.Frame::getStartValue)
                    .filter(Optional::isPresent)
                    .map(Optional::get)
                    .forEach(context.get().variables::add);

            node.getWindowFunctions().values().stream()
                    .map(WindowNode.Function::getFrame)
                    .map(WindowNode.Frame::getEndValue)
                    .filter(Optional::isPresent)
                    .map(Optional::get)
                    .forEach(context.get().variables::add);

            return context.defaultRewrite(node, context.get());
        }

        private boolean isRowType(VariableReferenceExpression variable)
        {
            return variable.getType() instanceof ArrayType && ((ArrayType) variable.getType()).getElementType() instanceof RowType;
        }

        private static Subfield dropNoSubfield(Subfield subfield)
        {
            return new Subfield(subfield.getRootName(),
                    subfield.getPath().stream().filter(pathElement -> !(pathElement instanceof Subfield.NoSubfield)).collect(toImmutableList()));
        }

        private static boolean containsNoSubfieldPathElement(Subfield subfield)
        {
            return subfield.getPath().stream().anyMatch(pathElement -> pathElement instanceof Subfield.NoSubfield);
        }

        private static boolean prefixExists(Subfield subfieldPath, Collection<Subfield> subfieldPaths)
        {
            return subfieldPaths.stream().anyMatch(path -> path.isPrefix(subfieldPath));
        }

        private static boolean isPrefixOf(Subfield subfieldPath, Collection<Subfield> subfieldPaths)
        {
            return subfieldPaths.stream().anyMatch(subfieldPath::isPrefix);
        }

        private static String getColumnName(Session session, Metadata metadata, TableHandle tableHandle, ColumnHandle columnHandle)
        {
            return metadata.getColumnMetadata(session, tableHandle, columnHandle).getName();
        }

        private static Optional<List<Subfield>> toSubfield(
                RowExpression expression,
                FunctionResolution functionResolution,
                ExpressionOptimizer expressionOptimizer,
                ConnectorSession connectorSession,
                FunctionAndTypeManager functionAndTypeManager,
                boolean isPushdownSubfieldsForMapFunctionsEnabled)
        {
            ImmutableList.Builder<Subfield.PathElement> elements = ImmutableList.builder();
            while (true) {
                if (expression instanceof VariableReferenceExpression) {
                    return Optional.of(ImmutableList.of(new Subfield(((VariableReferenceExpression) expression).getName(), elements.build().reverse())));
                }
                if (expression instanceof CallExpression) {
                    ComplexTypeFunctionDescriptor functionDescriptor = functionAndTypeManager.getFunctionMetadata(((CallExpression) expression).getFunctionHandle()).getDescriptor();
                    Optional<Integer> pushdownSubfieldArgIndex = functionDescriptor.getPushdownSubfieldArgIndex();
                    if (pushdownSubfieldArgIndex.isPresent() &&
                            ((CallExpression) expression).getArguments().size() > pushdownSubfieldArgIndex.get() &&
                            ((CallExpression) expression).getArguments().get(pushdownSubfieldArgIndex.get()).getType() instanceof RowType
                            && !elements.build().isEmpty()) { // ensures pushdown only happens when a subfield is read from a column
                        expression = ((CallExpression) expression).getArguments().get(pushdownSubfieldArgIndex.get());
                        continue;
                    }
                }

                if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == DEREFERENCE) {
                    SpecialFormExpression dereference = (SpecialFormExpression) expression;
                    RowExpression base = dereference.getArguments().get(0);
                    RowType baseType = (RowType) base.getType();

                    RowExpression indexExpression = expressionOptimizer.optimize(
                            dereference.getArguments().get(1),
                            ExpressionOptimizer.Level.OPTIMIZED,
                            connectorSession);

                    if (indexExpression instanceof ConstantExpression) {
                        Object index = ((ConstantExpression) indexExpression).getValue();
                        verify(index != null, "Struct field index cannot be null");
                        if (index instanceof Number) {
                            Optional<String> fieldName = baseType.getFields().get(((Number) index).intValue()).getName();
                            if (fieldName.isPresent()) {
                                elements.add(nestedField(fieldName.get()));
                                expression = base;
                                continue;
                            }
                        }
                    }
                    return Optional.empty();
                }
                if (expression instanceof CallExpression &&
                        isSubscriptOrElementAtFunction((CallExpression) expression, functionResolution, functionAndTypeManager)) {
                    List<RowExpression> arguments = ((CallExpression) expression).getArguments();
                    RowExpression indexExpression = expressionOptimizer.optimize(
                            arguments.get(1),
                            ExpressionOptimizer.Level.OPTIMIZED,
                            connectorSession);

                    if (indexExpression instanceof ConstantExpression) {
                        Object index = ((ConstantExpression) indexExpression).getValue();
                        if (index == null) {
                            return Optional.empty();
                        }
                        if (index instanceof Number) {
                            //Fix for issue https://github.com/prestodb/presto/issues/22690
                            //Avoid negative index pushdown
                            if (((Number) index).longValue() < 0 && arguments.get(0).getType() instanceof ArrayType) {
                                return Optional.empty();
                            }

                            elements.add(new Subfield.LongSubscript(((Number) index).longValue()));
                            expression = arguments.get(0);
                            continue;
                        }

                        if (isVarcharType(indexExpression.getType())) {
                            elements.add(new Subfield.StringSubscript(((Slice) index).toStringUtf8()));
                            expression = arguments.get(0);
                            continue;
                        }
                    }
                    return Optional.empty();
                }
                // map_subset(feature, constant_array) is only accessing fields specified in feature map.
                // For example map_subset(feature, array[1, 2]) is equivalent to calling element_at(feature, 1) and element_at(feature, 2) for subfield extraction
                if (isPushdownSubfieldsForMapFunctionsEnabled && expression instanceof CallExpression && isMapSubSetWithConstantArray((CallExpression) expression, functionResolution)) {
                    CallExpression call = (CallExpression) expression;
                    ConstantExpression constantArray = (ConstantExpression) call.getArguments().get(1);
                    return extractSubfieldsFromArray(constantArray, (VariableReferenceExpression) call.getArguments().get(0));
                }

                // map_filter(feature, (k, v) -> k in (1, 2, 3)), map_filter(feature, (k, v) -> contains(array[1, 2, 3], k)), map_filter(feature, (k, v) -> k = 2) only access specified elements
                if (isPushdownSubfieldsForMapFunctionsEnabled && expression instanceof CallExpression && isMapFilterWithConstantFilterInMapKey((CallExpression) expression, functionResolution)) {
                    CallExpression call = (CallExpression) expression;
                    VariableReferenceExpression mapVariable = (VariableReferenceExpression) call.getArguments().get(0);
                    ImmutableList.Builder<Subfield> arguments = ImmutableList.builder();
                    if (((LambdaDefinitionExpression) call.getArguments().get(1)).getBody() instanceof SpecialFormExpression) {
                        List<RowExpression> mapKeys = ((SpecialFormExpression) ((LambdaDefinitionExpression) call.getArguments().get(1)).getBody()).getArguments().stream().skip(1).collect(toImmutableList());
                        for (RowExpression mapKey : mapKeys) {
                            Optional<Subfield> mapKeySubfield = extractSubfieldsFromSingleValue((ConstantExpression) mapKey, mapVariable);
                            if (!mapKeySubfield.isPresent()) {
                                return Optional.empty();
                            }
                            arguments.add(mapKeySubfield.get());
                        }
                        return Optional.of(arguments.build());
                    }
                    else if (((LambdaDefinitionExpression) call.getArguments().get(1)).getBody() instanceof CallExpression) {
                        CallExpression callExpression = (CallExpression) ((LambdaDefinitionExpression) call.getArguments().get(1)).getBody();
                        if (functionResolution.isArrayContainsFunction(callExpression.getFunctionHandle())) {
                            return extractSubfieldsFromArray((ConstantExpression) callExpression.getArguments().get(0), mapVariable);
                        }
                        else if (functionResolution.isEqualFunction(callExpression.getFunctionHandle())) {
                            ConstantExpression mapKey;
                            if (callExpression.getArguments().get(0) instanceof ConstantExpression) {
                                mapKey = (ConstantExpression) callExpression.getArguments().get(0);
                            }
                            else {
                                mapKey = (ConstantExpression) callExpression.getArguments().get(1);
                            }
                            Optional<Subfield> mapKeySubfield = extractSubfieldsFromSingleValue(mapKey, mapVariable);
                            return mapKeySubfield.map(ImmutableList::of);
                        }
                    }
                }
                return Optional.empty();
            }
        }

        private static Optional<List<Subfield>> extractSubfieldsFromArray(ConstantExpression constantArray, VariableReferenceExpression mapVariable)
        {
            ImmutableList.Builder<Subfield> arguments = ImmutableList.builder();
            checkState(constantArray.getValue() instanceof Block && constantArray.getType() instanceof ArrayType);
            Block arrayValue = (Block) constantArray.getValue();
            Type arrayElementType = ((ArrayType) constantArray.getType()).getElementType();
            for (int i = 0; i < arrayValue.getPositionCount(); ++i) {
                Object mapKey = readNativeValue(arrayElementType, arrayValue, i);
                if (mapKey == null) {
                    return Optional.empty();
                }
                if (mapKey instanceof Number) {
                    arguments.add(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.LongSubscript(((Number) mapKey).longValue()))));
                }
                if (isVarcharType(arrayElementType)) {
                    arguments.add(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.StringSubscript(((Slice) mapKey).toStringUtf8()))));
                }
            }
            return Optional.of(arguments.build());
        }

        private static Optional<Subfield> extractSubfieldsFromSingleValue(ConstantExpression mapKey, VariableReferenceExpression mapVariable)
        {
            Object value = mapKey.getValue();
            if (value == null) {
                return Optional.empty();
            }
            if (value instanceof Number) {
                return Optional.of(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.LongSubscript(((Number) value).longValue()))));
            }
            if (isVarcharType(mapKey.getType())) {
                return Optional.of(new Subfield(mapVariable.getName(), ImmutableList.of(new Subfield.StringSubscript(((Slice) value).toStringUtf8()))));
            }
            return Optional.empty();
        }

        private static NestedField nestedField(String name)
        {
            return new NestedField(name.toLowerCase(Locale.ENGLISH));
        }

        private static final class SubfieldExtractor
                extends DefaultRowExpressionTraversalVisitor<Context>
        {
            private final FunctionResolution functionResolution;
            private final ExpressionOptimizer expressionOptimizer;
            private final ConnectorSession connectorSession;
            private final FunctionAndTypeManager functionAndTypeManager;
            private final boolean isPushDownSubfieldsFromLambdasEnabled;
            private final boolean isPushdownSubfieldsForMapFunctionsEnabled;

            private SubfieldExtractor(
                    FunctionResolution functionResolution,
                    ExpressionOptimizer expressionOptimizer,
                    ConnectorSession connectorSession,
                    FunctionAndTypeManager functionAndTypeManager,
                    Session session)
            {
                this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
                this.expressionOptimizer = requireNonNull(expressionOptimizer, "expressionOptimizer is null");
                this.connectorSession = connectorSession;
                this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionAndTypeManager is null");
                requireNonNull(session);
                this.isPushDownSubfieldsFromLambdasEnabled = isPushdownSubfieldsFromArrayLambdasEnabled(session);
                this.isPushdownSubfieldsForMapFunctionsEnabled = isPushSubfieldsForMapFunctionsEnabled(session);
            }

            @Override
            public Void visitCall(CallExpression call, Context context)
            {
                ComplexTypeFunctionDescriptor functionDescriptor = functionAndTypeManager.getFunctionMetadata(call.getFunctionHandle()).getDescriptor();
                if (isSubscriptOrElementAtFunction(call, functionResolution, functionAndTypeManager) || isMapSubSetWithConstantArray(call, functionResolution) || isMapFilterWithConstantFilterInMapKey(call, functionResolution)) {
                    Optional<List<Subfield>> subfield = toSubfield(call, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager, isPushdownSubfieldsForMapFunctionsEnabled);
                    if (subfield.isPresent()) {
                        if (context.isPruningLambdaSubfieldsPossible()) {
                            subfield.get().forEach(item -> addRequiredLambdaSubfields(context, item));
                        }
                        else {
                            context.subfields.addAll(subfield.get());
                        }
                    }
                    else {
                        call.getArguments().forEach(argument -> argument.accept(this, context));
                    }
                    return null;
                }
                if (!isPushDownSubfieldsFromLambdasEnabled) {
                    context.setLambdaSubfields(Context.ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE);
                    call.getArguments().forEach(argument -> argument.accept(this, context));
                    return null;
                }
                Set<Subfield> lambdaSubfieldsOriginal = context.getLambdaSubfields();
                if ((functionDescriptor.isAccessingInputValues() && functionDescriptor.getLambdaDescriptors().isEmpty())) {
                    // If function internally accesses the input values we cannot prune any unaccessed lambda subfields since we do not know what subfields function accessed.
                    context.giveUpOnCollectingLambdaSubfields();
                }

                // We need to apply output to input transformation function in order to make sense of all lambda subfields accessed in outer functions w.r.t. the
                // input of the current function.
                if (functionDescriptor.getOutputToInputTransformationFunction().isPresent()) {
                    Set<Subfield> transformedLambdaSubfields =
                            functionDescriptor.getOutputToInputTransformationFunction().get().apply(context.getLambdaSubfields());
                    context.setLambdaSubfields(ImmutableSet.copyOf(transformedLambdaSubfields));
                }

                Set<Integer> argumentIndicesContainingMapOrArray = functionDescriptor.getArgumentIndicesContainingMapOrArray()
                        .orElseGet(() -> IntStream.range(0, call.getArguments().size())
                                .filter(argIndex -> isMapOrArrayOfRowType(call.getArguments().get(argIndex)))
                                .boxed()
                                .collect(toImmutableSet()));

                // All the lambda subfields collected in outer functions relate only to the arguments of the function specified in
                // functionDescriptor.argumentIndicesContainingMapOrArray.
                Map<Integer, Set<Subfield>> lambdaSubfieldsFromOuterFunctions = argumentIndicesContainingMapOrArray.stream()
                        .collect(toImmutableMap(callArgumentIndex -> callArgumentIndex, unused -> ImmutableSet.copyOf(context.getLambdaSubfields())));

                // If the function accepts lambdas, add all the lambda subfields from each lambda.
                Map<Integer, Set<Subfield>> lambdaSubfieldsFromCurrentFunction = ImmutableMap.of();
                for (LambdaDescriptor lambdaDescriptor : functionDescriptor.getLambdaDescriptors()) {
                    Optional<Map<Integer, Set<Subfield>>> lambdaSubfields = collectLambdaSubfields(call, lambdaDescriptor);
                    if (!lambdaSubfields.isPresent()) {
                        context.giveUpOnCollectingLambdaSubfields();
                        call.getArguments().forEach(argument -> argument.accept(this, context));
                        return null;
                    }
                    lambdaSubfieldsFromCurrentFunction = merge(lambdaSubfieldsFromCurrentFunction, lambdaSubfields.get());
                }

                Map<Integer, Set<Subfield>> lambdaSubfields = merge(lambdaSubfieldsFromOuterFunctions, lambdaSubfieldsFromCurrentFunction);

                lambdaSubfields = addNoSubfieldIfNoAccessedSubfieldsFound(call, lambdaSubfields);

                // We need to continue visiting the function arguments and collect all lambda subfields in inner function calls as well as non-lambda subfields in all
                // function arguments. Once reached the leaf node, we will try to prune the subfields of the input field, subscript, or subfield.
                for (int callArgumentIndex = 0; callArgumentIndex < call.getArguments().size(); callArgumentIndex++) {
                    // Since context is global during the traversal of all the nodes in expression tree, we need to pass lambda subfields only to those  function
                    // arguments that they relate to.
                    if (lambdaSubfields.containsKey(callArgumentIndex)) {
                        context.setLambdaSubfields(lambdaSubfields.get(callArgumentIndex));
                    }
                    else {
                        context.setLambdaSubfields(Context.ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE);
                    }
                    call.getArguments().get(callArgumentIndex).accept(this, context);
                }

                // When we are done with inner calls (child nodes) we need to restore lambda subfields we received from parent expression to handle such situations like
                // in example below
                // SELECT * FROM my_table WHERE ANY_MATCH(column1, x -> x.ds > '2023-01-01') AND ALL_MATCH(column2, x -> STRPOS(x.comment,  'Presto') > 0)
                // After we are done with ANY_MATCH, we need to restore the lambda subfields to what we received from parent node 'AND' so that it does not collide with
                // lambda subfields of ALL_MATCH function.
                context.setLambdaSubfields(lambdaSubfieldsOriginal);
                return null;
            }

            private static Map<Integer, Set<Subfield>> merge(Map<Integer, Set<Subfield>> s1, Map<Integer, Set<Subfield>> s2)
            {
                Map<Integer, Set<Subfield>> result = new HashMap<>(s1);
                s2.forEach((callArgumentIndex, subfields) -> result.merge(
                        callArgumentIndex,
                        subfields,
                        (lambdaSubfields1, lambdaSubfields2) -> ImmutableSet.<Subfield>builder().addAll(lambdaSubfields1).addAll(lambdaSubfields2).build()));
                return ImmutableMap.copyOf(result);
            }

            private static Map<Integer, Set<Subfield>> addNoSubfieldIfNoAccessedSubfieldsFound(CallExpression call, Map<Integer, Set<Subfield>> argumentIndexToLambdaSubfieldsMap)
            {
                ImmutableMap.Builder<Integer, Set<Subfield>> argumentIndexToLambdaSubfieldsMapBuilder = ImmutableMap.builder();
                for (Integer callArgumentIndex : argumentIndexToLambdaSubfieldsMap.keySet()) {
                    if (!argumentIndexToLambdaSubfieldsMap.get(callArgumentIndex).isEmpty()) {
                        argumentIndexToLambdaSubfieldsMapBuilder.put(callArgumentIndex, argumentIndexToLambdaSubfieldsMap.get(callArgumentIndex));
                    }
                    else {
                        RowExpression argument = call.getArguments().get(callArgumentIndex);
                        if (isMapOrArrayOfRowType(argument)) {
                            argumentIndexToLambdaSubfieldsMapBuilder.put(callArgumentIndex, ImmutableSet.of(new Subfield("", ImmutableList.of(allSubscripts(), noSubfield()))));
                        }
                    }
                }
                return argumentIndexToLambdaSubfieldsMapBuilder.build();
            }

            private static boolean isMapOrArrayOfRowType(RowExpression argument)
            {
                return (argument.getType() instanceof ArrayType && ((ArrayType) argument.getType()).getElementType() instanceof RowType) ||
                        (argument.getType() instanceof MapType && ((MapType) argument.getType()).getValueType() instanceof RowType);
            }

            private Optional<Map<Integer, Set<Subfield>>> collectLambdaSubfields(CallExpression call, LambdaDescriptor lambdaDescriptor)
            {
                Map<Integer, Set<Subfield>> argumentIndexToLambdaSubfieldsMap = new HashMap<>();
                if (!(call.getArguments().get(lambdaDescriptor.getCallArgumentIndex()) instanceof LambdaDefinitionExpression)) {
                    // In this case, we cannot prune the subfields because the function can potentially access all subfields
                    return Optional.empty();
                }
                LambdaDefinitionExpression lambda = (LambdaDefinitionExpression) call.getArguments().get(lambdaDescriptor.getCallArgumentIndex());

                Context subContext = new Context();
                lambda.getBody().accept(this, subContext);
                for (int lambdaArgumentIndex : lambdaDescriptor.getLambdaArgumentDescriptors().keySet()) {
                    final LambdaArgumentDescriptor lambdaArgumentDescriptor = lambdaDescriptor.getLambdaArgumentDescriptors().get(lambdaArgumentIndex);
                    int callArgumentIndex = lambdaArgumentDescriptor.getCallArgumentIndex();
                    argumentIndexToLambdaSubfieldsMap.putIfAbsent(callArgumentIndex, new HashSet<>());
                    String root = lambda.getArguments().get(lambdaArgumentIndex);
                    if (subContext.variables.stream().anyMatch(variable -> variable.getName().equals(root))) {
                        // The entire struct was accessed.
                        return Optional.empty();
                    }
                    Set<Subfield> transformedLambdaSubfields = lambdaArgumentDescriptor.getLambdaArgumentToInputTransformationFunction().apply(
                            subContext.subfields.stream()
                                    .filter(x -> x.getRootName().equals(root))
                                    .collect(toImmutableSet()));
                    argumentIndexToLambdaSubfieldsMap.get(callArgumentIndex).addAll(transformedLambdaSubfields);
                }
                return Optional.of(ImmutableMap.copyOf(argumentIndexToLambdaSubfieldsMap));
            }

            @Override
            public Void visitSpecialForm(SpecialFormExpression specialForm, Context context)
            {
                if (specialForm.getForm() == IS_NULL) {
                    if (specialForm.getArguments().get(0) instanceof VariableReferenceExpression && specialForm.getArguments().get(0).getType() instanceof RowType) {
                        context.subfields.add(new Subfield(((VariableReferenceExpression) specialForm.getArguments().get(0)).getName(), ImmutableList.of(noSubfield())));
                        return null;
                    }
                }
                else if (specialForm.getForm() != DEREFERENCE) {
                    specialForm.getArguments().forEach(argument -> argument.accept(this, context));
                    return null;
                }

                Optional<List<Subfield>> subfield = toSubfield(specialForm, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager, isPushdownSubfieldsForMapFunctionsEnabled);

                if (subfield.isPresent()) {
                    if (context.isPruningLambdaSubfieldsPossible()) {
                        subfield.get().forEach(item -> addRequiredLambdaSubfields(context, item));
                    }
                    else {
                        context.subfields.addAll(subfield.get());
                    }
                }
                else {
                    specialForm.getArguments().forEach(argument -> argument.accept(this, context));
                }
                return null;
            }

            /**
             * Adds lambda subfields from the context to the list of the required subfields of the field/subscript/subfield provided in parameter 'input'. This function should be
             * invoked
             * once we reached leaf node while visiting the expression tree. Effectively, it prunes all unaccessed subfields of the 'input'.
             *
             * @param context - SubfieldExtractor context
             * @param input - input field, subscript, or subfield, for which lambda subfields were collected.
             */
            private void addRequiredLambdaSubfields(Context context, Subfield input)
            {
                Set<Subfield> lambdaSubfields = context.getLambdaSubfields();
                for (Subfield lambdaSubfield : lambdaSubfields) {
                    List<PathElement> newPath = ImmutableList.<PathElement>builder()
                            .addAll(input.getPath())
                            .addAll(lambdaSubfield.getPath())
                            .build();
                    context.subfields.add(new Subfield(input.getRootName(), newPath));
                }
            }

            @Override
            public Void visitVariableReference(VariableReferenceExpression reference, Context context)
            {
                if (context.isPruningLambdaSubfieldsPossible()) {
                    toSubfield(reference, functionResolution, expressionOptimizer, connectorSession, functionAndTypeManager, isPushdownSubfieldsForMapFunctionsEnabled).get().forEach(item -> addRequiredLambdaSubfields(context, item));
                    return null;
                }
                context.variables.add(reference);
                return null;
            }
        }

        private static final class Context
        {
            public static final Set<Subfield> ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE = ImmutableSet.of(new Subfield("", ImmutableList.of(allSubscripts())));
            // Variables whose subfields cannot be pruned
            private final Set<VariableReferenceExpression> variables = new HashSet<>();
            private final Set<Subfield> subfields = new HashSet<>();
            private Set<Subfield> lambdaSubfields = ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE;

            private void addAssignment(VariableReferenceExpression variable, VariableReferenceExpression otherVariable)
            {
                if (variables.contains(variable)) {
                    variables.add(otherVariable);
                    return;
                }

                List<Subfield> matchingSubfields = findSubfields(variable.getName());
                verify(!matchingSubfields.isEmpty(), "Missing variable: " + variable);

                matchingSubfields.stream()
                        .map(Subfield::getPath)
                        .map(path -> new Subfield(otherVariable.getName(), path))
                        .forEach(subfields::add);
            }

            private void addAssignment(VariableReferenceExpression variable, Subfield subfield)
            {
                if (variables.contains(variable)) {
                    subfields.add(subfield);
                    return;
                }

                List<Subfield> matchingSubfields = findSubfields(variable.getName());
                verify(!matchingSubfields.isEmpty(), "Missing variable: " + variable);

                matchingSubfields.stream()
                        .map(Subfield::getPath)
                        .map(path -> new Subfield(subfield.getRootName(), ImmutableList.<PathElement>builder()
                                .addAll(subfield.getPath())
                                .addAll(path)
                                .build()))
                        .forEach(subfields::add);
            }

            private List<Subfield> findSubfields(String rootName)
            {
                return subfields.stream()
                        .filter(subfield -> rootName.equals(subfield.getRootName()))
                        .collect(toImmutableList());
            }

            public void setLambdaSubfields(Set<Subfield> lambdaSubfields)
            {
                this.lambdaSubfields = lambdaSubfields;
            }

            public Set<Subfield> getLambdaSubfields()
            {
                return lambdaSubfields;
            }

            private void giveUpOnCollectingLambdaSubfields()
            {
                setLambdaSubfields(ALL_SUBFIELDS_OF_ARRAY_ELEMENT_OR_MAP_VALUE);
            }

            private boolean isPruningLambdaSubfieldsPossible()
            {
                return !getLambdaSubfields().isEmpty() &&
                        getLambdaSubfields().stream()
                                .noneMatch(
                                        subfield -> subfield.getPath().stream()
                                                .skip(subfield.getPath().size() - 1)
                                                .anyMatch(pathElement -> pathElement.equals(allSubscripts())));
            }
        }
    }

    private static boolean isSubscriptOrElementAtFunction(CallExpression expression, StandardFunctionResolution functionResolution, FunctionAndTypeManager functionAndTypeManager)
    {
        return functionResolution.isSubscriptFunction(expression.getFunctionHandle()) ||
                functionAndTypeManager.getFunctionAndTypeResolver().getFunctionMetadata(expression.getFunctionHandle()).getName()
                        .equals(functionAndTypeManager.getFunctionAndTypeResolver().qualifyObjectName(QualifiedName.of("element_at")));
    }

    private static boolean isMapSubSetWithConstantArray(CallExpression expression, FunctionResolution functionResolution)
    {
        return functionResolution.isMapSubSetFunction(expression.getFunctionHandle())
                && expression.getArguments().get(0) instanceof VariableReferenceExpression
                && expression.getArguments().get(1) instanceof ConstantExpression;
    }

    private static boolean isMapFilterWithConstantFilterInMapKey(CallExpression expression, FunctionResolution functionResolution)
    {
        if (functionResolution.isMapFilterFunction(expression.getFunctionHandle())
                && expression.getArguments().get(0) instanceof VariableReferenceExpression && expression.getArguments().get(1) instanceof LambdaDefinitionExpression) {
            LambdaDefinitionExpression lambdaDefinitionExpression = (LambdaDefinitionExpression) expression.getArguments().get(1);
            if (lambdaDefinitionExpression.getBody() instanceof SpecialFormExpression) {
                SpecialFormExpression specialFormExpression = (SpecialFormExpression) lambdaDefinitionExpression.getBody();
                if (specialFormExpression.getForm().equals(IN) && specialFormExpression.getArguments().get(0) instanceof VariableReferenceExpression
                        && ((VariableReferenceExpression) specialFormExpression.getArguments().get(0)).getName().equals(lambdaDefinitionExpression.getArguments().get(0))) {
                    return specialFormExpression.getArguments().stream().skip(1).allMatch(x -> x instanceof ConstantExpression);
                }
            }
            else if (lambdaDefinitionExpression.getBody() instanceof CallExpression) {
                CallExpression callExpression = (CallExpression) lambdaDefinitionExpression.getBody();
                if (functionResolution.isArrayContainsFunction(callExpression.getFunctionHandle())) {
                    return callExpression.getArguments().get(0) instanceof ConstantExpression && callExpression.getArguments().get(1) instanceof VariableReferenceExpression
                            && ((VariableReferenceExpression) callExpression.getArguments().get(1)).getName().equals(lambdaDefinitionExpression.getArguments().get(0));
                }
                else if (functionResolution.isEqualFunction(callExpression.getFunctionHandle())) {
                    return (callExpression.getArguments().get(0) instanceof VariableReferenceExpression
                            && ((VariableReferenceExpression) callExpression.getArguments().get(0)).getName().equals(lambdaDefinitionExpression.getArguments().get(0))
                            && callExpression.getArguments().get(1) instanceof ConstantExpression)
                            || (callExpression.getArguments().get(1) instanceof VariableReferenceExpression
                            && ((VariableReferenceExpression) callExpression.getArguments().get(1)).getName().equals(lambdaDefinitionExpression.getArguments().get(0))
                            && callExpression.getArguments().get(0) instanceof ConstantExpression);
                }
            }
        }
        return false;
    }
}