ReduceAggregationFunction.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.state.ReduceAggregationState;
import com.facebook.presto.operator.aggregation.state.ReduceAggregationStateFactory;
import com.facebook.presto.operator.aggregation.state.ReduceAggregationStateSerializer;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.facebook.presto.sql.gen.lambda.BinaryFunctionInterface;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandle;
import java.util.List;

import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.spi.StandardErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.INPUT_CHANNEL;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.util.Reflection.methodHandle;
import static java.lang.String.format;

public class ReduceAggregationFunction
        extends SqlAggregationFunction
{
    private static final String NAME = "reduce_agg";

    private static final MethodHandle INPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "input", Type.class, ReduceAggregationState.class, Object.class, Object.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle COMBINE_FUNCTION = methodHandle(ReduceAggregationFunction.class, "combine", ReduceAggregationState.class, ReduceAggregationState.class, BinaryFunctionInterface.class, BinaryFunctionInterface.class);
    private static final MethodHandle OUTPUT_FUNCTION = methodHandle(ReduceAggregationFunction.class, "write", Type.class, ReduceAggregationState.class, BlockBuilder.class);

    private final boolean supportsComplexTypes;

    public ReduceAggregationFunction(boolean supportsComplexTypes)
    {
        super(NAME,
                ImmutableList.of(typeVariable("T"), typeVariable("S")),
                ImmutableList.of(),
                parseTypeSignature("S"),
                ImmutableList.of(
                        parseTypeSignature("T"),
                        parseTypeSignature("S"),
                        parseTypeSignature("function(S,T,S)"),
                        parseTypeSignature("function(S,S,S)")));

        this.supportsComplexTypes = supportsComplexTypes;
    }

    @Override
    public boolean isDeterministic()
    {
        return false;
    }

    @Override
    public String getDescription()
    {
        return "Reduce input elements into a single value";
    }

    @Override
    public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
    {
        Type inputType = boundVariables.getTypeVariable("T");
        Type stateType = boundVariables.getTypeVariable("S");
        return generateAggregation(inputType, stateType);
    }

    private BuiltInAggregationFunctionImplementation generateAggregation(Type inputType, Type stateType)
    {
        DynamicClassLoader classLoader = new DynamicClassLoader(ReduceAggregationFunction.class.getClassLoader());

        MethodHandle inputMethodHandle;
        MethodHandle combineMethodHandle;
        MethodHandle outputMethodHandle;
        AccumulatorStateDescriptor stateDescriptor;

        if (!supportsComplexTypes && !(stateType.getJavaType() == long.class || stateType.getJavaType() == double.class || stateType.getJavaType() == boolean.class)) {
            // For large heap, State with Slice or Block may result in excessive JVM memory usage of remembered set.
            // See JDK-8017163.
            throw new PrestoException(NOT_SUPPORTED, format("State type not enabled for %s: %s", NAME, stateType.getDisplayName()));
        }

        inputMethodHandle = INPUT_FUNCTION.bindTo(inputType);
        combineMethodHandle = COMBINE_FUNCTION;
        outputMethodHandle = OUTPUT_FUNCTION.bindTo(stateType);

        stateDescriptor = new AccumulatorStateDescriptor(
                ReduceAggregationState.class,
                new ReduceAggregationStateSerializer(stateType),
                new ReduceAggregationStateFactory());

        AggregationMetadata metadata = new AggregationMetadata(
                generateAggregationName(getSignature().getNameSuffix(), inputType.getTypeSignature(), ImmutableList.of(inputType.getTypeSignature())),
                createInputParameterMetadata(inputType, stateType),
                inputMethodHandle.asType(
                        inputMethodHandle.type()
                                .changeParameterType(1, inputType.getJavaType())
                                .changeParameterType(2, stateType.getJavaType())),
                combineMethodHandle,
                outputMethodHandle,
                ImmutableList.of(stateDescriptor),
                stateType,
                ImmutableList.of(BinaryFunctionInterface.class, BinaryFunctionInterface.class));

        Class<? extends Accumulator> accumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
                Accumulator.class,
                metadata,
                classLoader);
        Class<? extends GroupedAccumulator> groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
                GroupedAccumulator.class,
                metadata,
                classLoader);
        return new BuiltInAggregationFunctionImplementation(
                getSignature().getNameSuffix(),
                ImmutableList.of(inputType),
                ImmutableList.of(stateType),
                stateType,
                true,
                false,
                metadata,
                accumulatorClass,
                groupedAccumulatorClass,
                ImmutableList.of(BinaryFunctionInterface.class, BinaryFunctionInterface.class));
    }

    private static List<ParameterMetadata> createInputParameterMetadata(Type inputType, Type stateType)
    {
        return ImmutableList.of(
                new ParameterMetadata(STATE),
                new ParameterMetadata(INPUT_CHANNEL, inputType),
                new ParameterMetadata(INPUT_CHANNEL, stateType));
    }

    public static void input(Type type, ReduceAggregationState state, Object value, Object initialStateValue, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
    {
        if (state.getValue() == null) {
            state.setValue(initialStateValue);
        }
        try {
            state.setValue(inputFunction.apply(state.getValue(), value));
        }
        catch (NullPointerException npe) {
            state.setValue(null);
        }
    }

    public static void combine(ReduceAggregationState state, ReduceAggregationState otherState, BinaryFunctionInterface inputFunction, BinaryFunctionInterface combineFunction)
    {
        if (state.getValue() == null) {
            state.setValue(otherState.getValue());
            return;
        }
        try {
            state.setValue(combineFunction.apply(state.getValue(), otherState.getValue()));
        }
        catch (NullPointerException npe) {
            state.setValue(null);
        }
    }

    public static void write(Type type, ReduceAggregationState state, BlockBuilder blockBuilder)
    {
        if (state.getValue() == null) {
            blockBuilder.appendNull();
        }
        else if (type.getJavaType() == long.class) {
            type.writeLong(blockBuilder, (long) state.getValue());
        }
        else if (type.getJavaType() == double.class) {
            type.writeDouble(blockBuilder, (double) state.getValue());
        }
        else if (type.getJavaType() == boolean.class) {
            type.writeBoolean(blockBuilder, (boolean) state.getValue());
        }
        else if (type.getJavaType() == Block.class) {
            type.writeObject(blockBuilder, state.getValue());
        }
        else {
            type.writeSlice(blockBuilder, (Slice) state.getValue());
        }
    }
}