EntropyAggregation.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.aggregation;

import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.operator.aggregation.state.EntropyState;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.AggregationFunction;
import com.facebook.presto.spi.function.AggregationState;
import com.facebook.presto.spi.function.CombineFunction;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.InputFunction;
import com.facebook.presto.spi.function.OutputFunction;
import com.facebook.presto.spi.function.SqlType;

import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;

/**
 *  Calculates the log-2 entropy of count inputs.
 *
 *  Given counts $c_1, c_2, ..., cn$ ($c_i \geq 0)$, this aggregation calculates $\sum_i [ p_i log_2(-p_i) ]$,
 *      where $p_i = {c_i \over \sum_i [ c_i ]}$. If $\sum_i [ c_i ] = 0$, the entropy is defined to be 0.
 */
@AggregationFunction("entropy")
@Description("Takes non-negative count inputs, and computes the log-2 entropy of their fractions when normalized to sum to 1.")
public final class EntropyAggregation
{
    private EntropyAggregation() {}

    /**
     * @note If count is negative, the value of the aggregation will be null; if
     * count is 0, this is a no-op (since, in the context of entropy, 0 log(0) = 0; if count is null,
     * this is a no op.
     */
    @InputFunction
    public static void input(
            @AggregationState EntropyState state,
            @SqlType(StandardTypes.BIGINT) long count)
    {
        if (count < 0) {
            throw new PrestoException(
                    INVALID_FUNCTION_ARGUMENT,
                    "Entropy count argument must be non-negative");
        }

        if (count == 0) {
            return;
        }
        state.setSumC(state.getSumC() + count);
        state.setSumCLogC(state.getSumCLogC() + count * Math.log(count));
    }

    @CombineFunction
    public static void combine(@AggregationState EntropyState state, @AggregationState EntropyState otherState)
    {
        state.setSumC(state.getSumC() + otherState.getSumC());
        state.setSumCLogC(state.getSumCLogC() + otherState.getSumCLogC());
    }

    @OutputFunction(StandardTypes.DOUBLE)
    public static void output(@AggregationState EntropyState state, BlockBuilder out)
    {
        Double entropy = 0.0;
        if (state.getSumC() > 0.0) {
            entropy = Math.max(
                    (Math.log(state.getSumC()) - state.getSumCLogC() / state.getSumC()) / Math.log(2.0),
                    0.0);
        }

        DOUBLE.writeDouble(out, entropy);
    }
}