TestCursorProcessorCompiler.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.ClassDefinition;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.operator.DriverYieldSignal;
import com.facebook.presto.operator.index.PageRecordSet;
import com.facebook.presto.operator.project.CursorProcessor;
import com.facebook.presto.spi.RecordSet;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static com.facebook.presto.bytecode.Access.FINAL;
import static com.facebook.presto.bytecode.Access.PUBLIC;
import static com.facebook.presto.bytecode.Access.a;
import static com.facebook.presto.bytecode.ParameterizedType.type;
import static com.facebook.presto.common.function.OperatorType.ADD;
import static com.facebook.presto.common.function.OperatorType.GREATER_THAN;
import static com.facebook.presto.common.function.OperatorType.LESS_THAN;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.AND;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IF;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.CommonSubExpressionFields.declareCommonSubExpressionFields;
import static com.facebook.presto.sql.gen.CommonSubExpressionRewriter.collectCSEByLevel;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.constant;
import static com.facebook.presto.sql.relational.Expressions.field;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static com.facebook.presto.util.CompilerUtils.makeClassName;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.util.Collections.emptyMap;
import static java.util.stream.Collectors.toList;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.expectThrows;

public class TestCursorProcessorCompiler
{
    private static final Metadata METADATA = createTestMetadataManager();
    private static final FunctionAndTypeManager FUNCTION_MANAGER = METADATA.getFunctionAndTypeManager();

    // Constants for testing JVM limits
    private static final int CONSTANT_POOL_STRESS_PROJECTION_COUNT = 8000;

    private static final CallExpression ADD_X_Y = call(
            ADD.name(),
            FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
            BIGINT,
            field(0, BIGINT),
            field(1, BIGINT));

    private static final CallExpression ADD_X_Y_GREATER_THAN_2 = call(
            GREATER_THAN.name(),
            FUNCTION_MANAGER.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)),
            BOOLEAN,
            ADD_X_Y,
            constant(2L, BIGINT));

    private static final CallExpression ADD_X_Y_LESS_THAN_10 = call(
            LESS_THAN.name(),
            FUNCTION_MANAGER.resolveOperator(LESS_THAN, fromTypes(BIGINT, BIGINT)),
            BOOLEAN,
            ADD_X_Y,
            constant(10L, BIGINT));

    private static final CallExpression ADD_X_Y_Z = call(
            ADD.name(),
            FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
            BIGINT,
            call(
                    ADD.name(),
                    FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
                    BIGINT,
                    field(0, BIGINT),
                    field(1, BIGINT)),
            field(2, BIGINT));

    @Test
    public void testRewriteRowExpressionWithCSE()
    {
        CursorProcessorCompiler cseCursorCompiler = new CursorProcessorCompiler(METADATA, true, emptyMap());

        ClassDefinition cursorProcessorClassDefinition = new ClassDefinition(
                a(PUBLIC, FINAL),
                makeClassName(CursorProcessor.class.getSimpleName()),
                type(Object.class),
                type(CursorProcessor.class));

        RowExpression filter = new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2);
        List<RowExpression> projections = ImmutableList.of(ADD_X_Y_Z);
        List<RowExpression> rowExpressions = ImmutableList.<RowExpression>builder()
                .addAll(projections)
                .add(filter)
                .build();
        Map<Integer, Map<RowExpression, VariableReferenceExpression>> commonSubExpressionsByLevel = collectCSEByLevel(rowExpressions);

        Map<VariableReferenceExpression, CommonSubExpressionRewriter.CommonSubExpressionFields> cseFields = declareCommonSubExpressionFields(cursorProcessorClassDefinition, commonSubExpressionsByLevel);
        Map<RowExpression, VariableReferenceExpression> commonSubExpressions = commonSubExpressionsByLevel.values().stream()
                .flatMap(m -> m.entrySet().stream())
                .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue));
        // X+Y as CSE
        assertEquals(1, cseFields.size());
        VariableReferenceExpression cseVariable = cseFields.keySet().iterator().next();

        RowExpression rewrittenFilter = cseCursorCompiler.rewriteRowExpressionsWithCSE(ImmutableList.of(filter), commonSubExpressions).get(0);

        List<RowExpression> rewrittenProjections = cseCursorCompiler.rewriteRowExpressionsWithCSE(projections, commonSubExpressions);

        // X+Y+Z contains CSE X+Y
        assertTrue(((CallExpression) rewrittenProjections.get(0)).getArguments().contains(cseVariable));

        // X+Y > 2 consists CSE X+Y
        assertTrue(((CallExpression) ((SpecialFormExpression) rewrittenFilter).getArguments().get(0)).getArguments().contains(cseVariable));
    }

    @Test
    public void testCompilerWithCSE()
    {
        PageFunctionCompiler functionCompiler = new PageFunctionCompiler(METADATA, 0);
        ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, functionCompiler);

        RowExpression filter = new SpecialFormExpression(AND, BIGINT, ADD_X_Y_GREATER_THAN_2, ADD_X_Y_LESS_THAN_10);
        List<? extends RowExpression> projections = createIfProjectionList(5);

        Supplier<CursorProcessor> cseCursorProcessorSupplier = expressionCompiler.compileCursorProcessor(SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, "key", true);
        Supplier<CursorProcessor> noCseSECursorProcessorSupplier = expressionCompiler.compileCursorProcessor(SESSION.getSqlFunctionProperties(), Optional.of(filter), projections, "key", false);

        Page input = createLongBlockPage(2, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9);

        List<Type> types = ImmutableList.of(BIGINT, BIGINT);
        PageBuilder pageBuilder = new PageBuilder(projections.stream().map(RowExpression::getType).collect(toList()));
        RecordSet recordSet = new PageRecordSet(types, input);
        cseCursorProcessorSupplier.get().process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);

        Page pageFromCSE = pageBuilder.build();
        pageBuilder.reset();

        noCseSECursorProcessorSupplier.get().process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);
        Page pageFromNoCSE = pageBuilder.build();

        checkPageEqual(pageFromCSE, pageFromNoCSE);
    }

    @DataProvider(name = "projectionCounts")
    public Object[][] projectionCounts()
    {
        return new Object[][] {
                {1},
                {10},
                {1000},
                {1500},
                {5000},
                {6000}
        };
    }

    @Test(dataProvider = "projectionCounts")
    public void testProjectionBatching(int projectionCount)
    {
        PageFunctionCompiler functionCompiler = new PageFunctionCompiler(METADATA, 0);
        ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, functionCompiler);

        List<RowExpression> projections = IntStream.range(0, projectionCount)
                .mapToObj(i -> field(i % 2, BIGINT))
                .collect(toImmutableList());

        Supplier<CursorProcessor> cursorProcessorSupplier = expressionCompiler.compileCursorProcessor(
                SESSION.getSqlFunctionProperties(),
                Optional.empty(),
                projections,
                "testProjectionBatching_" + projectionCount,
                false);

        CursorProcessor processor = cursorProcessorSupplier.get();
        assertNotNull(processor, "CursorProcessor should be created successfully for projectionCount = " + projectionCount);

        Page input = createLongBlockPage(2, 1L, 2L, 3L, 4L, 5L);
        List<Type> types = ImmutableList.of(BIGINT, BIGINT);
        PageBuilder pageBuilder = new PageBuilder(projections.stream().map(RowExpression::getType).collect(toList()));
        RecordSet recordSet = new PageRecordSet(types, input);

        processor.process(SESSION.getSqlFunctionProperties(), new DriverYieldSignal(), recordSet.cursor(), pageBuilder);
        Page result = pageBuilder.build();

        assertEquals(result.getChannelCount(), projectionCount, "Mismatch in projected column count");
        assertEquals(result.getPositionCount(), input.getPositionCount(), "Mismatch in row count");
    }

    /**
     * NEW NEGATIVE TEST CASE:
     * This test demonstrates that while we can handle MethodTooLarge exceptions through batching,
     * the JVM constant pool size limit remains a constraint when projections contain many unique constants.
     *
     * This test creates projections with random constants generated via Java code, which fills up
     * the constant pool and should still cause compilation failures even with our batching approach.
     * This clearly shows the scope of what we're solving (MethodTooLarge) vs. what remains a JVM constraint.
     */
    @Test
    public void testConstantPoolLimitStillConstrainsLargeProjections()
    {
        ExpressionCompiler expressionCompiler = new ExpressionCompiler(METADATA, new PageFunctionCompiler(METADATA, 0));

        List<RowExpression> projectionsWithRandomConstants = createProjectionsWithRandomConstants(CONSTANT_POOL_STRESS_PROJECTION_COUNT);

        expectThrows(RuntimeException.class, () -> {
            expressionCompiler.compileCursorProcessor(
                    SESSION.getSqlFunctionProperties(),
                    Optional.empty(),
                    projectionsWithRandomConstants,
                    "testConstantPoolLimit",
                    false);
        });
    }

    /**
     * Helper method to create projections with many unique random constants.
     * This is designed to stress the JVM constant pool limit.
     */
    private List<RowExpression> createProjectionsWithRandomConstants(int count)
    {
        Random random = new Random(42);
        return IntStream.range(0, count)
                .mapToObj(i -> {
                    long randomConstant = random.nextLong();
                    return call(
                            ADD.name(),
                            FUNCTION_MANAGER.resolveOperator(ADD, fromTypes(BIGINT, BIGINT)),
                            BIGINT,
                            field(0, BIGINT),
                            constant(randomConstant, BIGINT)); // Each projection gets a unique constant
                })
                .collect(toImmutableList());
    }

    private static Page createLongBlockPage(int blockCount, long... values)
    {
        Block[] blocks = new Block[blockCount];
        for (int i = 0; i < blockCount; i++) {
            BlockBuilder builder = BIGINT.createFixedSizeBlockBuilder(values.length);
            for (long value : values) {
                BIGINT.writeLong(builder, value);
            }
            blocks[i] = builder.build();
        }
        return new Page(blocks);
    }

    private List<? extends RowExpression> createIfProjectionList(int projectionCount)
    {
        return IntStream.range(0, projectionCount)
                .mapToObj(i -> new SpecialFormExpression(
                        IF,
                        BIGINT,
                        call(
                                GREATER_THAN.name(),
                                FUNCTION_MANAGER.resolveOperator(GREATER_THAN, fromTypes(BIGINT, BIGINT)),
                                BOOLEAN,
                                ADD_X_Y,
                                constant(8L, BIGINT)),
                        constant((long) i, BIGINT),
                        constant((long) i + 1, BIGINT)))
                .collect(toImmutableList());
    }

    private void checkBlockEqual(Block a, Block b)
    {
        assertEquals(a.getPositionCount(), b.getPositionCount());
        for (int i = 0; i < a.getPositionCount(); i++) {
            assertEquals(a.getLong(i), b.getLong(i));
        }
    }

    private void checkPageEqual(Page a, Page b)
    {
        assertEquals(a.getPositionCount(), b.getPositionCount());
        assertEquals(a.getChannelCount(), b.getChannelCount());
        for (int i = 0; i < a.getChannelCount(); i++) {
            checkBlockEqual(a.getBlock(i), b.getBlock(i));
        }
    }
}