ParquetDereferencePushDown.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.parquet.rule;

import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.expressions.RowExpressionRewriter;
import com.facebook.presto.expressions.RowExpressionTreeRewriter;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorPlanOptimizer;
import com.facebook.presto.spi.ConnectorPlanRewriter;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.ExpressionOptimizer;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionService;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import static com.facebook.presto.parquet.ParquetTypeUtils.pushdownColumnNameForSubfield;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Collections.unmodifiableList;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

public abstract class ParquetDereferencePushDown
        implements ConnectorPlanOptimizer
{
    private final RowExpressionService rowExpressionService;

    public ParquetDereferencePushDown(RowExpressionService rowExpressionService)
    {
        this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null");
    }

    @Override
    public PlanNode optimize(PlanNode maxSubplan, ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator)
    {
        return ConnectorPlanRewriter.rewriteWith(new Rewriter(session, variableAllocator, idAllocator), maxSubplan);
    }

    /**
     * Whether Parquet dereference pushdown is enabled for given TableHandle
     */
    protected abstract boolean isParquetDereferenceEnabled(ConnectorSession session, TableHandle tableHandle);

    /**
     * ColumnHandle is an interface. Each connector implements its own version.
     * Connector specific implementation of this method returns the column name given ColumnHandle refers to.
     */
    protected abstract String getColumnName(ColumnHandle columnHandle);

    /**
     * Create connector specific ColumnHandle for given subfield that is being pushed into table scan.
     *
     * @param baseColumnHandle ColumnHandle for base column that given <i>subfield</i> is part of.
     *                            Ex. in "msg.a.b", "msg" is the top level column. "a.b" is the subfield
     *                            that is part of "msg". This ColumnHandle refers to "msg".
     * @param subfield            Subfield info.
     * @param subfieldDataType    Data type of the subfield.
     * @param subfieldColumnName  Name of the subfield column being referred in table scan output.
     * @return
     */
    protected abstract ColumnHandle createSubfieldColumnHandle(
            ColumnHandle baseColumnHandle,
            Subfield subfield,
            Type subfieldDataType,
            String subfieldColumnName);

    private Map<RowExpression, Subfield> extractDereferences(
            Map<String, ColumnHandle> baseColumnHandles,
            ConnectorSession session, ExpressionOptimizer expressionOptimizer,
            Set<RowExpression> expressions)
    {
        Set<RowExpression> dereferenceAndVariableExpressions = new HashSet<>();
        expressions.forEach(e -> e.accept(new ExtractDereferenceAndVariables(session, expressionOptimizer), dereferenceAndVariableExpressions));

        // keep prefix only expressions
        List<RowExpression> dereferences = dereferenceAndVariableExpressions.stream()
                .filter(expression -> !prefixExists(expression, dereferenceAndVariableExpressions))
                .filter(expression -> expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == DEREFERENCE)
                .collect(Collectors.toList());

        return dereferences.stream().collect(toMap(identity(), dereference -> createNestedColumn(
                baseColumnHandles, dereference, expressionOptimizer, session)));
    }

    private static boolean prefixExists(RowExpression expression, Set<RowExpression> allExpressions)
    {
        int[] referenceCount = {0};
        expression.accept(
                new DefaultRowExpressionTraversalVisitor<int[]>()
                {
                    @Override
                    public Void visitSpecialForm(SpecialFormExpression specialForm, int[] context)
                    {
                        if (specialForm.getForm() != DEREFERENCE) {
                            return super.visitSpecialForm(specialForm, context);
                        }

                        if (allExpressions.contains(specialForm)) {
                            referenceCount[0] += 1;
                        }

                        RowExpression base = specialForm.getArguments().get(0);
                        base.accept(this, context);
                        return null;
                    }

                    @Override
                    public Void visitVariableReference(VariableReferenceExpression reference, int[] context)
                    {
                        if (allExpressions.contains(reference)) {
                            referenceCount[0] += 1;
                        }
                        return null;
                    }
                }, referenceCount);

        return referenceCount[0] > 1;
    }

    private Subfield createNestedColumn(
            Map<String, ColumnHandle> baseColumnHandles,
            RowExpression rowExpression,
            ExpressionOptimizer expressionOptimizer,
            ConnectorSession session)
    {
        if (!(rowExpression instanceof SpecialFormExpression) || ((SpecialFormExpression) rowExpression).getForm() != DEREFERENCE) {
            throw new IllegalArgumentException("expecting SpecialFormExpression(DEREFERENCE), but got: " + rowExpression);
        }

        RowExpression currentRowExpression = rowExpression;
        List<Subfield.PathElement> elements = new ArrayList<>();
        while (true) {
            if (currentRowExpression instanceof VariableReferenceExpression) {
                Collections.reverse(elements);
                String name = ((VariableReferenceExpression) currentRowExpression).getName();
                ColumnHandle handle = baseColumnHandles.get(name);
                checkArgument(handle != null, "Missing Column handle: " + name);
                String originalColumnName = getColumnName(handle);
                return new Subfield(originalColumnName, unmodifiableList(elements));
            }

            if (currentRowExpression instanceof SpecialFormExpression && ((SpecialFormExpression) currentRowExpression).getForm() == DEREFERENCE) {
                SpecialFormExpression dereferenceExpression = (SpecialFormExpression) currentRowExpression;
                RowExpression base = dereferenceExpression.getArguments().get(0);
                RowType baseType = (RowType) base.getType();

                RowExpression indexExpression = expressionOptimizer.optimize(
                        dereferenceExpression.getArguments().get(1),
                        ExpressionOptimizer.Level.OPTIMIZED,
                        session);

                if (indexExpression instanceof ConstantExpression) {
                    Object index = ((ConstantExpression) indexExpression).getValue();
                    if (index instanceof Number) {
                        Optional<String> fieldName = baseType.getFields().get(((Number) index).intValue()).getName();
                        if (fieldName.isPresent()) {
                            elements.add(new Subfield.NestedField(fieldName.get()));
                            currentRowExpression = base;
                            continue;
                        }
                    }
                }
            }
            break;
        }

        throw new IllegalArgumentException("expecting SpecialFormExpression(DEREFERENCE) with constants for indices, but got: " + currentRowExpression);
    }

    /**
     * Visitor to extract all dereference expressions and variable references.
     * <p>
     * If a dereference expression contains dereference expression, inner dereference expression are not returned
     * * sub(deref(deref(x, 1), 2)) --> deref(deref(x,1), 2)
     * Variable expressions returned are the ones not referenced by the dereference expressions
     * * sub(x + 1) --> x
     * * sub(deref(x, 1)) -> deref(x,1)
     */
    private static class ExtractDereferenceAndVariables
            extends DefaultRowExpressionTraversalVisitor<Set<RowExpression>>
    {
        private final ConnectorSession connectorSession;
        private final ExpressionOptimizer expressionOptimizer;

        public ExtractDereferenceAndVariables(ConnectorSession connectorSession, ExpressionOptimizer expressionOptimizer)
        {
            this.connectorSession = connectorSession;
            this.expressionOptimizer = expressionOptimizer;
        }

        @Override
        public Void visitSpecialForm(SpecialFormExpression specialForm, Set<RowExpression> context)
        {
            if (specialForm.getForm() != DEREFERENCE) {
                return super.visitSpecialForm(specialForm, context);
            }

            RowExpression expression = specialForm;
            while (true) {
                if (expression instanceof VariableReferenceExpression) {
                    context.add(specialForm);
                    return null;
                }

                if (expression instanceof SpecialFormExpression && ((SpecialFormExpression) expression).getForm() == DEREFERENCE) {
                    SpecialFormExpression dereferenceExpression = (SpecialFormExpression) expression;
                    RowExpression base = dereferenceExpression.getArguments().get(0);
                    RowType baseType = (RowType) base.getType();

                    RowExpression indexExpression = expressionOptimizer.optimize(
                            dereferenceExpression.getArguments().get(1),
                            ExpressionOptimizer.Level.OPTIMIZED,
                            connectorSession);

                    if (indexExpression instanceof ConstantExpression) {
                        Object index = ((ConstantExpression) indexExpression).getValue();
                        if (index instanceof Number) {
                            Optional<String> fieldName = baseType.getFields().get(((Number) index).intValue()).getName();
                            if (fieldName.isPresent()) {
                                expression = base;
                                continue;
                            }
                        }
                    }
                }
                break;
            }

            return super.visitSpecialForm(specialForm, context);
        }

        @Override
        public Void visitVariableReference(VariableReferenceExpression reference, Set<RowExpression> context)
        {
            context.add(reference);
            return null;
        }
    }

    private static class DereferenceExpressionRewriter
            extends RowExpressionRewriter<Void>
    {
        private final Map<RowExpression, VariableReferenceExpression> dereferenceMap;

        public DereferenceExpressionRewriter(Map<RowExpression, VariableReferenceExpression> dereferenceMap)
        {
            this.dereferenceMap = dereferenceMap;
        }

        @Override
        public RowExpression rewriteSpecialForm(SpecialFormExpression node, Void context, RowExpressionTreeRewriter<Void> treeRewriter)
        {
            return dereferenceMap.get(node);
        }
    }

    /**
     * Looks for ProjectNode -> TableScanNode patterns. Goes through the project expressions to extract out the DEREFERENCE expressions,
     * pushes the dereferences down to TableScan and creates new project expressions with the pushed down column coming from the TableScan.
     * Returned plan nodes could contain unreferenced outputs which will be pruned later in the planning process.
     */
    private class Rewriter
            extends ConnectorPlanRewriter<Void>
    {
        private final ConnectorSession session;
        private final VariableAllocator variableAllocator;
        private final PlanNodeIdAllocator idAllocator;

        Rewriter(ConnectorSession session, VariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator)
        {
            this.session = requireNonNull(session, "session is null");
            this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null");
            this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
        }

        @Override
        public PlanNode visitProject(ProjectNode project, RewriteContext<Void> context)
        {
            if (!(project.getSource() instanceof TableScanNode)) {
                return visitPlan(project, context);
            }

            TableScanNode tableScan = (TableScanNode) project.getSource();
            if (!isParquetDereferenceEnabled(session, tableScan.getTable())) {
                return visitPlan(project, context);
            }
            Map<String, ColumnHandle> baseColumnHandles = new HashMap<>();
            tableScan.getAssignments().entrySet().forEach(assignment -> {
                baseColumnHandles.put(assignment.getKey().getName(), assignment.getValue());
                baseColumnHandles.put(getColumnName(assignment.getValue()), assignment.getValue());
            });

            Map<RowExpression, Subfield> dereferenceToNestedColumnMap = extractDereferences(
                    baseColumnHandles,
                    session,
                    rowExpressionService.getExpressionOptimizer(session),
                    new HashSet<>(project.getAssignments().getExpressions()));
            if (dereferenceToNestedColumnMap.isEmpty()) {
                return visitPlan(project, context);
            }

            List<VariableReferenceExpression> newOutputVariables = new ArrayList<>(tableScan.getOutputVariables());
            Map<VariableReferenceExpression, ColumnHandle> newAssignments = new HashMap<>(tableScan.getAssignments());

            Map<RowExpression, VariableReferenceExpression> dereferenceToVariableMap = new HashMap<>();

            for (Map.Entry<RowExpression, Subfield> dereference : dereferenceToNestedColumnMap.entrySet()) {
                Subfield subfield = dereference.getValue();
                RowExpression dereferenceExpression = dereference.getKey();

                // Find the nested column Hive Type
                ColumnHandle baseColumnHandle = baseColumnHandles.get(subfield.getRootName());
                if (baseColumnHandle == null) {
                    throw new IllegalArgumentException("Subfield column [" + subfield + "]'s base column " + subfield.getRootName() + " is not present in table scan output");
                }
                String subfieldColumnName = pushdownColumnNameForSubfield(subfield);

                ColumnHandle nestedColumnHandle = createSubfieldColumnHandle(
                        baseColumnHandle,
                        subfield,
                        dereferenceExpression.getType(),
                        subfieldColumnName);

                VariableReferenceExpression newOutputVariable = variableAllocator.newVariable(subfieldColumnName, dereferenceExpression.getType());
                newOutputVariables.add(newOutputVariable);
                newAssignments.put(newOutputVariable, nestedColumnHandle);

                dereferenceToVariableMap.put(dereferenceExpression, newOutputVariable);
            }

            TableScanNode newTableScan = new TableScanNode(
                    tableScan.getSourceLocation(),
                    idAllocator.getNextId(),
                    tableScan.getTable(),
                    newOutputVariables,
                    newAssignments,
                    tableScan.getTableConstraints(),
                    tableScan.getCurrentConstraint(),
                    tableScan.getEnforcedConstraint(),
                    tableScan.getCteMaterializationInfo());

            Assignments.Builder newProjectAssignmentBuilder = Assignments.builder();
            for (Map.Entry<VariableReferenceExpression, RowExpression> entry : project.getAssignments().entrySet()) {
                RowExpression newExpression = RowExpressionTreeRewriter.rewriteWith(new DereferenceExpressionRewriter(dereferenceToVariableMap), entry.getValue());
                newProjectAssignmentBuilder.put(entry.getKey(), newExpression);
            }

            return new ProjectNode(tableScan.getSourceLocation(), idAllocator.getNextId(), newTableScan, newProjectAssignmentBuilder.build(), project.getLocality());
        }
    }
}