HivePartialAggregationPushdown.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.hive.rule;

import com.facebook.presto.common.type.Type;
import com.facebook.presto.hive.HiveColumnHandle;
import com.facebook.presto.hive.HiveStorageFormat;
import com.facebook.presto.hive.HiveTableHandle;
import com.facebook.presto.hive.HiveTableLayoutHandle;
import com.facebook.presto.hive.HiveTableProperties;
import com.facebook.presto.hive.HiveType;
import com.facebook.presto.hive.HiveTypeTranslator;
import com.facebook.presto.hive.TransactionalMetadata;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorPlanOptimizer;
import com.facebook.presto.spi.ConnectorPlanRewriter;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.ConnectorTableHandle;
import com.facebook.presto.spi.ConnectorTableMetadata;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.VariableAllocator;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.hive.BaseHiveColumnHandle.ColumnType.REGULAR;
import static com.facebook.presto.hive.HiveSessionProperties.isPartialAggregationPushdownEnabled;
import static com.facebook.presto.hive.HiveSessionProperties.isPartialAggregationPushdownForVariableLengthDatatypesEnabled;
import static com.facebook.presto.hive.HiveStorageFormat.DWRF;
import static com.facebook.presto.hive.HiveStorageFormat.ORC;
import static com.facebook.presto.hive.HiveStorageFormat.PARQUET;
import static com.facebook.presto.hive.metastore.MetastoreUtil.isArrayType;
import static com.facebook.presto.hive.metastore.MetastoreUtil.isMapType;
import static com.facebook.presto.hive.metastore.MetastoreUtil.isRowType;
import static com.facebook.presto.spi.ConnectorPlanRewriter.rewriteWith;
import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND;
import static com.facebook.presto.spi.plan.AggregationNode.Step.PARTIAL;
import static java.util.Objects.requireNonNull;

public class HivePartialAggregationPushdown
        implements ConnectorPlanOptimizer
{
    private final StandardFunctionResolution standardFunctionResolution;
    private final Supplier<TransactionalMetadata> metadataFactory;

    private static final int DUMMY_AGGREGATED_COLUMN_INDEX = -20;

    public HivePartialAggregationPushdown(
            StandardFunctionResolution standardFunctionResolution,
            Supplier<TransactionalMetadata> metadataFactory)
    {
        this.standardFunctionResolution = requireNonNull(standardFunctionResolution, "standard function resolution is null");
        this.metadataFactory = requireNonNull(metadataFactory, "metadata factory is null");
    }

    private static Optional<HiveTableHandle> getHiveTableHandle(TableScanNode tableScanNode)
    {
        TableHandle table = tableScanNode.getTable();
        if (table != null) {
            ConnectorTableHandle connectorHandle = table.getConnectorHandle();
            if (connectorHandle instanceof HiveTableHandle) {
                return Optional.of((HiveTableHandle) connectorHandle);
            }
        }
        return Optional.empty();
    }

    @Override
    public PlanNode optimize(PlanNode maxSubplan,
            ConnectorSession session,
            VariableAllocator variableAllocator,
            PlanNodeIdAllocator idAllocator)
    {
        if (!isPartialAggregationPushdownEnabled(session)) {
            return maxSubplan;
        }
        return rewriteWith(new Rewriter(session, idAllocator), maxSubplan);
    }

    private class Rewriter
            extends ConnectorPlanRewriter<Void>
    {
        private final PlanNodeIdAllocator idAllocator;
        private final ConnectorSession session;

        public Rewriter(ConnectorSession session, PlanNodeIdAllocator idAllocator)
        {
            this.session = session;
            this.idAllocator = idAllocator;
        }

        private boolean isAggregationPushdownSupported(AggregationNode partialAggregationNode, Map<VariableReferenceExpression, ColumnHandle> assignments)
        {
            if (partialAggregationNode.hasNonEmptyGroupingSet()) {
                return false;
            }

            TableScanNode tableScanNode = (TableScanNode) partialAggregationNode.getSource();
            ConnectorTableMetadata connectorTableMetadata = metadataFactory.get().getTableMetadata(session, tableScanNode.getTable().getConnectorHandle());
            Optional<Object> rawFormat = Optional.ofNullable(connectorTableMetadata.getProperties().get(HiveTableProperties.STORAGE_FORMAT_PROPERTY));
            if (!rawFormat.isPresent()) {
                return false;
            }

            final HiveStorageFormat hiveStorageFormat = HiveStorageFormat.valueOf(rawFormat.get().toString());
            if (hiveStorageFormat != ORC && hiveStorageFormat != PARQUET && hiveStorageFormat != DWRF) {
                return false;
            }

            if (tableScanNode.getTable().getLayout().isPresent()) {
                HiveTableLayoutHandle hiveTableLayoutHandle = (HiveTableLayoutHandle) tableScanNode.getTable().getLayout().get();
                if (!hiveTableLayoutHandle.getPredicateColumns().isEmpty()) {
                    return false;
                }
            }

            /**
             * Aggregation push downs are supported only on primitive types and supported aggregation functions are:
             * count(*), count(columnName), min(columnName), max(columnName)
             */
            for (Aggregation aggregation : partialAggregationNode.getAggregations().values()) {
                FunctionHandle functionHandle = aggregation.getFunctionHandle();
                if (!(standardFunctionResolution.isCountFunction(functionHandle) ||
                        standardFunctionResolution.isMaxFunction(functionHandle) ||
                        standardFunctionResolution.isMinFunction(functionHandle))) {
                    return false;
                }

                if (aggregation.getArguments().isEmpty() && !standardFunctionResolution.isCountFunction(functionHandle)) {
                    return false;
                }

                List<RowExpression> arguments = aggregation.getArguments();
                if (arguments.size() > 1) {
                    return false;
                }
                else if (arguments.size() == 1) {
                    RowExpression column = aggregation.getCall().getArguments().get(0);
                    HiveColumnHandle columnHandle = (HiveColumnHandle) assignments.get(column);
                    // These columns get 'PREFILLED' with values in the corresponding page sources
                    if (columnHandle.getColumnType() != REGULAR) {
                        return false;
                    }
                }

                if (standardFunctionResolution.isMinFunction(functionHandle) || standardFunctionResolution.isMaxFunction(functionHandle)) {
                    // Only allow supported datatypes for min/max
                    Type type = arguments.get(0).getType();
                    switch (hiveStorageFormat) {
                        case ORC:
                        case DWRF:
                            if (isNotSupportedOrcTypeForMinMax(type)) {
                                return false;
                            }
                            break;
                        case PARQUET:
                            if (isNotSupportedParquetTypeForMinMax(type)) {
                                return false;
                            }
                            break;
                        default:
                            return false;
                    }
                }
            }
            return true;
        }

        private boolean isNotSupportedOrcTypeForMinMax(Type type)
        {
            return BOOLEAN.equals(type) ||
                    type.getJavaType() == boolean.class ||
                    isRowType(type) ||
                    isArrayType(type) ||
                    isMapType(type) ||
                    TINYINT.equals(type) ||
                    VARBINARY.equals(type) ||
                    TIMESTAMP.equals(type) ||
                    isNotSupportedOrcTypeForVariableLengthDataType(type);
        }

        private boolean isNotSupportedParquetTypeForMinMax(Type type)
        {
            return BOOLEAN.equals(type) ||
                    type.getJavaType() == boolean.class ||
                    isRowType(type) ||
                    isArrayType(type) ||
                    isMapType(type) ||
                    isNotSupportedParquetTypeForVariableLengthDataType(type);
        }

        private boolean isNotSupportedOrcTypeForVariableLengthDataType(Type type)
        {
            return VARCHAR.equals(type) && !isPartialAggregationPushdownForVariableLengthDatatypesEnabled(session);
        }

        private boolean isNotSupportedParquetTypeForVariableLengthDataType(Type type)
        {
            boolean isVariableLengthType = VARBINARY.equals(type) || VARCHAR.equals(type);
            return isVariableLengthType && !isPartialAggregationPushdownForVariableLengthDatatypesEnabled(session);
        }

        private Optional<PlanNode> tryPartialAggregationPushdown(PlanNode plan)
        {
            if (!(plan instanceof AggregationNode
                    && ((AggregationNode) plan).getStep().equals(PARTIAL)
                    && ((AggregationNode) plan).getSource() instanceof TableScanNode)) {
                return Optional.empty();
            }

            AggregationNode partialAggregationNode = (AggregationNode) plan;

            TableScanNode oldTableScanNode = (TableScanNode) partialAggregationNode.getSource();
            TableHandle oldTableHandle = oldTableScanNode.getTable();
            HiveTableHandle hiveTableHandle = getHiveTableHandle(oldTableScanNode).orElseThrow(() -> new PrestoException(NOT_FOUND, "Hive table handle not found"));

            if (!isAggregationPushdownSupported(partialAggregationNode, oldTableScanNode.getAssignments())) {
                return Optional.empty();
            }

            HiveTypeTranslator hiveTypeTranslator = new HiveTypeTranslator();
            Map<VariableReferenceExpression, ColumnHandle> assignments = new HashMap<>();
            for (Map.Entry<VariableReferenceExpression, Aggregation> aggregationEntry : partialAggregationNode.getAggregations().entrySet()) {
                CallExpression callExpression = aggregationEntry.getValue().getCall();
                String columnName;
                int columnIndex;
                HiveType hiveType = HiveType.toHiveType(hiveTypeTranslator, callExpression.getType());
                if (callExpression.getArguments().isEmpty()) {
                    columnName = "count_star";
                    columnIndex = DUMMY_AGGREGATED_COLUMN_INDEX;
                }
                else {
                    RowExpression column = callExpression.getArguments().get(0);
                    columnName = column.toString();
                    HiveColumnHandle oldColumnHandle = (HiveColumnHandle) oldTableScanNode.getAssignments().get(column);
                    columnIndex = oldColumnHandle.getHiveColumnIndex();
                    hiveType = oldColumnHandle.getHiveType();
                }

                ColumnHandle newColumnHandle = new HiveColumnHandle(
                        columnName,
                        hiveType,
                        callExpression.getType().getTypeSignature(),
                        columnIndex,
                        HiveColumnHandle.ColumnType.AGGREGATED,
                        Optional.of("partial aggregation pushed down"),
                        Optional.of(aggregationEntry.getValue()));
                assignments.put(aggregationEntry.getKey(), newColumnHandle);
            }

            HiveTableLayoutHandle oldTableLayoutHandle = (HiveTableLayoutHandle) oldTableHandle.getLayout().get();
            HiveTableLayoutHandle newTableLayoutHandle = oldTableLayoutHandle.builder().setPartialAggregationsPushedDown(true).build();

            TableHandle newTableHandle = new TableHandle(
                    oldTableHandle.getConnectorId(),
                    hiveTableHandle,
                    oldTableHandle.getTransaction(),
                    Optional.of(newTableLayoutHandle));

            return Optional.of(new TableScanNode(
                    oldTableScanNode.getSourceLocation(),
                    idAllocator.getNextId(),
                    newTableHandle,
                    ImmutableList.copyOf(partialAggregationNode.getOutputVariables()),
                    ImmutableMap.copyOf(assignments),
                    oldTableScanNode.getTableConstraints(),
                    oldTableScanNode.getCurrentConstraint(),
                    oldTableScanNode.getEnforcedConstraint(),
                    oldTableScanNode.getCteMaterializationInfo()));
        }

        @Override
        public PlanNode visitAggregation(AggregationNode node, RewriteContext<Void> context)
        {
            return tryPartialAggregationPushdown(node).orElse(node);
        }
    }
}