ArrayNormalizeFunction.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.DoubleType;
import com.facebook.presto.common.type.RealType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.spi.function.TypeParameterSpecialization;

import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.util.Failures.checkCondition;
import static java.lang.String.format;

@ScalarFunction("array_normalize")
@Description("Normalizes an array by dividing each element by the p-norm of the array.")
public final class ArrayNormalizeFunction
{
    private static final ValueAccessor DOUBLE_VALUE_ACCESSOR = new DoubleValueAccessor();
    private static final ValueAccessor REAL_VALUE_ACCESSOR = new RealValueAccessor();

    private ArrayNormalizeFunction() {}

    @TypeParameter("T")
    @TypeParameterSpecialization(name = "T", nativeContainerType = double.class)
    @SqlType("array(T)")
    @SqlNullable
    public static Block normalizeDoubleArray(
            @TypeParameter("T") Type elementType,
            @SqlType("array(T)") Block block,
            @SqlType("T") double p)
    {
        return normalizeArray(elementType, block, p, DOUBLE_VALUE_ACCESSOR);
    }

    @TypeParameter("T")
    @TypeParameterSpecialization(name = "T", nativeContainerType = long.class)
    @SqlType("array(T)")
    @SqlNullable
    public static Block normalizeRealArray(
            @TypeParameter("T") Type elementType,
            @SqlType("array(T)") Block block,
            @SqlType("T") long p)
    {
        return normalizeArray(elementType, block, Float.intBitsToFloat((int) p), REAL_VALUE_ACCESSOR);
    }

    private static Block normalizeArray(Type elementType, Block block, double p, ValueAccessor valueAccessor)
    {
        if (!(elementType instanceof RealType) && !(elementType instanceof DoubleType)) {
            throw new PrestoException(
                    FUNCTION_IMPLEMENTATION_MISSING,
                    format("Unsupported array element type for array_normalize function: %s", elementType.getDisplayName()));
        }
        checkCondition(p >= 0, INVALID_FUNCTION_ARGUMENT, "array_normalize only supports non-negative p: %s", p);

        if (p == 0) {
            return block;
        }

        int elementCount = block.getPositionCount();
        double pNorm = 0;
        for (int i = 0; i < elementCount; i++) {
            if (block.isNull(i)) {
                return null;
            }
            pNorm += Math.pow(Math.abs(valueAccessor.getValue(elementType, block, i)), p);
        }
        if (pNorm == 0) {
            return block;
        }
        pNorm = Math.pow(pNorm, 1.0 / p);
        BlockBuilder blockBuilder = elementType.createBlockBuilder(null, elementCount);
        for (int i = 0; i < elementCount; i++) {
            valueAccessor.writeValue(elementType, blockBuilder, valueAccessor.getValue(elementType, block, i) / pNorm);
        }
        return blockBuilder.build();
    }

    private interface ValueAccessor
    {
        double getValue(Type elementType, Block block, int position);

        void writeValue(Type elementType, BlockBuilder blockBuilder, double value);
    }

    private static class DoubleValueAccessor
            implements ValueAccessor
    {
        @Override
        public double getValue(Type elementType, Block block, int position)
        {
            return elementType.getDouble(block, position);
        }

        @Override
        public void writeValue(Type elementType, BlockBuilder blockBuilder, double value)
        {
            elementType.writeDouble(blockBuilder, value);
        }
    }

    private static class RealValueAccessor
            implements ValueAccessor
    {
        @Override
        public double getValue(Type elementType, Block block, int position)
        {
            return Float.intBitsToFloat((int) elementType.getLong(block, position));
        }

        @Override
        public void writeValue(Type elementType, BlockBuilder blockBuilder, double value)
        {
            elementType.writeLong(blockBuilder, Float.floatToIntBits((float) value));
        }
    }
}