DruidQueryGeneratorContext.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.druid;

import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.plan.PlanNodeId;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;

import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;

import static com.facebook.presto.druid.DruidErrorCode.DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION;
import static com.facebook.presto.druid.DruidErrorCode.DRUID_QUERY_GENERATOR_FAILURE;
import static com.facebook.presto.druid.DruidPushdownUtils.DRUID_COUNT_DISTINCT_FUNCTION_NAME;
import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class DruidQueryGeneratorContext
{
    private final Map<VariableReferenceExpression, Selection> selections;
    private final Map<VariableReferenceExpression, Selection> groupByColumns;
    private final Set<VariableReferenceExpression> hiddenColumnSet;
    private final Set<VariableReferenceExpression> variablesInAggregation;
    private final Optional<String> from;
    private final Optional<String> filter;
    private final OptionalLong limit;
    private final int aggregations;
    private final Optional<PlanNodeId> tableScanNodeId;

    @Override
    public String toString()
    {
        return toStringHelper(this)
                .add("selections", selections)
                .add("groupByColumns", groupByColumns)
                .add("hiddenColumnSet", hiddenColumnSet)
                .add("variablesInAggregation", variablesInAggregation)
                .add("from", from)
                .add("filter", filter)
                .add("limit", limit)
                .add("aggregations", aggregations)
                .add("tableScanNodeId", tableScanNodeId)
                .toString();
    }

    DruidQueryGeneratorContext()
    {
        this(new LinkedHashMap<>(), null, null);
    }

    DruidQueryGeneratorContext(
            Map<VariableReferenceExpression, Selection> selections,
            String from,
            PlanNodeId planNodeId)
    {
        this(
                selections,
                Optional.ofNullable(from),
                Optional.empty(),
                OptionalLong.empty(),
                0,
                new LinkedHashMap<>(),
                new HashSet<>(),
                new HashSet<>(),
                Optional.ofNullable(planNodeId));
    }

    private DruidQueryGeneratorContext(
            Map<VariableReferenceExpression, Selection> selections,
            Optional<String> from,
            Optional<String> filter,
            OptionalLong limit,
            int aggregations,
            Map<VariableReferenceExpression, Selection> groupByColumns,
            Set<VariableReferenceExpression> variablesInAggregation,
            Set<VariableReferenceExpression> hiddenColumnSet,
            Optional<PlanNodeId> tableScanNodeId)
    {
        this.selections = new LinkedHashMap<>(requireNonNull(selections, "selections can't be null"));
        this.from = requireNonNull(from, "source can't be null");
        this.filter = requireNonNull(filter, "filter is null");
        this.limit = requireNonNull(limit, "limit is null");
        this.aggregations = aggregations;
        this.groupByColumns = new LinkedHashMap<>(requireNonNull(groupByColumns, "groupByColumns can't be null. It could be empty if not available"));
        this.hiddenColumnSet = requireNonNull(hiddenColumnSet, "hidden column set is null");
        this.variablesInAggregation = requireNonNull(variablesInAggregation, "variables in aggregation is null");
        this.tableScanNodeId = requireNonNull(tableScanNodeId, "tableScanNodeId can't be null");
    }

    public DruidQueryGeneratorContext withFilter(String filter)
    {
        if (hasAggregation()) {
            throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Druid does not support filter on top of AggregationNode.");
        }
        checkState(!hasFilter(), "Druid doesn't support filters at multiple levels under AggregationNode");
        return new DruidQueryGeneratorContext(
                selections,
                from,
                Optional.of(filter),
                limit,
                aggregations,
                groupByColumns,
                variablesInAggregation,
                hiddenColumnSet,
                tableScanNodeId);
    }

    public DruidQueryGeneratorContext withProject(Map<VariableReferenceExpression, Selection> newSelections)
    {
        return new DruidQueryGeneratorContext(
                newSelections,
                from,
                filter,
                limit,
                aggregations,
                groupByColumns,
                variablesInAggregation,
                hiddenColumnSet,
                tableScanNodeId);
    }

    public DruidQueryGeneratorContext withLimit(long limit)
    {
        if (limit <= 0 || limit > Long.MAX_VALUE) {
            throw new PrestoException(DRUID_QUERY_GENERATOR_FAILURE, "Invalid limit: " + limit);
        }
        checkState(!hasLimit(), "Limit already exists. Druid doesn't support limit on top of another limit");
        return new DruidQueryGeneratorContext(
                selections,
                from,
                filter,
                OptionalLong.of(limit),
                aggregations,
                groupByColumns,
                variablesInAggregation,
                hiddenColumnSet,
                tableScanNodeId);
    }

    public DruidQueryGeneratorContext withAggregation(
            Map<VariableReferenceExpression, Selection> newSelections,
            Map<VariableReferenceExpression, Selection> newGroupByColumns,
            int newAggregations,
            Set<VariableReferenceExpression> newHiddenColumnSet)
    {
        AtomicBoolean pushDownDistinctCount = new AtomicBoolean(false);
        newSelections.values().forEach(selection -> {
            if (selection.getDefinition().startsWith(DRUID_COUNT_DISTINCT_FUNCTION_NAME.toUpperCase(Locale.ENGLISH))) {
                pushDownDistinctCount.set(true);
            }
        });
        Map<VariableReferenceExpression, Selection> targetSelections = newSelections;
        if (pushDownDistinctCount.get()) {
            // Push down count distinct query to Druid, clean up hidden column set by the non-aggregation groupBy Plan.
            newHiddenColumnSet = ImmutableSet.of();
            ImmutableMap.Builder<VariableReferenceExpression, Selection> builder = ImmutableMap.builder();
            for (Map.Entry<VariableReferenceExpression, Selection> entry : newSelections.entrySet()) {
                if (entry.getValue().getDefinition().startsWith(DRUID_COUNT_DISTINCT_FUNCTION_NAME.toUpperCase(Locale.ENGLISH))) {
                    String definition = entry.getValue().getDefinition();
                    int start = definition.indexOf("(");
                    int end = definition.indexOf(")");
                    String countDistinctClause = "count ( distinct " + escapeSqlIdentifier(definition.substring(start + 1, end)) + ")";
                    Selection countDistinctSelection = new Selection(countDistinctClause, entry.getValue().getOrigin());
                    builder.put(entry.getKey(), countDistinctSelection);
                }
                else {
                    builder.put(entry.getKey(), entry.getValue());
                }
            }
            targetSelections = builder.build();
        }
        else {
            checkState(!hasAggregation(), "Druid doesn't support aggregation on top of the aggregated data");
        }
        checkState(!hasLimit(), "Druid doesn't support aggregation on top of the limit");
        checkState(newAggregations > 0, "Invalid number of aggregations");
        return new DruidQueryGeneratorContext(
                targetSelections,
                from,
                filter,
                limit,
                newAggregations,
                newGroupByColumns,
                variablesInAggregation,
                newHiddenColumnSet,
                tableScanNodeId);
    }

    private static String escapeSqlIdentifier(String identifier)
    {
        return "\"" + identifier + "\"";
    }

    public DruidQueryGeneratorContext withVariablesInAggregation(Set<VariableReferenceExpression> newVariablesInAggregation)
    {
        return new DruidQueryGeneratorContext(
                selections,
                from,
                filter,
                limit,
                aggregations,
                groupByColumns,
                newVariablesInAggregation,
                hiddenColumnSet,
                tableScanNodeId);
    }

    private boolean hasLimit()
    {
        return limit.isPresent();
    }

    private boolean hasFilter()
    {
        return filter.isPresent();
    }

    private boolean hasAggregation()
    {
        return aggregations > 0;
    }

    public Map<VariableReferenceExpression, Selection> getSelections()
    {
        return selections;
    }

    public Set<VariableReferenceExpression> getHiddenColumnSet()
    {
        return hiddenColumnSet;
    }

    Set<VariableReferenceExpression> getVariablesInAggregation()
    {
        return variablesInAggregation;
    }

    public Optional<PlanNodeId> getTableScanNodeId()
    {
        return tableScanNodeId;
    }

    public DruidQueryGenerator.GeneratedDql toQuery()
    {
        if (hasLimit() && aggregations > 1 && !groupByColumns.isEmpty()) {
            throw new PrestoException(DRUID_QUERY_GENERATOR_FAILURE, "Could not pushdown multiple aggregates in the presence of group by and limit");
        }

        String expressions = selections.entrySet().stream()
                .map(s -> s.getValue().getEscapedDefinition())
                .collect(Collectors.joining(", "));
        if (expressions.isEmpty()) {
            throw new PrestoException(DRUID_QUERY_GENERATOR_FAILURE, "Empty Druid query");
        }

        String tableName = from.orElseThrow(() -> new PrestoException(DRUID_QUERY_GENERATOR_FAILURE, "Table name missing in Druid query"));
        String query = "SELECT " + expressions + " FROM " + escapeSqlIdentifier(tableName);
        boolean pushdown = false;
        if (filter.isPresent()) {
            // this is hack!!!. Ideally we want to clone the scan pipeline and create/update the filter in the scan pipeline to contain this filter and
            // at the same time add the time column to scan so that the query generator doesn't fail when it looks up the time column in scan output columns
            query += " WHERE " + filter.get();
            pushdown = true;
        }

        if (!groupByColumns.isEmpty()) {
            String groupByExpression = groupByColumns.entrySet().stream().map(v -> v.getValue().getEscapedDefinition()).collect(Collectors.joining(", "));
            query = query + " GROUP BY " + groupByExpression;
            pushdown = true;
        }

        if (hasAggregation()) {
            pushdown = true;
        }

        if (limit.isPresent()) {
            query += " LIMIT " + limit.getAsLong();
            pushdown = true;
        }
        return new DruidQueryGenerator.GeneratedDql(tableName, query, pushdown);
    }

    public Map<VariableReferenceExpression, DruidColumnHandle> getAssignments()
    {
        Map<VariableReferenceExpression, DruidColumnHandle> result = new LinkedHashMap<>();
        selections.entrySet().stream().filter(e -> !hiddenColumnSet.contains(e.getKey())).forEach(entry -> {
            VariableReferenceExpression variable = entry.getKey();
            Selection selection = entry.getValue();
            DruidColumnHandle handle = selection.getOrigin() == Origin.TABLE_COLUMN ?
                    new DruidColumnHandle(selection.getDefinition(), variable.getType(), DruidColumnHandle.DruidColumnType.REGULAR) :
                    new DruidColumnHandle(variable, DruidColumnHandle.DruidColumnType.DERIVED);
            result.put(variable, handle);
        });
        return result;
    }

    public DruidQueryGeneratorContext withOutputColumns(List<VariableReferenceExpression> outputColumns)
    {
        Map<VariableReferenceExpression, Selection> newSelections = new LinkedHashMap<>();
        outputColumns.forEach(o -> newSelections.put(o, requireNonNull(selections.get(o), "Cannot find the selection " + o + " in the original context " + this)));
        selections.entrySet().stream().filter(e -> hiddenColumnSet.contains(e.getKey())).forEach(e -> newSelections.put(e.getKey(), e.getValue()));

        return new DruidQueryGeneratorContext(
                newSelections,
                from,
                filter,
                limit,
                aggregations,
                groupByColumns,
                variablesInAggregation,
                hiddenColumnSet,
                tableScanNodeId);
    }

    public enum Origin
    {
        TABLE_COLUMN, // refers to direct column in table
        DERIVED, // expression is derived from one or more input columns or a combination of input columns and literals
        LITERAL, // derived from literal
    }

    // Projected/selected column definition in query
    public static class Selection
    {
        private final String definition;
        private final Origin origin;

        public Selection(String definition, Origin origin)
        {
            this.definition = definition;
            this.origin = origin;
        }

        public String getDefinition()
        {
            return definition;
        }

        public String getEscapedDefinition()
        {
            if (origin == Origin.TABLE_COLUMN) {
                return escapeSqlIdentifier(definition);
            }
            return definition;
        }

        public Origin getOrigin()
        {
            return origin;
        }

        @Override
        public String toString()
        {
            return definition;
        }
    }
}