CachingStripeMetadataSource.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc;

import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.orc.StripeReader.StripeId;
import com.facebook.presto.orc.StripeReader.StripeStreamId;
import com.facebook.presto.orc.metadata.MetadataReader;
import com.facebook.presto.orc.metadata.PostScript.HiveWriterVersion;
import com.facebook.presto.orc.metadata.RowGroupIndex;
import com.facebook.presto.orc.metadata.Stream.StreamKind;
import com.facebook.presto.orc.metadata.statistics.HiveBloomFilter;
import com.facebook.presto.orc.stream.OrcInputStream;
import com.google.common.cache.Cache;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.UncheckedExecutionException;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;

import static com.facebook.presto.common.RuntimeUnit.BYTE;
import static com.facebook.presto.common.RuntimeUnit.NONE;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.BLOOM_FILTER;
import static com.facebook.presto.orc.metadata.Stream.StreamKind.ROW_INDEX;
import static com.google.common.base.Throwables.throwIfInstanceOf;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public class CachingStripeMetadataSource
        implements StripeMetadataSource
{
    private final StripeMetadataSource delegate;
    private final Cache<StripeId, CacheableSlice> footerSliceCache;
    private final Cache<StripeStreamId, CacheableSlice> stripeStreamCache;
    private final Optional<Cache<StripeStreamId, CacheableRowGroupIndices>> rowGroupIndexCache;

    public CachingStripeMetadataSource(StripeMetadataSource delegate, Cache<StripeId, CacheableSlice> footerSliceCache, Cache<StripeStreamId, CacheableSlice> stripeStreamCache, Optional<Cache<StripeStreamId, CacheableRowGroupIndices>> rowGroupIndexCache)
    {
        this.delegate = requireNonNull(delegate, "delegate is null");
        this.footerSliceCache = requireNonNull(footerSliceCache, "footerSliceCache is null");
        this.stripeStreamCache = requireNonNull(stripeStreamCache, "rowIndexSliceCache is null");
        this.rowGroupIndexCache = requireNonNull(rowGroupIndexCache, "rowGroupIndexCache is null");
    }

    @Override
    public Slice getStripeFooterSlice(OrcDataSource orcDataSource, StripeId stripeId, long footerOffset, int footerLength, boolean cacheable, long fileModificationTime)
            throws IOException
    {
        if (!cacheable) {
            return delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable, fileModificationTime);
        }
        try {
            CacheableSlice cacheableSlice = footerSliceCache.getIfPresent(stripeId);
            if (cacheableSlice != null) {
                if (cacheableSlice.getFileModificationTime() == fileModificationTime) {
                    return cacheableSlice.getSlice();
                }
                footerSliceCache.invalidate(stripeId);
                // This get call is to increment the miss count for invalidated entries so the stats are recorded correctly.
                footerSliceCache.getIfPresent(stripeId);
            }
            cacheableSlice = new CacheableSlice(delegate.getStripeFooterSlice(orcDataSource, stripeId, footerOffset, footerLength, cacheable, fileModificationTime), fileModificationTime);
            footerSliceCache.put(stripeId, cacheableSlice);
            return cacheableSlice.getSlice();
        }
        catch (UncheckedExecutionException e) {
            throwIfInstanceOf(e.getCause(), IOException.class);
            throw new IOException("Unexpected error in stripe footer reading after footerSliceCache miss", e.getCause());
        }
    }

    @Override
    public Map<StreamId, OrcDataSourceInput> getInputs(OrcDataSource orcDataSource, StripeId stripeId, Map<StreamId, DiskRange> diskRanges, boolean cacheable, long fileModificationTime)
            throws IOException
    {
        if (!cacheable) {
            return delegate.getInputs(orcDataSource, stripeId, diskRanges, cacheable, fileModificationTime);
        }

        // Fetch existing stream slice from cache
        ImmutableMap.Builder<StreamId, OrcDataSourceInput> inputsBuilder = ImmutableMap.builder();
        ImmutableMap.Builder<StreamId, DiskRange> uncachedDiskRangesBuilder = ImmutableMap.builder();
        for (Entry<StreamId, DiskRange> entry : diskRanges.entrySet()) {
            StripeStreamId stripeStreamId = new StripeStreamId(stripeId, entry.getKey());
            if (isCachedStream(entry.getKey().getStreamKind())) {
                CacheableSlice streamSlice = stripeStreamCache.getIfPresent(stripeStreamId);
                if (streamSlice != null && streamSlice.getFileModificationTime() == fileModificationTime) {
                    inputsBuilder.put(entry.getKey(), new OrcDataSourceInput(new BasicSliceInput(streamSlice.getSlice()), streamSlice.getSlice().length()));
                }
                else {
                    if (streamSlice != null) {
                        stripeStreamCache.invalidate(stripeStreamId);
                        // This get call is to increment the miss count for invalidated entries so the stats are recorded correctly.
                        stripeStreamCache.getIfPresent(stripeStreamId);
                    }
                    uncachedDiskRangesBuilder.put(entry);
                }
            }
            else {
                uncachedDiskRangesBuilder.put(entry);
            }
        }

        // read ranges and update cache
        Map<StreamId, OrcDataSourceInput> uncachedInputs = delegate.getInputs(orcDataSource, stripeId, uncachedDiskRangesBuilder.build(), cacheable, fileModificationTime);
        for (Entry<StreamId, OrcDataSourceInput> entry : uncachedInputs.entrySet()) {
            if (isCachedStream(entry.getKey().getStreamKind())) {
                // We need to rewind the input after eagerly reading the slice.
                Slice streamSlice = Slices.wrappedBuffer(entry.getValue().getInput().readSlice(toIntExact(entry.getValue().getInput().length())).getBytes());
                stripeStreamCache.put(new StripeStreamId(stripeId, entry.getKey()), new CacheableSlice(streamSlice, fileModificationTime));
                inputsBuilder.put(entry.getKey(), new OrcDataSourceInput(new BasicSliceInput(streamSlice), toIntExact(streamSlice.getRetainedSize())));
            }
            else {
                inputsBuilder.put(entry.getKey(), entry.getValue());
            }
        }
        return inputsBuilder.build();
    }

    @Override
    public List<RowGroupIndex> getRowIndexes(
            MetadataReader metadataReader,
            HiveWriterVersion hiveWriterVersion,
            StripeId stripId,
            StreamId streamId,
            OrcInputStream inputStream,
            List<HiveBloomFilter> bloomFilters,
            RuntimeStats runtimeStats,
            long fileModificationTime)
            throws IOException
    {
        StripeStreamId stripeStreamId = new StripeStreamId(stripId, streamId);
        if (rowGroupIndexCache.isPresent()) {
            CacheableRowGroupIndices cacheableRowGroupIndices = rowGroupIndexCache.get().getIfPresent(stripeStreamId);
            if (cacheableRowGroupIndices != null && cacheableRowGroupIndices.getFileModificationTime() == fileModificationTime) {
                runtimeStats.addMetricValue("OrcRowGroupIndexCacheHit", NONE, 1);
                runtimeStats.addMetricValue("OrcRowGroupIndexInMemoryBytesRead", BYTE, cacheableRowGroupIndices.getRowGroupIndices().stream().mapToLong(RowGroupIndex::getRetainedSizeInBytes).sum());
                return cacheableRowGroupIndices.getRowGroupIndices();
            }
            else {
                if (cacheableRowGroupIndices != null) {
                    rowGroupIndexCache.get().invalidate(stripeStreamId);
                    // This get call is to increment the miss count for invalidated entries so the stats are recorded correctly.
                    rowGroupIndexCache.get().getIfPresent(stripeStreamId);
                }
                runtimeStats.addMetricValue("OrcRowGroupIndexCacheHit", NONE, 0);
                runtimeStats.addMetricValue("OrcRowGroupIndexStorageBytesRead", BYTE, inputStream.getRetainedSizeInBytes());
            }
        }
        List<RowGroupIndex> rowGroupIndices = delegate.getRowIndexes(metadataReader, hiveWriterVersion, stripId, streamId, inputStream, bloomFilters, runtimeStats, fileModificationTime);
        if (rowGroupIndexCache.isPresent()) {
            rowGroupIndexCache.get().put(stripeStreamId, new CacheableRowGroupIndices(rowGroupIndices, fileModificationTime));
        }
        return rowGroupIndices;
    }

    private static boolean isCachedStream(StreamKind streamKind)
    {
        // BLOOM_FILTER and ROW_INDEX are on the critical path to generate a stripe. Other stream kinds could be lazily read.
        return streamKind == BLOOM_FILTER || streamKind == ROW_INDEX;
    }
}