TransformCorrelatedScalarSubquery.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.common.type.BooleanType;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.MarkDistinctNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.analyzer.FunctionAndTypeResolver;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.plan.AssignUniqueId;
import com.facebook.presto.sql.planner.plan.EnforceSingleRowNode;
import com.facebook.presto.sql.planner.plan.LateralJoinNode;
import com.google.common.collect.ImmutableList;
import io.airlift.slice.Slices;

import java.util.Optional;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.matching.Pattern.nonEmpty;
import static com.facebook.presto.metadata.CastType.CAST;
import static com.facebook.presto.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static com.facebook.presto.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments;
import static com.facebook.presto.sql.planner.plan.Patterns.LateralJoin.correlation;
import static com.facebook.presto.sql.planner.plan.Patterns.lateralJoin;
import static com.facebook.presto.sql.relational.Expressions.buildSwitch;
import static com.facebook.presto.sql.relational.Expressions.call;
import static com.facebook.presto.sql.relational.Expressions.specialForm;
import static java.util.Objects.requireNonNull;

/**
 * Scalar filter scan query is something like:
 * <pre>
 *     SELECT a,b,c FROM rel WHERE a = correlated1 AND b = correlated2
 * </pre>
 * <p>
 * This optimizer can rewrite to mark distinct and filter over a left outer join:
 * <p>
 * From:
 * <pre>
 * - LateralJoin (with correlation list: [C])
 *   - (input) plan which produces symbols: [A, B, C]
 *   - (scalar subquery) Project F
 *     - Filter(D = C AND E > 5)
 *       - plan which produces symbols: [D, E, F]
 * </pre>
 * to:
 * <pre>
 * - Filter(CASE isDistinct WHEN true THEN true ELSE fail('Scalar sub-query has returned multiple rows'))
 *   - MarkDistinct(isDistinct)
 *     - LateralJoin (with correlation list: [C])
 *       - AssignUniqueId(adds symbol U)
 *         - (input) plan which produces symbols: [A, B, C]
 *       - non scalar subquery
 * </pre>
 * <p>
 * This must be run after {@link TransformCorrelatedScalarAggregationToJoin}
 */
public class TransformCorrelatedScalarSubquery
        implements Rule<LateralJoinNode>
{
    private static final Pattern<LateralJoinNode> PATTERN = lateralJoin()
            .with(nonEmpty(correlation()));

    private final FunctionAndTypeResolver functionAndTypeResolver;

    public TransformCorrelatedScalarSubquery(FunctionAndTypeManager functionAndTypeManager)
    {
        requireNonNull(functionAndTypeManager, "functionManager is null");
        this.functionAndTypeResolver = functionAndTypeManager.getFunctionAndTypeResolver();
    }

    @Override
    public Pattern getPattern()
    {
        return PATTERN;
    }

    @Override
    public Result apply(LateralJoinNode lateralJoinNode, Captures captures, Context context)
    {
        PlanNode subquery = context.getLookup().resolve(lateralJoinNode.getSubquery());

        if (!searchFrom(subquery, context.getLookup())
                .where(EnforceSingleRowNode.class::isInstance)
                .recurseOnlyWhen(ProjectNode.class::isInstance)
                .matches()) {
            return Result.empty();
        }

        PlanNode rewrittenSubquery = searchFrom(subquery, context.getLookup())
                .where(EnforceSingleRowNode.class::isInstance)
                .recurseOnlyWhen(ProjectNode.class::isInstance)
                .removeFirst();

        if (isAtMostScalar(rewrittenSubquery, context.getLookup())) {
            return Result.ofPlanNode(new LateralJoinNode(
                    lateralJoinNode.getSourceLocation(),
                    context.getIdAllocator().getNextId(),
                    lateralJoinNode.getInput(),
                    rewrittenSubquery,
                    lateralJoinNode.getCorrelation(),
                    lateralJoinNode.getType(),
                    lateralJoinNode.getOriginSubqueryError()));
        }

        VariableReferenceExpression unique = context.getVariableAllocator().newVariable("unique", BIGINT);

        LateralJoinNode rewrittenLateralJoinNode = new LateralJoinNode(
                lateralJoinNode.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                new AssignUniqueId(
                        lateralJoinNode.getSourceLocation(),
                        context.getIdAllocator().getNextId(),
                        lateralJoinNode.getInput(),
                        unique),
                rewrittenSubquery,
                lateralJoinNode.getCorrelation(),
                lateralJoinNode.getType(),
                lateralJoinNode.getOriginSubqueryError());

        VariableReferenceExpression isDistinct = context.getVariableAllocator().newVariable("is_distinct", BooleanType.BOOLEAN);
        MarkDistinctNode markDistinctNode = new MarkDistinctNode(
                rewrittenLateralJoinNode.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                rewrittenLateralJoinNode,
                isDistinct,
                rewrittenLateralJoinNode.getInput().getOutputVariables(),
                Optional.empty());

        CallExpression fail = call(
                "fail",
                functionAndTypeResolver.lookupFunction("fail", fromTypes(INTEGER, VARCHAR)),
                UNKNOWN,
                new ConstantExpression((long) SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode(), INTEGER),
                new ConstantExpression(Slices.utf8Slice("Scalar sub-query has returned multiple rows"), VARCHAR));

        FilterNode filterNode = new FilterNode(
                markDistinctNode.getSourceLocation(),
                context.getIdAllocator().getNextId(),
                markDistinctNode,
                buildSwitch(
                        isDistinct,
                        ImmutableList.of(specialForm(WHEN, BOOLEAN, TRUE_CONSTANT, TRUE_CONSTANT)),
                        Optional.of(call(CAST.name(), functionAndTypeResolver.lookupCast("CAST", UNKNOWN, BOOLEAN), BOOLEAN, fail)),
                        BOOLEAN));
//                castToRowExpression(new SimpleCaseExpression(
//                        createSymbolReference(isDistinct),
//                        ImmutableList.of(
//                                new WhenClause(TRUE_LITERAL, TRUE_LITERAL)),
//                        Optional.of(new Cast(
//                                new FunctionCall(
//                                        QualifiedName.of("fail"),
//                                        ImmutableList.of(
//                                                new LongLiteral(Integer.toString(SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode())),
//                                                new StringLiteral("Scalar sub-query has returned multiple rows"))),
//                                BOOLEAN)))));

        return Result.ofPlanNode(new ProjectNode(
                context.getIdAllocator().getNextId(),
                filterNode,
                identityAssignments(lateralJoinNode.getOutputVariables())));
    }
}