TestDruidQueryBase.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.druid;

import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.metadata.MetadataManager;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorId;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.WarningCollector;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.Assignments;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.LimitNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.ProjectNode;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.ExpressionUtils;
import com.facebook.presto.sql.parser.ParsingOptions;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder;
import com.facebook.presto.sql.relational.FunctionResolution;
import com.facebook.presto.sql.relational.SqlToRowExpressionTranslator;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.testing.TestingSession;
import com.facebook.presto.testing.TestingTransactionHandle;
import com.google.common.collect.ImmutableMap;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.druid.DruidColumnHandle.DruidColumnType.REGULAR;
import static com.facebook.presto.druid.DruidQueryGeneratorContext.Origin.DERIVED;
import static com.facebook.presto.druid.DruidQueryGeneratorContext.Origin.TABLE_COLUMN;
import static com.facebook.presto.metadata.FunctionAndTypeManager.createTestFunctionAndTypeManager;
import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager;
import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.testing.TestingConnectorSession.SESSION;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;
import static java.util.stream.Collectors.toMap;

public class TestDruidQueryBase
{
    protected static final FunctionAndTypeManager functionAndTypeManager = createTestFunctionAndTypeManager();
    protected static final StandardFunctionResolution standardFunctionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver());

    protected static ConnectorId druidConnectorId = new ConnectorId("id");
    protected static DruidTableHandle realtimeOnlyTable = new DruidTableHandle("schema", "realtimeOnly", Optional.empty());
    protected static DruidTableHandle hybridTable = new DruidTableHandle("schema", "hybrid", Optional.empty());
    protected static DruidColumnHandle regionId = new DruidColumnHandle("region.Id", BIGINT, REGULAR);
    protected static DruidColumnHandle city = new DruidColumnHandle("city", VARCHAR, REGULAR);
    protected static final DruidColumnHandle fare = new DruidColumnHandle("fare", DOUBLE, REGULAR);
    protected static final DruidColumnHandle secondsSinceEpoch = new DruidColumnHandle("secondsSinceEpoch", BIGINT, REGULAR);
    protected static final DruidColumnHandle datetime = new DruidColumnHandle("datetime", TIMESTAMP, REGULAR);

    protected static final Metadata metadata = MetadataManager.createTestMetadataManager();

    protected final DruidConfig druidConfig = new DruidConfig();

    protected static final Map<VariableReferenceExpression, DruidQueryGeneratorContext.Selection> testInput =
            new ImmutableMap.Builder<VariableReferenceExpression, DruidQueryGeneratorContext.Selection>()
                    .put(new VariableReferenceExpression(Optional.empty(), "region.id", BIGINT), new DruidQueryGeneratorContext.Selection("region.Id", TABLE_COLUMN))
                    .put(new VariableReferenceExpression(Optional.empty(), "city", VARCHAR), new DruidQueryGeneratorContext.Selection("city", TABLE_COLUMN))
                    .put(new VariableReferenceExpression(Optional.empty(), "fare", DOUBLE), new DruidQueryGeneratorContext.Selection("fare", TABLE_COLUMN))
                    .put(new VariableReferenceExpression(Optional.empty(), "totalfare", DOUBLE), new DruidQueryGeneratorContext.Selection("(fare + trip)", DERIVED))
                    .put(new VariableReferenceExpression(Optional.empty(), "secondssinceepoch", BIGINT), new DruidQueryGeneratorContext.Selection("secondsSinceEpoch", TABLE_COLUMN))
                    .put(new VariableReferenceExpression(Optional.empty(), "datetime", TIMESTAMP), new DruidQueryGeneratorContext.Selection("datetime", TABLE_COLUMN))
                    .build();

    protected final TypeProvider typeProvider = TypeProvider.fromVariables(testInput.keySet());

    protected static class SessionHolder
    {
        private final ConnectorSession connectorSession;
        private final Session session;

        public SessionHolder()
        {
            connectorSession = SESSION;
            session = TestingSession.testSessionBuilder(createTestingSessionPropertyManager(new SystemSessionProperties().getSessionProperties())).build();
        }

        public ConnectorSession getConnectorSession()
        {
            return connectorSession;
        }

        public Session getSession()
        {
            return session;
        }
    }

    protected VariableReferenceExpression variable(String name)
    {
        return testInput.keySet().stream().filter(v -> v.getName().equals(name)).findFirst().orElseThrow(() -> new IllegalArgumentException("Cannot find variable " + name));
    }

    protected TableScanNode tableScan(PlanBuilder planBuilder, DruidTableHandle connectorTableHandle, DruidColumnHandle... columnHandles)
    {
        List<VariableReferenceExpression> variables = Arrays.stream(columnHandles).map(ch -> new VariableReferenceExpression(Optional.empty(), ch.getColumnName().toLowerCase(ENGLISH), ch.getColumnType())).collect(toImmutableList());
        ImmutableMap.Builder<VariableReferenceExpression, ColumnHandle> assignments = ImmutableMap.builder();
        for (int i = 0; i < variables.size(); ++i) {
            assignments.put(variables.get(i), columnHandles[i]);
        }
        TableHandle tableHandle = new TableHandle(
                druidConnectorId,
                connectorTableHandle,
                TestingTransactionHandle.create(),
                Optional.empty());
        return new TableScanNode(
                Optional.empty(),
                planBuilder.getIdAllocator().getNextId(),
                tableHandle,
                variables,
                assignments.build(),
                TupleDomain.all(),
                TupleDomain.all(), Optional.empty());
    }

    protected FilterNode filter(PlanBuilder planBuilder, PlanNode source, RowExpression predicate)
    {
        return planBuilder.filter(predicate, source);
    }

    protected ProjectNode project(PlanBuilder planBuilder, PlanNode source, List<String> columnNames)
    {
        Map<String, VariableReferenceExpression> incomingColumns = source.getOutputVariables().stream().collect(toMap(VariableReferenceExpression::getName, identity()));
        Assignments.Builder assignmentsBuilder = Assignments.builder();
        columnNames.forEach(columnName -> {
            VariableReferenceExpression variable = requireNonNull(incomingColumns.get(columnName), "Couldn't find the incoming column " + columnName);
            assignmentsBuilder.put(variable, variable);
        });
        return planBuilder.project(assignmentsBuilder.build(), source);
    }

    public static Expression expression(String sql)
    {
        return ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql, new ParsingOptions(ParsingOptions.DecimalLiteralTreatment.AS_DECIMAL)));
    }

    protected RowExpression toRowExpression(Expression expression, Session session)
    {
        Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(
                session,
                metadata,
                new SqlParser(),
                typeProvider,
                expression,
                ImmutableMap.of(),
                WarningCollector.NOOP);
        return SqlToRowExpressionTranslator.translate(expression, expressionTypes, ImmutableMap.of(), functionAndTypeManager, session);
    }

    protected LimitNode limit(PlanBuilder pb, long count, PlanNode source)
    {
        return new LimitNode(Optional.empty(), pb.getIdAllocator().getNextId(), source, count, FINAL);
    }

    protected RowExpression getRowExpression(String sqlExpression, SessionHolder sessionHolder)
    {
        return toRowExpression(expression(sqlExpression), sessionHolder.getSession());
    }

    protected PlanBuilder createPlanBuilder(SessionHolder sessionHolder)
    {
        return new PlanBuilder(sessionHolder.getSession(), new PlanNodeIdAllocator(), metadata);
    }
}