TestStateCompiler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation;
import com.facebook.presto.bytecode.DynamicClassLoader;
import com.facebook.presto.common.array.BlockBigArray;
import com.facebook.presto.common.array.BooleanBigArray;
import com.facebook.presto.common.array.ByteBigArray;
import com.facebook.presto.common.array.DoubleBigArray;
import com.facebook.presto.common.array.IntBigArray;
import com.facebook.presto.common.array.LongBigArray;
import com.facebook.presto.common.array.ReferenceCountMap;
import com.facebook.presto.common.array.SliceBigArray;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.aggregation.state.LongState;
import com.facebook.presto.operator.aggregation.state.NullableLongState;
import com.facebook.presto.operator.aggregation.state.StateCompiler;
import com.facebook.presto.operator.aggregation.state.VarianceState;
import com.facebook.presto.spi.function.AccumulatorState;
import com.facebook.presto.spi.function.AccumulatorStateFactory;
import com.facebook.presto.spi.function.AccumulatorStateMetadata;
import com.facebook.presto.spi.function.AccumulatorStateSerializer;
import com.facebook.presto.spi.function.GroupedAccumulatorState;
import com.facebook.presto.spi.function.TypeParameter;
import com.facebook.presto.util.Reflection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slice;
import org.openjdk.jol.info.ClassLayout;
import org.testng.annotations.Test;
import java.lang.invoke.MethodHandle;
import java.lang.reflect.Field;
import java.util.Map;
import static com.facebook.presto.block.BlockAssertions.createLongsBlock;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.util.StructuralTestUtil.mapBlockOf;
import static com.facebook.presto.util.StructuralTestUtil.mapType;
import static io.airlift.slice.Slices.utf8Slice;
import static io.airlift.slice.Slices.wrappedDoubleArray;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;
public class TestStateCompiler
{
@Test
public void testPrimitiveNullableLongSerialization()
{
AccumulatorStateFactory<NullableLongState> factory = StateCompiler.generateStateFactory(NullableLongState.class);
AccumulatorStateSerializer<NullableLongState> serializer = StateCompiler.generateStateSerializer(NullableLongState.class);
NullableLongState state = factory.createSingleState();
NullableLongState deserializedState = factory.createSingleState();
state.setLong(2);
state.setNull(false);
BlockBuilder builder = BIGINT.createBlockBuilder(null, 2);
serializer.serialize(state, builder);
state.setNull(true);
serializer.serialize(state, builder);
Block block = builder.build();
assertEquals(block.isNull(0), false);
assertEquals(BIGINT.getLong(block, 0), state.getLong());
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.getLong(), state.getLong());
assertEquals(block.isNull(1), true);
}
@Test
public void testPrimitiveLongSerialization()
{
AccumulatorStateFactory<LongState> factory = StateCompiler.generateStateFactory(LongState.class);
AccumulatorStateSerializer<LongState> serializer = StateCompiler.generateStateSerializer(LongState.class);
LongState state = factory.createSingleState();
LongState deserializedState = factory.createSingleState();
state.setLong(2);
BlockBuilder builder = BIGINT.createBlockBuilder(null, 1);
serializer.serialize(state, builder);
Block block = builder.build();
assertEquals(BIGINT.getLong(block, 0), state.getLong());
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.getLong(), state.getLong());
}
@Test
public void testGetSerializedType()
{
AccumulatorStateSerializer<LongState> serializer = StateCompiler.generateStateSerializer(LongState.class);
assertEquals(serializer.getSerializedType(), BIGINT);
}
@Test
public void testPrimitiveBooleanSerialization()
{
AccumulatorStateFactory<BooleanState> factory = StateCompiler.generateStateFactory(BooleanState.class);
AccumulatorStateSerializer<BooleanState> serializer = StateCompiler.generateStateSerializer(BooleanState.class);
BooleanState state = factory.createSingleState();
BooleanState deserializedState = factory.createSingleState();
state.setBoolean(true);
BlockBuilder builder = BOOLEAN.createBlockBuilder(null, 1);
serializer.serialize(state, builder);
Block block = builder.build();
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.isBoolean(), state.isBoolean());
}
@Test
public void testPrimitiveByteSerialization()
{
AccumulatorStateFactory<ByteState> factory = StateCompiler.generateStateFactory(ByteState.class);
AccumulatorStateSerializer<ByteState> serializer = StateCompiler.generateStateSerializer(ByteState.class);
ByteState state = factory.createSingleState();
ByteState deserializedState = factory.createSingleState();
state.setByte((byte) 3);
BlockBuilder builder = TINYINT.createBlockBuilder(null, 1);
serializer.serialize(state, builder);
Block block = builder.build();
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.getByte(), state.getByte());
}
@Test
public void testNonPrimitiveSerialization()
{
AccumulatorStateFactory<SliceState> factory = StateCompiler.generateStateFactory(SliceState.class);
AccumulatorStateSerializer<SliceState> serializer = StateCompiler.generateStateSerializer(SliceState.class);
SliceState state = factory.createSingleState();
SliceState deserializedState = factory.createSingleState();
state.setSlice(null);
BlockBuilder nullBlockBuilder = VARCHAR.createBlockBuilder(null, 1);
serializer.serialize(state, nullBlockBuilder);
Block nullBlock = nullBlockBuilder.build();
serializer.deserialize(nullBlock, 0, deserializedState);
assertEquals(deserializedState.getSlice(), state.getSlice());
state.setSlice(utf8Slice("test"));
BlockBuilder builder = VARCHAR.createBlockBuilder(null, 1);
serializer.serialize(state, builder);
Block block = builder.build();
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.getSlice(), state.getSlice());
}
@Test
public void testVarianceStateSerialization()
{
AccumulatorStateFactory<VarianceState> factory = StateCompiler.generateStateFactory(VarianceState.class);
AccumulatorStateSerializer<VarianceState> serializer = StateCompiler.generateStateSerializer(VarianceState.class);
VarianceState singleState = factory.createSingleState();
VarianceState deserializedState = factory.createSingleState();
singleState.setMean(1);
singleState.setCount(2);
singleState.setM2(3);
BlockBuilder builder = RowType.anonymous(ImmutableList.of(BIGINT, DOUBLE, DOUBLE)).createBlockBuilder(null, 1);
serializer.serialize(singleState, builder);
Block block = builder.build();
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.getCount(), singleState.getCount());
assertEquals(deserializedState.getMean(), singleState.getMean());
assertEquals(deserializedState.getM2(), singleState.getM2());
}
@Test
public void testComplexSerialization()
{
Type arrayType = new ArrayType(BIGINT);
Type mapType = mapType(BIGINT, VARCHAR);
Map<String, Type> fieldMap = ImmutableMap.of("Block", arrayType, "AnotherBlock", mapType);
AccumulatorStateFactory<TestComplexState> factory = StateCompiler.generateStateFactory(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader()));
AccumulatorStateSerializer<TestComplexState> serializer = StateCompiler.generateStateSerializer(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader()));
TestComplexState singleState = factory.createSingleState();
TestComplexState deserializedState = factory.createSingleState();
singleState.setBoolean(true);
singleState.setLong(1);
singleState.setDouble(2.0);
singleState.setByte((byte) 3);
singleState.setInt(4);
singleState.setSlice(utf8Slice("test"));
singleState.setAnotherSlice(wrappedDoubleArray(1.0, 2.0, 3.0));
singleState.setYetAnotherSlice(null);
Block array = createLongsBlock(45);
singleState.setBlock(array);
singleState.setAnotherBlock(mapBlockOf(BIGINT, VARCHAR, ImmutableMap.of(123L, "testBlock")));
BlockBuilder builder = RowType.anonymous(ImmutableList.of(BOOLEAN, TINYINT, DOUBLE, INTEGER, BIGINT, mapType, VARBINARY, arrayType, VARBINARY, VARBINARY))
.createBlockBuilder(null, 1);
serializer.serialize(singleState, builder);
Block block = builder.build();
serializer.deserialize(block, 0, deserializedState);
assertEquals(deserializedState.getBoolean(), singleState.getBoolean());
assertEquals(deserializedState.getLong(), singleState.getLong());
assertEquals(deserializedState.getDouble(), singleState.getDouble());
assertEquals(deserializedState.getByte(), singleState.getByte());
assertEquals(deserializedState.getInt(), singleState.getInt());
assertEquals(deserializedState.getSlice(), singleState.getSlice());
assertEquals(deserializedState.getAnotherSlice(), singleState.getAnotherSlice());
assertEquals(deserializedState.getYetAnotherSlice(), singleState.getYetAnotherSlice());
assertEquals(deserializedState.getBlock().getLong(0), singleState.getBlock().getLong(0));
assertEquals(deserializedState.getAnotherBlock().getLong(0), singleState.getAnotherBlock().getLong(0));
assertEquals(deserializedState.getAnotherBlock().getSlice(1, 0, 9), singleState.getAnotherBlock().getSlice(1, 0, 9));
}
private long getComplexStateRetainedSize(TestComplexState state)
{
long retainedSize = ClassLayout.parseClass(state.getClass()).instanceSize();
// reflection is necessary because TestComplexState implementation is generated
Field[] fields = state.getClass().getDeclaredFields();
try {
for (Field field : fields) {
Class type = field.getType();
field.setAccessible(true);
if (type == BlockBigArray.class || type == BooleanBigArray.class || type == SliceBigArray.class ||
type == ByteBigArray.class || type == DoubleBigArray.class || type == LongBigArray.class || type == IntBigArray.class) {
MethodHandle sizeOf = Reflection.methodHandle(type, "sizeOf");
retainedSize += (long) sizeOf.invokeWithArguments(field.get(state));
}
}
}
catch (Throwable t) {
throw new RuntimeException(t);
}
return retainedSize;
}
private static long getReferenceCountMapOverhead(TestComplexState state)
{
long overhead = 0;
// reflection is necessary because TestComplexState implementation is generated
Field[] stateFields = state.getClass().getDeclaredFields();
try {
for (Field stateField : stateFields) {
if (stateField.getType() != BlockBigArray.class && stateField.getType() != SliceBigArray.class) {
continue;
}
stateField.setAccessible(true);
Field[] bigArrayFields = stateField.getType().getDeclaredFields();
for (Field bigArrayField : bigArrayFields) {
if (bigArrayField.getType() != ReferenceCountMap.class) {
continue;
}
bigArrayField.setAccessible(true);
MethodHandle sizeOf = Reflection.methodHandle(bigArrayField.getType(), "sizeOf");
overhead += (long) sizeOf.invokeWithArguments(bigArrayField.get(stateField.get(state)));
}
}
}
catch (Throwable t) {
throw new RuntimeException(t);
}
return overhead;
}
@Test(invocationCount = 100, successPercentage = 90)
public void testComplexStateEstimatedSize()
{
Map<String, Type> fieldMap = ImmutableMap.of("Block", new ArrayType(BIGINT), "AnotherBlock", mapType(BIGINT, VARCHAR));
AccumulatorStateFactory<TestComplexState> factory = StateCompiler.generateStateFactory(TestComplexState.class, fieldMap, new DynamicClassLoader(TestComplexState.class.getClassLoader()));
TestComplexState groupedState = factory.createGroupedState();
long initialRetainedSize = getComplexStateRetainedSize(groupedState);
assertEquals(groupedState.getEstimatedSize(), initialRetainedSize);
// BlockBigArray or SliceBigArray has an internal map that can grow in size when getting more blocks
// need to handle the map overhead separately
initialRetainedSize -= getReferenceCountMapOverhead(groupedState);
for (int i = 0; i < 1000; i++) {
long retainedSize = 0;
((GroupedAccumulatorState) groupedState).setGroupId(i);
groupedState.setBoolean(true);
groupedState.setLong(1);
groupedState.setDouble(2.0);
groupedState.setByte((byte) 3);
groupedState.setInt(4);
Slice slice = utf8Slice("test");
retainedSize += slice.getRetainedSize();
groupedState.setSlice(slice);
slice = wrappedDoubleArray(1.0, 2.0, 3.0);
retainedSize += slice.getRetainedSize();
groupedState.setAnotherSlice(slice);
groupedState.setYetAnotherSlice(null);
Block array = createLongsBlock(45);
retainedSize += array.getRetainedSizeInBytes();
groupedState.setBlock(array);
BlockBuilder mapBlockBuilder = mapType(BIGINT, VARCHAR).createBlockBuilder(null, 1);
BlockBuilder singleMapBlockWriter = mapBlockBuilder.beginBlockEntry();
BIGINT.writeLong(singleMapBlockWriter, 123L);
VARCHAR.writeSlice(singleMapBlockWriter, utf8Slice("testBlock"));
mapBlockBuilder.closeEntry();
Block map = mapBlockBuilder.build();
retainedSize += map.getRetainedSizeInBytes();
groupedState.setAnotherBlock(map);
assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * (i + 1) + getReferenceCountMapOverhead(groupedState));
}
for (int i = 0; i < 1000; i++) {
long retainedSize = 0;
((GroupedAccumulatorState) groupedState).setGroupId(i);
groupedState.setBoolean(true);
groupedState.setLong(1);
groupedState.setDouble(2.0);
groupedState.setByte((byte) 3);
groupedState.setInt(4);
Slice slice = utf8Slice("test");
retainedSize += slice.getRetainedSize();
groupedState.setSlice(slice);
slice = wrappedDoubleArray(1.0, 2.0, 3.0);
retainedSize += slice.getRetainedSize();
groupedState.setAnotherSlice(slice);
groupedState.setYetAnotherSlice(null);
Block array = createLongsBlock(45);
retainedSize += array.getRetainedSizeInBytes();
groupedState.setBlock(array);
BlockBuilder mapBlockBuilder = mapType(BIGINT, VARCHAR).createBlockBuilder(null, 1);
BlockBuilder singleMapBlockWriter = mapBlockBuilder.beginBlockEntry();
BIGINT.writeLong(singleMapBlockWriter, 123L);
VARCHAR.writeSlice(singleMapBlockWriter, utf8Slice("testBlock"));
mapBlockBuilder.closeEntry();
Block map = mapBlockBuilder.build();
retainedSize += map.getRetainedSizeInBytes();
groupedState.setAnotherBlock(map);
assertEquals(groupedState.getEstimatedSize(), initialRetainedSize + retainedSize * 1000 + getReferenceCountMapOverhead(groupedState));
}
}
@Test
public void testStateSerializerConstructorsWithMetadata()
{
Map<String, Type> fields = ImmutableMap.of("T", BIGINT, "E", VARCHAR);
Object stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerNoType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerNoType.class.getClassLoader()));
assertTrue(stateSerializer instanceof TestAccumulatorSerializerNoType);
stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerMultipleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerMultipleType.class.getClassLoader()));
assertTrue(stateSerializer instanceof TestAccumulatorSerializerMultipleType);
stateSerializer = StateCompiler.generateStateSerializer(TestAccumulatorSerializerSingleType.class, fields, new DynamicClassLoader(TestAccumulatorSerializerSingleType.class.getClassLoader()));
assertTrue(stateSerializer instanceof TestAccumulatorSerializerSingleType);
assertThrows(() -> StateCompiler.generateStateSerializer(
TestAccumulatorSerializerUntyped.class,
fields,
new DynamicClassLoader(TestAccumulatorSerializerUntyped.class.getClassLoader())));
assertThrows(() -> StateCompiler.generateStateSerializer(
TestAccumulatorAmbiguousConstructor.class,
fields,
new DynamicClassLoader(TestAccumulatorAmbiguousConstructor.class.getClassLoader())));
}
public interface TestComplexState
extends AccumulatorState
{
double getDouble();
void setDouble(double value);
boolean getBoolean();
void setBoolean(boolean value);
long getLong();
void setLong(long value);
byte getByte();
void setByte(byte value);
int getInt();
void setInt(int value);
Slice getSlice();
void setSlice(Slice slice);
Slice getAnotherSlice();
void setAnotherSlice(Slice slice);
Slice getYetAnotherSlice();
void setYetAnotherSlice(Slice slice);
Block getBlock();
void setBlock(Block block);
Block getAnotherBlock();
void setAnotherBlock(Block block);
}
public interface BooleanState
extends AccumulatorState
{
boolean isBoolean();
void setBoolean(boolean value);
}
public interface ByteState
extends AccumulatorState
{
byte getByte();
void setByte(byte value);
}
public interface SliceState
extends AccumulatorState
{
Slice getSlice();
void setSlice(Slice slice);
}
private abstract static class TestAccumulatorSerializer
implements AccumulatorStateSerializer<Object>
{
@Override
public Type getSerializedType()
{
return null;
}
@Override
public void serialize(Object state, BlockBuilder out)
{}
@Override
public void deserialize(Block block, int index, Object state)
{}
}
@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerSingleType.class)
public static class TestAccumulatorSerializerSingleType
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerSingleType(@TypeParameter("E") Type first)
{}
}
@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerMultipleType.class)
public static class TestAccumulatorSerializerMultipleType
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerMultipleType(@TypeParameter("E") Type first, @TypeParameter("T") Type second)
{}
}
@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerNoType.class)
public static class TestAccumulatorSerializerNoType
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerNoType()
{}
}
@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorAmbiguousConstructor.class)
public static class TestAccumulatorAmbiguousConstructor
extends TestAccumulatorSerializer
{
public TestAccumulatorAmbiguousConstructor()
{}
public TestAccumulatorAmbiguousConstructor(@TypeParameter("E") Type type)
{}
}
// test all invalid constructor types
@AccumulatorStateMetadata(stateSerializerClass = TestAccumulatorSerializerUntyped.class)
public static class TestAccumulatorSerializerUntyped
extends TestAccumulatorSerializer
{
public TestAccumulatorSerializerUntyped(Type x)
{}
public TestAccumulatorSerializerUntyped(int y)
{}
// type parameter G should not be in passed fields
public TestAccumulatorSerializerUntyped(@TypeParameter("G") Object y)
{}
// type parameter G should not be in passed fields
public TestAccumulatorSerializerUntyped(@TypeParameter("E") Type x, @TypeParameter("G") Long y)
{}
}
}