NoisyCountGaussianColumnAggregation.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.noisyaggregation;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.BoundVariables;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.SqlAggregationFunction;
import com.facebook.presto.operator.aggregation.AccumulatorCompiler;
import com.facebook.presto.operator.aggregation.BuiltInAggregationFunctionImplementation;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.aggregation.Accumulator;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata;
import com.facebook.presto.spi.function.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import com.facebook.presto.spi.function.aggregation.GroupedAccumulator;
import com.google.common.collect.ImmutableList;
import java.lang.invoke.MethodHandle;
import java.util.List;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
import static com.facebook.presto.operator.aggregation.AggregationUtils.generateAggregationName;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAggregationUtils.combineStates;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAggregationUtils.updateState;
import static com.facebook.presto.operator.aggregation.noisyaggregation.NoisyCountAggregationUtils.writeNoisyCountOutput;
import static com.facebook.presto.spi.function.Signature.typeVariable;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INDEX;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.BLOCK_INPUT_CHANNEL;
import static com.facebook.presto.spi.function.aggregation.AggregationMetadata.ParameterMetadata.ParameterType.STATE;
import static com.facebook.presto.util.Reflection.methodHandle;
import static com.google.common.collect.ImmutableList.toImmutableList;
/**
* Add a random Gaussian noise to true count with a given value of standard deviation of the noise.
* If one needs to replace COUNT(*), NOISY_COUNT_GAUSSIAN(1, noiseScale) should be used.
* <p>
* This function behaves similarly to COUNT.
* So, in the case of empty input, this function returns 0 if there is no grouping,
* and returns NULL if there is grouping
* <p>
* Optional randomSeed is used to get a fixed value of noise, often for reproducibility purposes.
* If randomSeed is omitted or 0, SecureRandom is used. If randomSeed > 0 is provided, Random is used.
* <p>
* Function signature is NOISY_COUNT_GAUSSIAN(x, noiseScale[, randomSeed])
* - x: input column/value
* - noiseScale: standard deviation of noise
* - randomSeed: (optional) random seed
*/
public class NoisyCountGaussianColumnAggregation
extends SqlAggregationFunction
{
public static final NoisyCountGaussianColumnAggregation NOISY_COUNT_GAUSSIAN_AGGREGATION = new NoisyCountGaussianColumnAggregation();
private static final String NAME = "noisy_count_gaussian";
private static final MethodHandle INPUT_FUNCTION = methodHandle(NoisyCountGaussianColumnAggregation.class, "input", NoisyCountState.class, Block.class, Block.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(NoisyCountGaussianColumnAggregation.class, "combine", NoisyCountState.class, NoisyCountState.class);
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(NoisyCountGaussianColumnAggregation.class, "output", NoisyCountState.class, BlockBuilder.class);
public NoisyCountGaussianColumnAggregation()
{
super(NAME,
ImmutableList.of(typeVariable("T")),
ImmutableList.of(),
parseTypeSignature(StandardTypes.BIGINT),
ImmutableList.of(parseTypeSignature("T"), DOUBLE.getTypeSignature()));
}
@Override
public String getDescription()
{
return "Counts the non-null values and then add Gaussian noise to the true count. The noisy count is post-processed to be non-negative and rounded to bigint. Noise is from a secure random.";
}
@Override
public BuiltInAggregationFunctionImplementation specialize(BoundVariables boundVariables, int arity, FunctionAndTypeManager functionAndTypeManager)
{
Type type = boundVariables.getTypeVariable("T");
return generateAggregation(type);
}
private static BuiltInAggregationFunctionImplementation generateAggregation(Type type)
{
DynamicClassLoader classLoader = new DynamicClassLoader(NoisyCountGaussianColumnAggregation.class.getClassLoader());
AccumulatorStateSerializer<NoisyCountState> stateSerializer = StateCompiler.generateStateSerializer(NoisyCountState.class, classLoader);
AccumulatorStateFactory<NoisyCountState> stateFactory = StateCompiler.generateStateFactory(NoisyCountState.class, classLoader);
Type intermediateType = stateSerializer.getSerializedType();
List<Type> inputTypes = ImmutableList.of(type, DOUBLE);
AggregationMetadata metadata = new AggregationMetadata(
generateAggregationName(NAME, BIGINT.getTypeSignature(), inputTypes.stream().map(Type::getTypeSignature).collect(toImmutableList())),
createInputParameterMetadata(type),
INPUT_FUNCTION,
COMBINE_FUNCTION,
OUTPUT_FUNCTION,
ImmutableList.of(new AccumulatorStateDescriptor(
NoisyCountState.class,
stateSerializer,
stateFactory)),
BIGINT);
Class<? extends Accumulator> accumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
Accumulator.class,
metadata,
classLoader);
Class<? extends GroupedAccumulator> groupedAccumulatorClass = AccumulatorCompiler.generateAccumulatorClass(
GroupedAccumulator.class,
metadata,
classLoader);
return new BuiltInAggregationFunctionImplementation(NAME, inputTypes, ImmutableList.of(intermediateType), BIGINT,
true, false, metadata, accumulatorClass, groupedAccumulatorClass);
}
private static List<ParameterMetadata> createInputParameterMetadata(Type type)
{
return ImmutableList.of(
new ParameterMetadata(STATE),
new ParameterMetadata(BLOCK_INPUT_CHANNEL, type),
new ParameterMetadata(BLOCK_INPUT_CHANNEL, DOUBLE),
new ParameterMetadata(BLOCK_INDEX));
}
public static void input(NoisyCountState state, Block valueBlock, Block noiseScaleBlock, int index)
{
double noiseScale = DOUBLE.getDouble(noiseScaleBlock, index);
updateState(state, noiseScale, null);
}
public static void combine(NoisyCountState state, NoisyCountState otherState)
{
combineStates(state, otherState);
}
public static void output(NoisyCountState state, BlockBuilder out)
{
writeNoisyCountOutput(state, out);
}
}