QueryRewriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.verifier.rewrite;

import com.facebook.presto.common.block.BlockEncodingSerde;
import com.facebook.presto.common.predicate.NullableValue;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.common.type.TypeSignature;
import com.facebook.presto.common.type.TypeSignatureParameter;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.LiteralEncoder;
import com.facebook.presto.sql.tree.AllColumns;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.ComparisonExpression.Operator;
import com.facebook.presto.sql.tree.CreateTable;
import com.facebook.presto.sql.tree.CreateTableAsSelect;
import com.facebook.presto.sql.tree.CreateView;
import com.facebook.presto.sql.tree.DropTable;
import com.facebook.presto.sql.tree.DropView;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.Insert;
import com.facebook.presto.sql.tree.IsNullPredicate;
import com.facebook.presto.sql.tree.LikeClause;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.Property;
import com.facebook.presto.sql.tree.QualifiedName;
import com.facebook.presto.sql.tree.Query;
import com.facebook.presto.sql.tree.QuerySpecification;
import com.facebook.presto.sql.tree.Select;
import com.facebook.presto.sql.tree.SelectItem;
import com.facebook.presto.sql.tree.ShowCreate;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.Statement;
import com.facebook.presto.verifier.framework.ClusterType;
import com.facebook.presto.verifier.framework.Column;
import com.facebook.presto.verifier.framework.QueryConfiguration;
import com.facebook.presto.verifier.framework.QueryException;
import com.facebook.presto.verifier.framework.QueryObjectBundle;
import com.facebook.presto.verifier.prestoaction.PrestoAction;
import com.facebook.presto.verifier.prestoaction.PrestoAction.ResultSetConverter;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;
import org.intellij.lang.annotations.Language;
import org.joda.time.DateTimeZone;

import java.sql.ResultSetMetaData;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TimeZone;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.RowType.Field;
import static com.facebook.presto.common.type.StandardTypes.MAP;
import static com.facebook.presto.common.type.TimeType.TIME;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.hive.HiveUtil.parsePartitionValue;
import static com.facebook.presto.hive.metastore.MetastoreUtil.toPartitionNamesAndValues;
import static com.facebook.presto.sql.tree.LikeClause.PropertiesOption.INCLUDING;
import static com.facebook.presto.sql.tree.ShowCreate.Type.VIEW;
import static com.facebook.presto.verifier.framework.CreateViewVerification.SHOW_CREATE_VIEW_CONVERTER;
import static com.facebook.presto.verifier.framework.DataVerificationUtil.getColumns;
import static com.facebook.presto.verifier.framework.QueryStage.REWRITE;
import static com.facebook.presto.verifier.framework.VerifierUtil.PARSING_OPTIONS;
import static com.facebook.presto.verifier.framework.VerifierUtil.getColumnNames;
import static com.facebook.presto.verifier.framework.VerifierUtil.getColumnTypes;
import static com.facebook.presto.verifier.rewrite.FunctionCallRewriter.FunctionCallSubstitute;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static com.google.common.collect.Iterables.getOnlyElement;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Map.Entry;
import static java.util.Objects.requireNonNull;
import static java.util.UUID.randomUUID;

public class QueryRewriter
{
    private final SqlParser sqlParser;
    private final TypeManager typeManager;
    private final BlockEncodingSerde blockEncodingSerde;
    private final PrestoAction prestoAction;
    private final Map<ClusterType, QualifiedName> prefixes;
    private final Map<ClusterType, List<Property>> tableProperties;
    private final Map<ClusterType, Boolean> reuseTables;
    private final Optional<FunctionCallRewriter> functionCallRewriter;

    public QueryRewriter(
            SqlParser sqlParser,
            TypeManager typeManager,
            BlockEncodingSerde blockEncodingSerde,
            PrestoAction prestoAction,
            Map<ClusterType, QualifiedName> tablePrefixes,
            Map<ClusterType, List<Property>> tableProperties,
            Map<ClusterType, Boolean> reuseTables)
    {
        this(sqlParser, typeManager, blockEncodingSerde, prestoAction, tablePrefixes, tableProperties, reuseTables, ImmutableMultimap.of());
    }

    public QueryRewriter(
            SqlParser sqlParser,
            TypeManager typeManager,
            BlockEncodingSerde blockEncodingSerde,
            PrestoAction prestoAction,
            Map<ClusterType, QualifiedName> tablePrefixes,
            Map<ClusterType, List<Property>> tableProperties,
            Map<ClusterType, Boolean> reuseTables,
            Multimap<String, FunctionCallSubstitute> functionSubstitutes)
    {
        this.sqlParser = requireNonNull(sqlParser, "sqlParser is null");
        this.typeManager = requireNonNull(typeManager, "typeManager is null");
        this.blockEncodingSerde = requireNonNull(blockEncodingSerde, "blockEncodingSerge");
        this.prestoAction = requireNonNull(prestoAction, "prestoAction is null");
        this.prefixes = ImmutableMap.copyOf(tablePrefixes);
        this.tableProperties = ImmutableMap.copyOf(tableProperties);
        this.reuseTables = ImmutableMap.copyOf(reuseTables);
        this.functionCallRewriter = FunctionCallRewriter.getInstance(functionSubstitutes, typeManager);
    }

    public QueryObjectBundle rewriteQuery(@Language("SQL") String query, QueryConfiguration queryConfiguration, ClusterType clusterType)
    {
        return rewriteQuery(query, queryConfiguration, clusterType, false);
    }

    public QueryObjectBundle rewriteQuery(@Language("SQL") String query, QueryConfiguration queryConfiguration, ClusterType clusterType, boolean reuseTable)
    {
        checkState(prefixes.containsKey(clusterType), "Unsupported cluster type: %s", clusterType);
        Statement statement = sqlParser.createStatement(query, PARSING_OPTIONS);

        QualifiedName prefix = prefixes.get(clusterType);
        List<Property> properties = tableProperties.get(clusterType);
        boolean shouldReuseTable = reuseTable && reuseTables.get(clusterType) && queryConfiguration.isReusableTable();

        if (statement instanceof CreateTableAsSelect) {
            CreateTableAsSelect createTableAsSelect = (CreateTableAsSelect) statement;
            Query createQuery = createTableAsSelect.getQuery();

            Optional<String> functionSubstitutions = Optional.empty();
            if (functionCallRewriter.isPresent()) {
                FunctionCallRewriter.RewriterResult rewriterResult = functionCallRewriter.get().rewrite(createQuery);
                createQuery = (Query) rewriterResult.getRewrittenNode();
                functionSubstitutions = rewriterResult.getSubstitutions();
            }
            if (shouldReuseTable && !functionSubstitutions.isPresent()) {
                Optional<Expression> partitionsPredicate = getPartitionsPredicate(createTableAsSelect.getName(), queryConfiguration.getPartitions());
                if (partitionsPredicate.isPresent()) {
                    return new QueryObjectBundle(
                            createTableAsSelect.getName(),
                            ImmutableList.of(),
                            createTableAsSelect,
                            ImmutableList.of(),
                            clusterType,
                            Optional.empty(),
                            partitionsPredicate,
                            true);
                }
            }

            QualifiedName temporaryTableName = generateTemporaryName(Optional.of(createTableAsSelect.getName()), prefix);
            return new QueryObjectBundle(
                    temporaryTableName,
                    ImmutableList.of(),
                    new CreateTableAsSelect(
                            temporaryTableName,
                            createQuery,
                            createTableAsSelect.isNotExists(),
                            applyPropertyOverride(createTableAsSelect.getProperties(), properties),
                            createTableAsSelect.isWithData(),
                            createTableAsSelect.getColumnAliases(),
                            createTableAsSelect.getComment()),
                    ImmutableList.of(new DropTable(temporaryTableName, true)),
                    clusterType,
                    functionSubstitutions,
                    Optional.empty(),
                    false);
        }
        if (statement instanceof Insert) {
            Insert insert = (Insert) statement;
            QualifiedName originalTableName = insert.getTarget();
            Query insertQuery = insert.getQuery();

            Optional<String> functionSubstitutions = Optional.empty();
            if (functionCallRewriter.isPresent()) {
                FunctionCallRewriter.RewriterResult rewriterResult = functionCallRewriter.get().rewrite(insertQuery);
                insertQuery = (Query) rewriterResult.getRewrittenNode();
                functionSubstitutions = rewriterResult.getSubstitutions();
            }
            if (shouldReuseTable && !functionSubstitutions.isPresent()) {
                Optional<Expression> partitionsPredicate = getPartitionsPredicate(originalTableName, queryConfiguration.getPartitions());
                if (partitionsPredicate.isPresent()) {
                    return new QueryObjectBundle(
                            originalTableName,
                            ImmutableList.of(),
                            insert,
                            ImmutableList.of(),
                            clusterType,
                            Optional.empty(),
                            partitionsPredicate,
                            true);
                }
            }

            QualifiedName temporaryTableName = generateTemporaryName(Optional.of(originalTableName), prefix);
            return new QueryObjectBundle(
                    temporaryTableName,
                    ImmutableList.of(
                            new CreateTable(
                                    temporaryTableName,
                                    ImmutableList.of(new LikeClause(originalTableName, Optional.of(INCLUDING))),
                                    false,
                                    properties,
                                    Optional.empty())),
                    new Insert(
                            temporaryTableName,
                            insert.getColumns(),
                            insertQuery),
                    ImmutableList.of(new DropTable(temporaryTableName, true)),
                    clusterType,
                    functionSubstitutions,
                    Optional.empty(),
                    false);
        }
        if (statement instanceof Query) {
            Query queryBody = (Query) statement;

            Optional<String> functionSubstitutions = Optional.empty();
            if (functionCallRewriter.isPresent()) {
                FunctionCallRewriter.RewriterResult rewriterResult = functionCallRewriter.get().rewrite(queryBody);
                queryBody = (Query) rewriterResult.getRewrittenNode();
                functionSubstitutions = rewriterResult.getSubstitutions();
            }

            QualifiedName temporaryTableName = generateTemporaryName(Optional.empty(), prefix);
            ResultSetMetaData metadata = getResultMetadata(queryBody);
            List<Identifier> columnAliases = generateStorageColumnAliases(metadata);
            queryBody = rewriteNonStorableColumns(queryBody, metadata);
            return new QueryObjectBundle(
                    temporaryTableName,
                    ImmutableList.of(),
                    new CreateTableAsSelect(
                            temporaryTableName,
                            queryBody,
                            false,
                            properties,
                            true,
                            Optional.of(columnAliases),
                            Optional.empty()),
                    ImmutableList.of(new DropTable(temporaryTableName, true)),
                    clusterType,
                    functionSubstitutions,
                    Optional.empty(),
                    false);
        }
        if (statement instanceof CreateView) {
            CreateView createView = (CreateView) statement;
            QualifiedName temporaryViewName = generateTemporaryName(Optional.empty(), prefix);
            ImmutableList.Builder<Statement> setupQueries = ImmutableList.builder();

            // Check to see if there is an existing view with the specified view name.
            // If view exists, create a temporary view that are has the same definition as the existing view.
            // Otherwise, do not pre-create temporary view.
            try {
                String createExistingViewQuery = getOnlyElement(prestoAction.execute(
                        new ShowCreate(VIEW, createView.getName()),
                        REWRITE,
                        SHOW_CREATE_VIEW_CONVERTER).getResults());
                CreateView createExistingView = (CreateView) sqlParser.createStatement(createExistingViewQuery, PARSING_OPTIONS);
                setupQueries.add(new CreateView(
                        temporaryViewName,
                        createExistingView.getQuery(),
                        false,
                        createExistingView.getSecurity()));
            }
            catch (QueryException e) {
                // no-op
            }
            return new QueryObjectBundle(
                    temporaryViewName,
                    setupQueries.build(),
                    new CreateView(
                            temporaryViewName,
                            createView.getQuery(),
                            createView.isReplace(),
                            createView.getSecurity()),
                    ImmutableList.of(new DropView(temporaryViewName, true)),
                    clusterType,
                    Optional.empty(),
                    Optional.empty(),
                    false);
        }
        if (statement instanceof CreateTable) {
            CreateTable createTable = (CreateTable) statement;
            QualifiedName temporaryTableName = generateTemporaryName(Optional.empty(), prefix);
            return new QueryObjectBundle(
                    temporaryTableName,
                    ImmutableList.of(),
                    new CreateTable(
                            temporaryTableName,
                            createTable.getElements(),
                            createTable.isNotExists(),
                            applyPropertyOverride(createTable.getProperties(), properties),
                            createTable.getComment()),
                    ImmutableList.of(new DropTable(temporaryTableName, true)),
                    clusterType,
                    Optional.empty(),
                    Optional.empty(),
                    false);
        }

        throw new IllegalStateException(format("Unsupported query type: %s", statement.getClass()));
    }

    private QualifiedName generateTemporaryName(Optional<QualifiedName> originalName, QualifiedName prefix)
    {
        List<String> parts = new ArrayList<>();
        int originalSize = originalName.map(QualifiedName::getOriginalParts).map(List::size).orElse(0);
        int prefixSize = prefix.getOriginalParts().size();
        if (originalName.isPresent() && originalSize > prefixSize) {
            parts.addAll(originalName.get().getOriginalParts().subList(0, originalSize - prefixSize));
        }
        parts.addAll(prefix.getOriginalParts());
        parts.set(parts.size() - 1, prefix.getOriginalSuffix() + "_" + randomUUID().toString().replace("-", ""));
        return QualifiedName.of(parts);
    }

    private List<Identifier> generateStorageColumnAliases(ResultSetMetaData metadata)
    {
        ImmutableList.Builder<Identifier> aliases = ImmutableList.builder();
        Set<String> usedAliases = new HashSet<>();

        for (String columnName : getColumnNames(metadata)) {
            columnName = sanitizeColumnName(columnName);
            String alias = columnName;
            int postfix = 1;
            while (usedAliases.contains(alias)) {
                alias = format("%s__%s", columnName, postfix);
                postfix++;
            }
            aliases.add(new Identifier(alias, true));
            usedAliases.add(alias);
        }
        return aliases.build();
    }

    private ResultSetMetaData getResultMetadata(Query query)
    {
        Query zeroRowQuery;
        if (query.getQueryBody() instanceof QuerySpecification) {
            QuerySpecification querySpecification = (QuerySpecification) query.getQueryBody();
            zeroRowQuery = new Query(
                    query.getWith(),
                    new QuerySpecification(
                            querySpecification.getSelect(),
                            querySpecification.getFrom(),
                            querySpecification.getWhere(),
                            querySpecification.getGroupBy(),
                            querySpecification.getHaving(),
                            querySpecification.getOrderBy(),
                            querySpecification.getOffset(),
                            Optional.of("0")),
                    Optional.empty(),
                    Optional.empty(),
                    Optional.empty());
        }
        else {
            zeroRowQuery = new Query(query.getWith(), query.getQueryBody(), Optional.empty(), Optional.empty(), Optional.of("0"));
        }
        return prestoAction.execute(zeroRowQuery, REWRITE, ResultSetConverter.DEFAULT).getMetadata();
    }

    private Query rewriteNonStorableColumns(Query query, ResultSetMetaData metadata)
    {
        // Skip if all columns are storable
        List<Type> columnTypes = getColumnTypes(typeManager, metadata);
        if (columnTypes.stream().noneMatch(type -> getColumnTypeRewrite(type).isPresent())) {
            return query;
        }

        // Cannot handle SELECT query with top-level SetOperation
        if (!(query.getQueryBody() instanceof QuerySpecification)) {
            return query;
        }

        QuerySpecification querySpecification = (QuerySpecification) query.getQueryBody();
        List<SelectItem> selectItems = querySpecification.getSelect().getSelectItems();
        // Cannot handle SELECT *
        if (selectItems.stream().anyMatch(AllColumns.class::isInstance)) {
            return query;
        }

        List<SelectItem> newItems = new ArrayList<>();
        checkState(selectItems.size() == columnTypes.size(), "SelectItem count (%s) mismatches column count (%s)", selectItems.size(), columnTypes.size());
        for (int i = 0; i < selectItems.size(); i++) {
            SingleColumn singleColumn = (SingleColumn) selectItems.get(i);
            Optional<Type> columnTypeRewrite = getColumnTypeRewrite(columnTypes.get(i));
            if (columnTypeRewrite.isPresent()) {
                newItems.add(new SingleColumn(new Cast(singleColumn.getExpression(), columnTypeRewrite.get().getTypeSignature().toString()), singleColumn.getAlias()));
            }
            else {
                newItems.add(singleColumn);
            }
        }

        return new Query(
                query.getWith(),
                new QuerySpecification(
                        new Select(querySpecification.getSelect().isDistinct(), newItems),
                        querySpecification.getFrom(),
                        querySpecification.getWhere(),
                        querySpecification.getGroupBy(),
                        querySpecification.getHaving(),
                        querySpecification.getOrderBy(),
                        Optional.empty(),
                        querySpecification.getLimit()),
                query.getOrderBy(),
                Optional.empty(),
                query.getLimit());
    }

    private Optional<Type> getColumnTypeRewrite(Type type)
    {
        if (type.equals(DATE) || type.equals(TIME)) {
            return Optional.of(TIMESTAMP);
        }
        if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
            return Optional.of(VARCHAR);
        }
        if (type.equals(UNKNOWN)) {
            return Optional.of(BIGINT);
        }
        if (type instanceof DecimalType) {
            return Optional.of(DOUBLE);
        }
        if (type instanceof ArrayType) {
            return getColumnTypeRewrite(((ArrayType) type).getElementType()).map(ArrayType::new);
        }
        if (type instanceof MapType) {
            Type keyType = ((MapType) type).getKeyType();
            Type valueType = ((MapType) type).getValueType();
            Optional<Type> keyTypeRewrite = getColumnTypeRewrite(keyType);
            Optional<Type> valueTypeRewrite = getColumnTypeRewrite(valueType);
            if (keyTypeRewrite.isPresent() || valueTypeRewrite.isPresent()) {
                return Optional.of(typeManager.getType(new TypeSignature(
                        MAP,
                        TypeSignatureParameter.of(keyTypeRewrite.orElse(keyType).getTypeSignature()),
                        TypeSignatureParameter.of(valueTypeRewrite.orElse(valueType).getTypeSignature()))));
            }
            return Optional.empty();
        }
        if (type instanceof RowType) {
            List<Field> fields = ((RowType) type).getFields();
            List<Field> fieldsRewrite = new ArrayList<>();
            boolean rewrite = false;
            for (Field field : fields) {
                Optional<Type> fieldTypeRewrite = getColumnTypeRewrite(field.getType());
                rewrite = rewrite || fieldTypeRewrite.isPresent();
                fieldsRewrite.add(new Field(field.getName(), fieldTypeRewrite.orElse(field.getType())));
            }
            return rewrite ? Optional.of(RowType.from(fieldsRewrite)) : Optional.empty();
        }
        return Optional.empty();
    }

    private static String sanitizeColumnName(String columnName)
    {
        return columnName.replaceAll("[^a-zA-Z0-9_]", "_").toLowerCase(ENGLISH);
    }

    private static List<Property> applyPropertyOverride(List<Property> properties, List<Property> overrides)
    {
        Map<String, Expression> propertyMap = properties.stream()
                .collect(toImmutableMap(property -> property.getName().getValueLowerCase(), Property::getValue));
        Map<String, Expression> overrideMap = overrides.stream()
                .collect(toImmutableMap(property -> property.getName().getValueLowerCase(), Property::getValue));
        return Stream.concat(propertyMap.entrySet().stream(), overrideMap.entrySet().stream())
                .collect(Collectors.toMap(Entry::getKey, Entry::getValue, (original, override) -> override))
                .entrySet()
                .stream()
                .map(entry -> new Property(new Identifier(entry.getKey()), entry.getValue()))
                .collect(toImmutableList());
    }

    private Optional<Expression> getPartitionsPredicate(QualifiedName tableName, List<String> partitions)
    {
        if (partitions.isEmpty()) {
            return Optional.empty();
        }

        List<Column> columns = getColumns(prestoAction, typeManager, tableName);

        Expression disjunct = null;
        for (String partition : partitions) {
            Optional<Expression> conjunct = Optional.empty();
            try {
                conjunct = getPartitionConjunct(partition, columns);
            }
            catch (Exception e) {
            }
            if (!conjunct.isPresent()) {
                return Optional.empty();
            }
            disjunct = disjunct == null ? conjunct.get() : new LogicalBinaryExpression(LogicalBinaryExpression.Operator.OR, disjunct, conjunct.get());
        }

        return Optional.ofNullable(disjunct);
    }

    private Optional<Expression> getPartitionConjunct(String partition, List<Column> columns)
    {
        Expression conjunct = null;
        Map<String, String> partitionRawKeyValues = toPartitionNamesAndValues(partition);
        // TryCatch
        for (String partitionKey : partitionRawKeyValues.keySet()) {
            Type type = null;
            for (Column column : columns) {
                if (column.getName().equals(partitionKey)) {
                    type = column.getType();
                    break;
                }
            }
            if (type == null) {
                // LOG
                return Optional.empty();
            }

            NullableValue partitionValue = parsePartitionValue(partitionKey, partitionRawKeyValues.get(partitionKey), type, DateTimeZone.forTimeZone(TimeZone.getDefault()));

            Expression equalPredicate = null;
            if (partitionValue.isNull()) {
                equalPredicate = new IsNullPredicate(new Identifier(partitionKey));
            }
            else {
                LiteralEncoder literalEncoder = new LiteralEncoder(blockEncodingSerde);
                equalPredicate = new ComparisonExpression(Operator.EQUAL, new Identifier(partitionKey),
                        literalEncoder.toExpression(partitionValue.getValue(), partitionValue.getType(), false));
            }
            conjunct = conjunct == null ? equalPredicate : new LogicalBinaryExpression(LogicalBinaryExpression.Operator.AND, conjunct, equalPredicate);
        }

        return Optional.ofNullable(conjunct);
    }
}