ScanQueryPageSource.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.elasticsearch;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.PageBuilderStatus;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.elasticsearch.client.ElasticsearchClient;
import com.facebook.presto.elasticsearch.decoders.ArrayDecoder;
import com.facebook.presto.elasticsearch.decoders.BigintDecoder;
import com.facebook.presto.elasticsearch.decoders.BooleanDecoder;
import com.facebook.presto.elasticsearch.decoders.Decoder;
import com.facebook.presto.elasticsearch.decoders.DoubleDecoder;
import com.facebook.presto.elasticsearch.decoders.IdColumnDecoder;
import com.facebook.presto.elasticsearch.decoders.IntegerDecoder;
import com.facebook.presto.elasticsearch.decoders.IpAddressDecoder;
import com.facebook.presto.elasticsearch.decoders.RealDecoder;
import com.facebook.presto.elasticsearch.decoders.RowDecoder;
import com.facebook.presto.elasticsearch.decoders.ScoreColumnDecoder;
import com.facebook.presto.elasticsearch.decoders.SmallintDecoder;
import com.facebook.presto.elasticsearch.decoders.SourceColumnDecoder;
import com.facebook.presto.elasticsearch.decoders.TimestampDecoder;
import com.facebook.presto.elasticsearch.decoders.TinyintDecoder;
import com.facebook.presto.elasticsearch.decoders.VarbinaryDecoder;
import com.facebook.presto.elasticsearch.decoders.VarcharDecoder;
import com.facebook.presto.spi.ConnectorPageSource;
import com.facebook.presto.spi.ConnectorSession;
import com.google.common.collect.AbstractIterator;
import com.google.common.collect.ImmutableList;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Supplier;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.elasticsearch.BuiltinColumns.ID;
import static com.facebook.presto.elasticsearch.BuiltinColumns.SCORE;
import static com.facebook.presto.elasticsearch.BuiltinColumns.SOURCE;
import static com.facebook.presto.elasticsearch.ElasticsearchQueryBuilder.buildSearchQuery;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;
import static java.util.function.Predicate.isEqual;
import static java.util.stream.Collectors.toList;

public class ScanQueryPageSource
        implements ConnectorPageSource
{
    private static final Logger LOG = Logger.get(ScanQueryPageSource.class);

    private final List<Decoder> decoders;

    private final SearchHitIterator iterator;
    private final BlockBuilder[] columnBuilders;
    private final List<ElasticsearchColumnHandle> columns;
    private long totalBytes;
    private long readTimeNanos;
    private long completedPositions;

    public ScanQueryPageSource(
            ElasticsearchClient client,
            ConnectorSession session,
            ElasticsearchTableHandle table,
            ElasticsearchSplit split,
            List<ElasticsearchColumnHandle> columns)
    {
        requireNonNull(client, "client is null");
        requireNonNull(columns, "columns is null");

        this.columns = ImmutableList.copyOf(columns);

        decoders = createDecoders(session, columns);

        // When the _source field is requested, we need to bypass column pruning when fetching the document
        boolean needAllFields = columns.stream()
                .map(ElasticsearchColumnHandle::getName)
                .anyMatch(isEqual(SOURCE.getName()));

        // Columns to fetch as doc_fields instead of pulling them out of the JSON source
        // This is convenient for types such as DATE, TIMESTAMP, etc, which have multiple possible
        // representations in JSON, but a single normalized representation as doc_field.
        List<String> documentFields = flattenFields(columns).entrySet().stream()
                .filter(entry -> entry.getValue().equals(TIMESTAMP))
                .map(Map.Entry::getKey)
                .collect(toImmutableList());

        columnBuilders = columns.stream()
                .map(ElasticsearchColumnHandle::getType)
                .map(type -> type.createBlockBuilder(null, 1))
                .toArray(BlockBuilder[]::new);

        List<String> requiredFields = columns.stream()
                .map(ElasticsearchColumnHandle::getName)
                .filter(name -> !BuiltinColumns.NAMES.contains(name))
                .collect(toList());

        // sorting by _doc (index order) get special treatment in Elasticsearch and is more efficient
        Optional<String> sort = Optional.of("_doc");

        if (table.getQuery().isPresent()) {
            // However, if we're using a custom Elasticsearch query, use default sorting.
            // Documents will be scored and returned based on relevance
            sort = Optional.empty();
        }

        long start = System.nanoTime();
        SearchResponse searchResponse = client.beginSearch(
                split.getIndex(),
                split.getShard(),
                buildSearchQuery(session, split.getTupleDomain().transform(ElasticsearchColumnHandle.class::cast), table.getQuery()),
                needAllFields ? Optional.empty() : Optional.of(requiredFields),
                documentFields,
                sort);
        readTimeNanos += System.nanoTime() - start;
        this.iterator = new SearchHitIterator(client, () -> searchResponse);
    }

    @Override
    public long getCompletedBytes()
    {
        return totalBytes;
    }

    @Override
    public long getCompletedPositions()
    {
        return completedPositions;
    }

    @Override
    public long getReadTimeNanos()
    {
        return readTimeNanos + iterator.getReadTimeNanos();
    }

    @Override
    public boolean isFinished()
    {
        return !iterator.hasNext();
    }

    @Override
    public long getSystemMemoryUsage()
    {
        return 0;
    }

    @Override
    public void close()
    {
        iterator.close();
    }

    @Override
    public Page getNextPage()
    {
        long size = 0;
        while (size < PageBuilderStatus.DEFAULT_MAX_PAGE_SIZE_IN_BYTES && iterator.hasNext()) {
            SearchHit hit = iterator.next();
            Map<String, Object> document = hit.getSourceAsMap();

            for (int i = 0; i < decoders.size(); i++) {
                String field = columns.get(i).getName();
                decoders.get(i).decode(hit, () -> getField(document, field), columnBuilders[i]);
            }

            if (hit.getSourceRef() != null) {
                totalBytes += hit.getSourceRef().length();
            }

            completedPositions += 1;
            size = Arrays.stream(columnBuilders)
                    .mapToLong(BlockBuilder::getSizeInBytes)
                    .sum();
        }

        Block[] blocks = new Block[columnBuilders.length];
        for (int i = 0; i < columnBuilders.length; i++) {
            blocks[i] = columnBuilders[i].build();
            columnBuilders[i] = columnBuilders[i].newBlockBuilderLike(null);
        }

        return new Page(blocks);
    }

    public static Object getField(Map<String, Object> document, String field)
    {
        Object value = document.get(field);
        if (value == null) {
            Map<String, Object> result = new HashMap<>();
            String prefix = field + ".";
            for (Map.Entry<String, Object> entry : document.entrySet()) {
                String key = entry.getKey();
                if (key.startsWith(prefix)) {
                    result.put(key.substring(prefix.length()), entry.getValue());
                }
            }

            if (!result.isEmpty()) {
                return result;
            }
        }

        return value;
    }

    private Map<String, Type> flattenFields(List<ElasticsearchColumnHandle> columns)
    {
        Map<String, Type> result = new HashMap<>();

        for (ElasticsearchColumnHandle column : columns) {
            flattenFields(result, column.getName(), column.getType());
        }

        return result;
    }

    private void flattenFields(Map<String, Type> result, String fieldName, Type type)
    {
        if (type instanceof RowType) {
            for (RowType.Field field : ((RowType) type).getFields()) {
                flattenFields(result, appendPath(fieldName, field.getName().get()), field.getType());
            }
        }
        else {
            result.put(fieldName, type);
        }
    }

    private List<Decoder> createDecoders(ConnectorSession session, List<ElasticsearchColumnHandle> columns)
    {
        return columns.stream()
                .map(column -> {
                    if (column.getName().equals(ID.getName())) {
                        return new IdColumnDecoder();
                    }

                    if (column.getName().equals(SCORE.getName())) {
                        return new ScoreColumnDecoder();
                    }

                    if (column.getName().equals(SOURCE.getName())) {
                        return new SourceColumnDecoder();
                    }

                    return createDecoder(session, column.getName(), column.getType());
                })
                .collect(toImmutableList());
    }

    private Decoder createDecoder(ConnectorSession session, String path, Type type)
    {
        if (type.equals(VARCHAR)) {
            return new VarcharDecoder(path);
        }
        else if (type.equals(VARBINARY)) {
            return new VarbinaryDecoder(path);
        }
        else if (type.equals(TIMESTAMP)) {
            return new TimestampDecoder(session, path);
        }
        else if (type.equals(BOOLEAN)) {
            return new BooleanDecoder(path);
        }
        else if (type.equals(DOUBLE)) {
            return new DoubleDecoder(path);
        }
        else if (type.equals(REAL)) {
            return new RealDecoder(path);
        }
        else if (type.equals(TINYINT)) {
            return new TinyintDecoder(path);
        }
        else if (type.equals(SMALLINT)) {
            return new SmallintDecoder(path);
        }
        else if (type.equals(INTEGER)) {
            return new IntegerDecoder(path);
        }
        else if (type.equals(BIGINT)) {
            return new BigintDecoder(path);
        }
        else if (type.getTypeSignature().getBase().equals(StandardTypes.IPADDRESS)) {
            return new IpAddressDecoder(path, type);
        }
        else if (type instanceof RowType) {
            RowType rowType = (RowType) type;

            List<Decoder> decoders = rowType.getFields().stream()
                    .map(field -> createDecoder(session, appendPath(path, field.getName().get()), field.getType()))
                    .collect(toImmutableList());

            List<String> fieldNames = rowType.getFields().stream()
                    .map(RowType.Field::getName)
                    .map(Optional::get)
                    .collect(toImmutableList());

            return new RowDecoder(path, fieldNames, decoders);
        }
        else if (type instanceof ArrayType) {
            Type elementType = ((ArrayType) type).getElementType();

            return new ArrayDecoder(path, createDecoder(session, path, elementType));
        }

        throw new UnsupportedOperationException("Type not supported: " + type);
    }

    private static String appendPath(String base, String element)
    {
        if (base.isEmpty()) {
            return element;
        }

        return base + "." + element;
    }

    private static class SearchHitIterator
            extends AbstractIterator<SearchHit>
    {
        private final ElasticsearchClient client;
        private final Supplier<SearchResponse> first;

        private SearchHits searchHits;
        private String scrollId;
        private int currentPosition;

        private long readTimeNanos;

        public SearchHitIterator(ElasticsearchClient client, Supplier<SearchResponse> first)
        {
            this.client = client;
            this.first = first;
        }

        public long getReadTimeNanos()
        {
            return readTimeNanos;
        }

        @Override
        protected SearchHit computeNext()
        {
            if (scrollId == null) {
                long start = System.nanoTime();
                SearchResponse response = first.get();
                readTimeNanos += System.nanoTime() - start;
                reset(response);
            }
            else if (currentPosition == searchHits.getHits().length) {
                long start = System.nanoTime();
                SearchResponse response = client.nextPage(scrollId);
                readTimeNanos += System.nanoTime() - start;
                reset(response);
            }

            if (currentPosition == searchHits.getHits().length) {
                return endOfData();
            }

            SearchHit hit = searchHits.getAt(currentPosition);
            currentPosition++;

            return hit;
        }

        private void reset(SearchResponse response)
        {
            scrollId = response.getScrollId();
            searchHits = response.getHits();
            currentPosition = 0;
        }

        public void close()
        {
            if (scrollId != null) {
                try {
                    client.clearScroll(scrollId);
                }
                catch (Exception e) {
                    // ignore
                    LOG.debug("Error clearing scroll", e);
                }
            }
        }
    }
}