ArrayCumSum.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.scalar;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.RunLengthEncodedBlock;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.OperatorDependency;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.SqlType;
import com.facebook.presto.spi.function.TypeParameter;
import java.lang.invoke.MethodHandle;
import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.common.type.TypeUtils.readNativeValue;
import static com.facebook.presto.common.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.util.Failures.internalError;
@Description("Get the cumulative sum array for the input array")
@ScalarFunction(value = "array_cum_sum", deterministic = true)
public final class ArrayCumSum
{
private ArrayCumSum() {}
@TypeParameter("T")
@SqlType("array(T)")
public static Block sum(
@OperatorDependency(operator = ADD, argumentTypes = {"T", "T"}) MethodHandle addFunction,
@TypeParameter("T") Type elementType,
@SqlType("array(T)") Block arrayBlock)
{
int positionCount = arrayBlock.getPositionCount();
BlockBuilder resultBuilder = elementType.createBlockBuilder(null, positionCount);
if (positionCount == 0) {
return resultBuilder.build();
}
if (arrayBlock.isNull(0)) {
return RunLengthEncodedBlock.create(elementType, null, positionCount);
}
elementType.appendTo(arrayBlock, 0, resultBuilder);
if (arrayBlock.mayHaveNull()) {
int pos = 1;
for (; pos < positionCount; ++pos) {
if (arrayBlock.isNull(pos)) {
break;
}
writeSum(elementType, resultBuilder, addFunction, pos, arrayBlock);
}
for (; pos < positionCount; ++pos) {
resultBuilder.appendNull();
}
}
else {
for (int pos = 1; pos < positionCount; ++pos) {
writeSum(elementType, resultBuilder, addFunction, pos, arrayBlock);
}
}
return resultBuilder.build();
}
private static void writeSum(Type elementType, BlockBuilder resultBuilder, MethodHandle addFunction, int pos, Block arrayBlock)
{
try {
writeNativeValue(elementType, resultBuilder, addFunction.invoke(readNativeValue(elementType, resultBuilder, pos - 1), readNativeValue(elementType, arrayBlock, pos)));
}
catch (Throwable throwable) {
throw internalError(throwable);
}
}
}