PinotAggregationProjectConverter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.pinot.query;

import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.pinot.PinotException;
import com.facebook.presto.pinot.PinotSessionProperties;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slice;

import java.time.ZoneId;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.pinot.PinotErrorCode.PINOT_UNSUPPORTED_EXPRESSION;
import static com.facebook.presto.pinot.PinotPushdownUtils.getLiteralAsString;
import static com.facebook.presto.pinot.query.PinotExpression.derived;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;

public class PinotAggregationProjectConverter
        extends PinotProjectExpressionConverter
{
    private static final String FROM_UNIXTIME = "from_unixtime";

    private static final Map<String, String> PRESTO_TO_PINOT_ARRAY_AGGREGATIONS = ImmutableMap.<String, String>builder()
            .put("array_min", "arrayMin")
            .put("array_max", "arrayMax")
            .put("array_average", "arrayAverage")
            .put("array_sum", "arraySum")
            .build();

    private final VariableReferenceExpression arrayVariableHint;

    public PinotAggregationProjectConverter(TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution standardFunctionResolution, ConnectorSession session)
    {
        this(typeManager, functionMetadataManager, standardFunctionResolution, session, null);
    }

    public PinotAggregationProjectConverter(TypeManager typeManager, FunctionMetadataManager functionMetadataManager, StandardFunctionResolution standardFunctionResolution, ConnectorSession session, VariableReferenceExpression arrayVariableHint)
    {
        super(typeManager, functionMetadataManager, standardFunctionResolution, session);
        this.arrayVariableHint = arrayVariableHint;
    }

    @Override
    public PinotExpression visitCall(
            CallExpression call,
            Map<VariableReferenceExpression, PinotQueryGeneratorContext.Selection> context)
    {
        FunctionHandle functionHandle = call.getFunctionHandle();
        if (standardFunctionResolution.isCastFunction(functionHandle)) {
            return handleCast(call, context);
        }
        if (standardFunctionResolution.isNotFunction(functionHandle) || standardFunctionResolution.isBetweenFunction(functionHandle)) {
            throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), "Unsupported function in pinot aggregation: " + functionHandle);
        }

        FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(functionHandle);
        Optional<OperatorType> operatorTypeOptional = functionMetadata.getOperatorType();
        if (operatorTypeOptional.isPresent()) {
            OperatorType operatorType = operatorTypeOptional.get();
            if (operatorType.isArithmeticOperator()) {
                return handleArithmeticExpression(call, operatorType, context);
            }
            if (operatorType.isComparisonOperator()) {
                throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), "Comparison operator not supported: " + call);
            }
        }
        return handleFunction(call, context);
    }

    @Override
    public PinotExpression visitConstant(
            ConstantExpression literal,
            Map<VariableReferenceExpression, PinotQueryGeneratorContext.Selection> context)
    {
        return new PinotExpression(getLiteralAsString(literal), PinotQueryGeneratorContext.Origin.LITERAL);
    }

    private PinotExpression handleDateTruncationViaDateTimeConvert(
            CallExpression function,
            Map<VariableReferenceExpression, PinotQueryGeneratorContext.Selection> context)
    {
        // Convert SQL standard function `DATE_TRUNC(INTERVAL, DATE/TIMESTAMP COLUMN)` to
        // Pinot's equivalent function `dateTimeConvert(columnName, inputFormat, outputFormat, outputGranularity)`
        // Pinot doesn't have a DATE/TIMESTAMP type. That means the input column (second argument) has been converted from numeric type to DATE/TIMESTAMP using one of the
        // conversion functions in SQL. First step is find the function and find its input column units (seconds, secondsSinceEpoch etc.)
        RowExpression timeInputParameter = function.getArguments().get(1);
        String inputColumn;
        String inputFormat;

        CallExpression timeConversion = getExpressionAsFunction(timeInputParameter, timeInputParameter);
        switch (timeConversion.getDisplayName().toLowerCase(ENGLISH)) {
            case FROM_UNIXTIME:
                inputColumn = timeConversion.getArguments().get(0).accept(this, context).getDefinition();
                inputFormat = "'1:SECONDS:EPOCH'";
                break;
            default:
                throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), "not supported: " + timeConversion.getDisplayName());
        }

        String outputFormat = "'1:MILLISECONDS:EPOCH'";
        String outputGranularity;

        RowExpression intervalParameter = function.getArguments().get(0);
        if (!(intervalParameter instanceof ConstantExpression)) {
            throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(),
                    "interval unit in date_trunc is not supported: " + intervalParameter);
        }

        String value = getStringFromConstant(intervalParameter);
        switch (value) {
            case "second":
                outputGranularity = "'1:SECONDS'";
                break;
            case "minute":
                outputGranularity = "'1:MINUTES'";
                break;
            case "hour":
                outputGranularity = "'1:HOURS'";
                break;
            case "day":
                outputGranularity = "'1:DAYS'";
                break;
            case "week":
                outputGranularity = "'1:WEEKS'";
                break;
            case "month":
                outputGranularity = "'1:MONTHS'";
                break;
            case "quarter":
                outputGranularity = "'1:QUARTERS'";
                break;
            case "year":
                outputGranularity = "'1:YEARS'";
                break;
            default:
                throw new PinotException(
                        PINOT_UNSUPPORTED_EXPRESSION,
                        Optional.empty(),
                        "interval in date_trunc is not supported: " + value);
        }

        return derived("dateTimeConvert(" + inputColumn + ", " + inputFormat + ", " + outputFormat + ", " + outputGranularity + ")");
    }

    private PinotExpression handleDateTruncationViaDateTruncation(
            CallExpression function,
            Map<VariableReferenceExpression, PinotQueryGeneratorContext.Selection> context)
    {
        RowExpression timeInputParameter = function.getArguments().get(1);
        String inputColumn;
        String inputTimeZone;
        String inputFormat;

        CallExpression timeConversion = getExpressionAsFunction(timeInputParameter, timeInputParameter);
        switch (timeConversion.getDisplayName().toLowerCase(ENGLISH)) {
            case FROM_UNIXTIME:
                inputColumn = timeConversion.getArguments().get(0).accept(this, context).getDefinition();
                inputTimeZone = timeConversion.getArguments().size() > 1 ? getStringFromConstant(timeConversion.getArguments().get(1)) : ZoneId.of("UTC").getId();
                inputFormat = "seconds";
                break;
            default:
                throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), "not supported: " + timeConversion.getDisplayName());
        }

        RowExpression intervalParameter = function.getArguments().get(0);
        if (!(intervalParameter instanceof ConstantExpression)) {
            throw new PinotException(
                    PINOT_UNSUPPORTED_EXPRESSION,
                    Optional.empty(),
                    "interval unit in date_trunc is not supported: " + intervalParameter);
        }

        return derived("dateTrunc(" + inputColumn + "," + inputFormat + ", " + inputTimeZone + ", " + getStringFromConstant(intervalParameter) + ")");
    }

    private PinotExpression handleFunction(
            CallExpression function,
            Map<VariableReferenceExpression, PinotQueryGeneratorContext.Selection> context)
    {
        String functionName = function.getDisplayName().toLowerCase(ENGLISH);
        switch (functionName) {
            case "date_trunc":
                boolean useDateTruncation = PinotSessionProperties.isUseDateTruncation(session);
                return useDateTruncation ?
                        handleDateTruncationViaDateTruncation(function, context) :
                        handleDateTruncationViaDateTimeConvert(function, context);
            case "array_max":
            case "array_min":
                String pinotArrayFunctionName = PRESTO_TO_PINOT_ARRAY_AGGREGATIONS.get(functionName);
                requireNonNull(pinotArrayFunctionName, "Converted Pinot array function is null for - " + functionName);
                return derived(String.format(
                        "%s(%s)",
                        pinotArrayFunctionName,
                        function.getArguments().get(0).accept(this, context).getDefinition()));
            // array_sum and array_reduce are translated to a reduce function with lambda functions, so we pass in
            // this arrayVariableHint to help determine which array function it is.
            case "reduce":
                if (arrayVariableHint != null) {
                    String arrayFunctionName = getArrayFunctionName(arrayVariableHint);
                    if (arrayFunctionName != null) {
                        String inputColumn = function.getArguments().get(0).accept(this, context).getDefinition();
                        return derived(String.format("%s(%s)", arrayFunctionName, inputColumn));
                    }
                }
            default:
                throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), format("function %s not supported yet", function.getDisplayName()));
        }
    }

    // The array function variable names are in the format of `array_sum`, `array_average_0`, `array_sum_1`.
    // So we can parse the array function name based on variable name.
    private String getArrayFunctionName(VariableReferenceExpression variable)
    {
        String[] variableNameSplits = variable.getName().split("_");
        if (variableNameSplits.length < 2 || variableNameSplits.length > 3) {
            return null;
        }
        String arrayFunctionName = String.format("%s_%s", variableNameSplits[0], variableNameSplits[1]);
        return PRESTO_TO_PINOT_ARRAY_AGGREGATIONS.get(arrayFunctionName);
    }

    private static String getStringFromConstant(RowExpression expression)
    {
        if (expression instanceof ConstantExpression) {
            Object value = ((ConstantExpression) expression).getValue();
            if (value instanceof String) {
                return (String) value;
            }
            if (value instanceof Slice) {
                return ((Slice) value).toStringUtf8();
            }
        }
        throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), "Expected string literal but found " + expression);
    }

    private CallExpression getExpressionAsFunction(
            RowExpression originalExpression,
            RowExpression expression)
    {
        if (expression instanceof CallExpression) {
            CallExpression call = (CallExpression) expression;
            if (standardFunctionResolution.isCastFunction(call.getFunctionHandle())) {
                if (isImplicitCast(call.getArguments().get(0).getType(), call.getType())) {
                    return getExpressionAsFunction(originalExpression, call.getArguments().get(0));
                }
            }
            else {
                return call;
            }
        }
        throw new PinotException(PINOT_UNSUPPORTED_EXPRESSION, Optional.empty(), "Could not dig function out of expression: " + originalExpression + ", inside of " + expression);
    }
}