TestPageFunctionCompiler.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.Work;
import com.facebook.presto.operator.project.PageFilter;
import com.facebook.presto.operator.project.PageProjection;
import com.facebook.presto.operator.project.PageProjectionWithOutputs;
import com.facebook.presto.operator.project.SelectedPositions;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.IntStream;
import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.field;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotSame;
import static org.testng.Assert.assertSame;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
public class TestPageFunctionCompiler
{
private static final FunctionAndTypeManager FUNCTION_MANAGER = createTestMetadataManager().getFunctionAndTypeManager();
private static final CallExpression ADD_10_EXPRESSION = call(
ADD.name(),
FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
BIGINT,
field(0, BIGINT),
constant(10L, BIGINT));
private static final CallExpression ADD_X_Y = call(
ADD.name(),
FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
BIGINT,
field(0, BIGINT),
field(1, BIGINT));
private static final CallExpression ADD_X_Y_GREATER_THAN_2 = call(
GREATER_THAN.name(),
FUNCTION_MANAGER.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)),
BOOLEAN,
ADD_X_Y,
constant(2L, BIGINT));
private static final CallExpression ADD_X_Y_LESS_THAN_10 = call(
LESS_THAN.name(),
FUNCTION_MANAGER.resolveOperator(LESS_THAN, fromTypes(BIGINT, BIGINT)),
BOOLEAN,
ADD_X_Y,
constant(10L, BIGINT));
private static final CallExpression ADD_X_Y_Z = call(
ADD.name(),
FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
BIGINT,
call(
ADD.name(),
FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
BIGINT,
field(0, BIGINT),
field(1, BIGINT)),
field(2, BIGINT));
@Test
public void testFailureDoesNotCorruptFutureResults()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
Supplier<PageProjection> projectionSupplier = functionCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty());
PageProjection projection = projectionSupplier.get();
// process good page and verify we got the expected number of result rows
Page goodPage = createLongBlockPage(1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
Block goodResult = project(projection, goodPage, SelectedPositions.positionsRange(0, goodPage.getPositionCount())).get(0);
assertEquals(goodPage.getPositionCount(), goodResult.getPositionCount());
// addition will throw due to integer overflow
Page badPage = createLongBlockPage(1, 0, 1, 2, 3, 4, Long.MAX_VALUE);
try {
project(projection, badPage, SelectedPositions.positionsRange(0, 100));
fail("expected exception");
}
catch (PrestoException e) {
assertEquals(e.getErrorCode(), NUMERIC_VALUE_OUT_OF_RANGE.toErrorCode());
}
// running the good page should still work
// if block builder in generated code was not reset properly, we could get junk results after the failure
goodResult = project(projection, goodPage, SelectedPositions.positionsRange(0, goodPage.getPositionCount())).get(0);
assertEquals(goodPage.getPositionCount(), goodResult.getPositionCount());
}
@Test
public void testGeneratedClassName()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
String planNodeId = "7";
String stageId = "20170707_223500_67496_zguwn.2";
String classSuffix = stageId + "_" + planNodeId;
Supplier<PageProjection> projectionSupplier = functionCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of(classSuffix));
PageProjection projection = projectionSupplier.get();
Work<List<Block>> work = projection.project(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), createLongBlockPage(1, 0), SelectedPositions.positionsRange(0, 1));
// class name should look like PageProjectionOutput_20170707_223500_67496_zguwn_2_7_XX
assertTrue(work.getClass().getSimpleName().startsWith("PageProjectionWork_" + stageId.replace('.', '_') + "_" + planNodeId));
}
@Test
public void testCache()
{
PageFunctionCompiler cacheCompiler = new PageFunctionCompiler(createTestMetadataManager(), 100);
assertSame(
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty()),
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty()));
assertSame(
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint")),
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint")));
assertSame(
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint")),
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint2")));
assertSame(
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty()),
cacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint2")));
PageFunctionCompiler noCacheCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
assertNotSame(
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty()),
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty()));
assertNotSame(
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint")),
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint")));
assertNotSame(
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint")),
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint2")));
assertNotSame(
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.empty()),
noCacheCompiler.compileProjection(SESSION.getSqlFunctionProperties(), ADD_10_EXPRESSION, Optional.of("hint2")));
}
@Test
public void testCommonSubExpressionInProjection()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
List<Supplier<PageProjectionWithOutputs>> pageProjectionsCSE = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), ImmutableList.of(ADD_X_Y, ADD_X_Y_Z), true, Optional.empty());
assertEquals(pageProjectionsCSE.size(), 1);
List<Supplier<PageProjectionWithOutputs>> pageProjectionsNoCSE = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), ImmutableList.of(ADD_X_Y, ADD_X_Y_Z), false, Optional.empty());
assertEquals(pageProjectionsNoCSE.size(), 2);
Page input = createLongBlockPage(3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
List<Block> cseResult = project(pageProjectionsCSE.get(0).get().getPageProjection(), input, SelectedPositions.positionsRange(0, input.getPositionCount()));
assertEquals(cseResult.size(), 2);
List<Block> noCseResult1 = project(pageProjectionsNoCSE.get(0).get().getPageProjection(), input, SelectedPositions.positionsRange(0, input.getPositionCount()));
assertEquals(noCseResult1.size(), 1);
List<Block> noCseResult2 = project(pageProjectionsNoCSE.get(1).get().getPageProjection(), input, SelectedPositions.positionsRange(0, input.getPositionCount()));
assertEquals(noCseResult2.size(), 1);
checkBlockEqual(cseResult.get(0), noCseResult1.get(0));
checkBlockEqual(cseResult.get(1), noCseResult2.get(0));
}
@Test
public void testCommonSubExpressionDuplicatesInProjection()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
List<Supplier<PageProjectionWithOutputs>> pageProjections = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), ImmutableList.of(ADD_X_Y, ADD_X_Y), true, Optional.empty());
assertEquals(pageProjections.size(), 2);
}
@Test
public void testCommonSubExpressionLongProjectionList()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
List<Supplier<PageProjectionWithOutputs>> pageProjections = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), createIfProjectionList(5), true, Optional.empty());
assertEquals(pageProjections.size(), 1);
pageProjections = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), createIfProjectionList(10), true, Optional.empty());
assertEquals(pageProjections.size(), 1);
pageProjections = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), createIfProjectionList(11), true, Optional.empty());
assertEquals(pageProjections.size(), 2);
pageProjections = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), createIfProjectionList(20), true, Optional.empty());
assertEquals(pageProjections.size(), 2);
pageProjections = functionCompiler.compileProjections(SESSION.getSqlFunctionProperties(), createIfProjectionList(101), true, Optional.empty());
assertEquals(pageProjections.size(), 11);
}
@Test
public void testCommonSubExpressionInFilter()
{
PageFunctionCompiler functionCompiler = new PageFunctionCompiler(createTestMetadataManager(), 0);
Supplier<PageFilter> pageFilter = functionCompiler.compileFilter(SESSION.getSqlFunctionProperties(), new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2, ADD_X_Y_LESS_THAN_10), true, Optional.empty());
Page input = createLongBlockPage(2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
SelectedPositions positions = filter(pageFilter.get(), input);
assertEquals(positions.size(), 3);
assertEquals(positions.getPositions(), new int[] {2, 3, 4});
}
private void checkBlockEqual(Block a, Block b)
{
assertEquals(a.getPositionCount(), b.getPositionCount());
for (int i = 0; i < a.getPositionCount(); i++) {
assertEquals(a.getLong(i), b.getLong(i));
}
}
private List<Block> project(PageProjection projection, Page page, SelectedPositions selectedPositions)
{
Work<List<Block>> work = projection.project(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), page, selectedPositions);
assertTrue(work.process());
return work.getResult();
}
private SelectedPositions filter(PageFilter filter, Page page)
{
return filter.filter(SESSION.getSqlFunctionProperties(), filter.getInputChannels().getInputChannels(page));
}
private static Page createLongBlockPage(int blockCount, long... values)
{
Block[] blocks = new Block[blockCount];
for (int i = 0; i < blockCount; i++) {
BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(values.length);
for (long value : values) {
BIGINT.writeLong(builder, value);
}
blocks[i] = builder.build();
}
return new Page(blocks);
}
private List<? extends RowExpression> createIfProjectionList(int projectionCount)
{
return IntStream.range(0, projectionCount)
.mapToObj(i -> new SpecialFormExpression(
IF,
BIGINT,
call(
GREATER_THAN.name(),
FUNCTION_MANAGER.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)),
BOOLEAN,
field(0, BIGINT),
constant(10L, BIGINT)),
constant((long) i, BIGINT),
constant((long) i + 1, BIGINT)))
.collect(toImmutableList());
}
}