OrcInputStream.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc.stream;

import com.facebook.presto.orc.DwrfDataEncryptor;
import com.facebook.presto.orc.OrcAggregatedMemoryContext;
import com.facebook.presto.orc.OrcCorruptionException;
import com.facebook.presto.orc.OrcDataSourceId;
import com.facebook.presto.orc.OrcDecompressor;
import com.facebook.presto.orc.OrcLocalMemoryContext;
import com.facebook.presto.orc.metadata.OrcType.OrcTypeKind;
import io.airlift.slice.ByteArrays;
import io.airlift.slice.FixedLengthSliceInput;
import org.openjdk.jol.info.ClassLayout;

import java.io.IOException;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Optional;

import static com.facebook.presto.orc.checkpoint.InputStreamCheckpoint.createInputStreamCheckpoint;
import static com.facebook.presto.orc.checkpoint.InputStreamCheckpoint.decodeCompressedBlockOffset;
import static com.facebook.presto.orc.checkpoint.InputStreamCheckpoint.decodeDecompressedOffset;
import static com.facebook.presto.orc.stream.LongDecode.zigzagDecode;
import static com.google.common.base.MoreObjects.toStringHelper;
import static io.airlift.slice.SizeOf.SIZE_OF_DOUBLE;
import static io.airlift.slice.SizeOf.SIZE_OF_FLOAT;
import static io.airlift.slice.SizeOf.SIZE_OF_INT;
import static io.airlift.slice.SizeOf.SIZE_OF_LONG;
import static io.airlift.slice.SizeOf.SIZE_OF_SHORT;
import static io.airlift.slice.SizeOf.sizeOf;
import static io.airlift.slice.Slices.EMPTY_SLICE;
import static java.lang.Math.round;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

public final class OrcInputStream
        extends InputStream
{
    private static final int INSTANCE_SIZE = ClassLayout.parseClass(OrcInputStream.class).instanceSize();

    private static final long VARINT_MASK = 0x8080_8080_8080_8080L;
    private static final int MAX_VARINT_LENGTH = 10;
    private static final double BUFFER_ALLOWED_MEMORY_WASTE_RATIO = 1.5;

    private final OrcDataSourceId orcDataSourceId;
    private final SharedBuffer sharedDecompressionBuffer;
    private final FixedLengthSliceInput compressedSliceInput;
    private final long compressedSliceInputRetainedSizeInBytes;
    private final Optional<OrcDecompressor> decompressor;
    private final Optional<DwrfDataEncryptor> dwrfDecryptor;
    private final OrcLocalMemoryContext memoryUsage;
    // Temporary memory for reading a float or double at buffer boundary.
    private final byte[] temporaryBuffer = new byte[SIZE_OF_DOUBLE];

    private int currentCompressedBlockOffset;

    private byte[] buffer;
    private int position;
    private int length;
    private int uncompressedOffset;

    public OrcInputStream(
            OrcDataSourceId orcDataSourceId,
            SharedBuffer sharedDecompressionBuffer,
            FixedLengthSliceInput sliceInput,
            Optional<OrcDecompressor> decompressor,
            Optional<DwrfDataEncryptor> dwrfDecryptor,
            OrcAggregatedMemoryContext systemMemoryContext,
            long sliceInputRetainedSizeInBytes)
    {
        this.orcDataSourceId = requireNonNull(orcDataSourceId, "orcDataSource is null");
        this.sharedDecompressionBuffer = requireNonNull(sharedDecompressionBuffer, "sharedDecompressionBuffer is null");

        requireNonNull(sliceInput, "sliceInput is null");

        this.decompressor = requireNonNull(decompressor, "decompressor is null");
        this.dwrfDecryptor = requireNonNull(dwrfDecryptor, "dwrfDecryptor is null");

        // memory reserved in the systemMemoryContext is never release and instead it is
        // expected that the context itself will be destroyed at the end of the read
        requireNonNull(systemMemoryContext, "systemMemoryContext is null");
        this.memoryUsage = systemMemoryContext.newOrcLocalMemoryContext(OrcInputStream.class.getSimpleName());

        if (!decompressor.isPresent() && !dwrfDecryptor.isPresent()) {
            // for unencrypted uncompressed input read the entire input and discard the original sliceInput
            int sliceInputPosition = toIntExact(sliceInput.position());
            int sliceInputRemaining = toIntExact(sliceInput.remaining());
            this.buffer = new byte[sliceInputRemaining];
            this.length = buffer.length;
            sliceInput.readFully(buffer, sliceInputPosition, sliceInputRemaining);
            this.compressedSliceInput = EMPTY_SLICE.getInput();
            this.compressedSliceInputRetainedSizeInBytes = compressedSliceInput.getRetainedSize();
        }
        else {
            this.compressedSliceInput = sliceInput;
            this.buffer = new byte[0];
            this.compressedSliceInputRetainedSizeInBytes = sliceInputRetainedSizeInBytes;
        }

        memoryUsage.setBytes(getRetainedSizeInBytes());
    }

    @Override
    public void close()
    {
        // close is never called, so do not add code here
    }

    @Override
    public int available()
    {
        if (buffer == null) {
            return 0;
        }
        return length - position;
    }

    @Override
    public int read()
            throws IOException
    {
        if (buffer == null) {
            return -1;
        }
        if (available() > 0) {
            return 0xff & buffer[position++];
        }

        advance();
        return read();
    }

    @Override
    public int read(byte[] b, int off, int length)
            throws IOException
    {
        if (buffer == null) {
            return -1;
        }

        if (available() == 0) {
            advance();
            if (buffer == null) {
                return -1;
            }
        }
        length = Math.min(length, available());
        System.arraycopy(buffer, position, b, off, length);
        position += length;
        return length;
    }

    public void skipFully(long length)
            throws IOException
    {
        while (length > 0) {
            long result = skip(length);
            if (result < 0) {
                throw new OrcCorruptionException(orcDataSourceId, "Unexpected end of stream");
            }
            length -= result;
        }
    }

    public void readFully(byte[] buffer, int offset, int length)
            throws IOException
    {
        while (offset < length) {
            int result = read(buffer, offset, length - offset);
            if (result < 0) {
                throw new OrcCorruptionException(orcDataSourceId, "Unexpected end of stream");
            }
            offset += result;
        }
    }

    public OrcDataSourceId getOrcDataSourceId()
    {
        return orcDataSourceId;
    }

    public long getCheckpoint()
    {
        // if the decompressed buffer is empty, return a checkpoint starting at the next block
        if (buffer == null || (position == 0 && available() == 0)) {
            return createInputStreamCheckpoint(toIntExact(compressedSliceInput.position()), 0);
        }
        // otherwise return a checkpoint at the last compressed block read and the current position in the buffer
        // If we have uncompressed data uncompressedOffset is not included in the offset.
        return createInputStreamCheckpoint(currentCompressedBlockOffset, toIntExact(position - uncompressedOffset));
    }

    public boolean seekToCheckpoint(long checkpoint)
            throws IOException
    {
        int compressedBlockOffset = decodeCompressedBlockOffset(checkpoint);
        int decompressedOffset = decodeDecompressedOffset(checkpoint);
        boolean discardedBuffer;
        if (compressedBlockOffset != currentCompressedBlockOffset) {
            if (!decompressor.isPresent() && !dwrfDecryptor.isPresent()) {
                throw new OrcCorruptionException(orcDataSourceId, "Reset stream has a block offset but stream is not compressed or encrypted");
            }
            compressedSliceInput.setPosition(compressedBlockOffset);
            buffer = new byte[0];
            memoryUsage.setBytes(getRetainedSizeInBytes());
            position = 0;
            length = 0;
            uncompressedOffset = 0;
            discardedBuffer = true;
        }
        else {
            discardedBuffer = false;
        }

        if (decompressedOffset != position - uncompressedOffset) {
            position = uncompressedOffset;
            if (available() < decompressedOffset) {
                decompressedOffset -= available();
                advance();
            }
            position += decompressedOffset;
        }
        else if (length == 0) {
            advance();
            position += decompressedOffset;
        }
        return discardedBuffer;
    }

    @Override
    public long skip(long n)
            throws IOException
    {
        if (buffer == null || n <= 0) {
            return -1;
        }

        long result = Math.min(available(), n);
        position += toIntExact(result);
        if (result != 0) {
            return result;
        }
        if (read() == -1) {
            return 0;
        }
        result = Math.min(available(), n - 1);
        position += toIntExact(result);
        return 1 + result;
    }

    public long readDwrfLong(OrcTypeKind type)
            throws IOException
    {
        switch (type) {
            case SHORT:
                return read() | (read() << 8);
            case INT:
                return read() | (read() << 8) | (read() << 16) | (read() << 24);
            case LONG:
                return ((long) read()) |
                        (((long) read()) << 8) |
                        (((long) read()) << 16) |
                        (((long) read()) << 24) |
                        (((long) read()) << 32) |
                        (((long) read()) << 40) |
                        (((long) read()) << 48) |
                        (((long) read()) << 56);
            default:
                throw new IllegalStateException();
        }
    }

    public void skipDwrfLong(OrcTypeKind type, long items)
            throws IOException
    {
        if (items == 0) {
            return;
        }
        long bytes = items;
        switch (type) {
            case SHORT:
                bytes *= SIZE_OF_SHORT;
                break;
            case INT:
                bytes *= SIZE_OF_INT;
                break;
            case LONG:
                bytes *= SIZE_OF_LONG;
                break;
            default:
                throw new IllegalStateException();
        }
        skip(bytes);
    }

    public long readVarint(boolean signed)
            throws IOException
    {
        long result = 0;
        int shift = 0;
        int available = available();
        if (available >= 2 * Long.BYTES) {
            long word = ByteArrays.getLong(buffer, position);
            int count = 1;
            boolean atEnd = false;
            result = word & 0x7f;
            if ((word & 0x80) != 0) {
                long control = word >>> 8;
                long mask = 0x7f << 7;
                while (true) {
                    word = word >>> 1;
                    result |= word & mask;
                    count++;
                    if ((control & 0x80) == 0) {
                        atEnd = true;
                        break;
                    }
                    if (mask == 0x7fL << (7 * 7)) {
                        break;
                    }
                    mask = mask << 7;
                    control = control >>> 8;
                }
                if (!atEnd) {
                    word = ByteArrays.getLong(buffer, position + 8);
                    result |= (word & 0x7f) << 56;
                    if ((word & 0x80) == 0) {
                        count++;
                    }
                    else {
                        result |= 1L << 63;
                        count += 2;
                    }
                }
            }
            position += count;
        }
        else {
            do {
                if (available == 0) {
                    advance();
                    available = available();
                    if (available == 0) {
                        throw new OrcCorruptionException(orcDataSourceId, "End of stream in RLE Integer");
                    }
                }
                available--;
                result |= (long) (buffer[position] & 0x7f) << shift;
                shift += 7;
            }
            while ((buffer[position++] & 0x80) != 0);
        }
        if (signed) {
            return zigzagDecode(result);
        }
        else {
            return result;
        }
    }

    public void skipVarints(long items)
            throws IOException
    {
        if (items == 0) {
            return;
        }

        while (items > 0) {
            items -= skipVarintsInBuffer(items);
        }
    }

    private long skipVarintsInBuffer(long items)
            throws IOException
    {
        if (available() == 0) {
            advance();
            if (available() == 0) {
                throw new OrcCorruptionException(orcDataSourceId, "Unexpected end of stream");
            }
        }
        long skipped = 0;
        // If items to skip is > SIZE_OF_LONG it is safe to skip entire longs
        while (items - skipped > SIZE_OF_LONG && available() > MAX_VARINT_LENGTH) {
            long value = ByteArrays.getLong(buffer, position);
            position += SIZE_OF_LONG;
            long mask = (value & VARINT_MASK) ^ VARINT_MASK;
            skipped += Long.bitCount(mask);
        }
        while (skipped < items && available() > 0) {
            if ((buffer[position++] & 0x80) == 0) {
                skipped++;
            }
        }
        return skipped;
    }

    public double readDouble()
            throws IOException
    {
        int readPosition = ensureContiguousBytesAndAdvance(SIZE_OF_DOUBLE);
        if (readPosition < 0) {
            return ByteArrays.getDouble(temporaryBuffer, 0);
        }
        return ByteArrays.getDouble(buffer, readPosition);
    }

    public float readFloat()
            throws IOException
    {
        int readPosition = ensureContiguousBytesAndAdvance(SIZE_OF_FLOAT);
        if (readPosition < 0) {
            return ByteArrays.getFloat(temporaryBuffer, 0);
        }
        return ByteArrays.getFloat(buffer, readPosition);
    }

    private int ensureContiguousBytesAndAdvance(int bytes)
            throws IOException
    {
        // If there are numBytes in the buffer, return the offset of the start and advance by numBytes. If not, copy numBytes
        // into temporaryBuffer, advance by numBytes and return -1.
        if (available() >= bytes) {
            int startPosition = position;
            position += bytes;
            return startPosition;
        }
        readFully(temporaryBuffer, 0, bytes);
        return -1;
    }

    // This comes from the Apache Hive ORC code
    private void advance()
            throws IOException
    {
        if (compressedSliceInput == null || compressedSliceInput.remaining() == 0) {
            buffer = null;
            position = 0;
            length = 0;
            uncompressedOffset = 0;
            memoryUsage.setBytes(getRetainedSizeInBytes());
            return;
        }

        // 3 byte header
        // NOTE: this must match BLOCK_HEADER_SIZE
        currentCompressedBlockOffset = toIntExact(compressedSliceInput.position());
        int b0 = compressedSliceInput.readUnsignedByte();
        int b1 = compressedSliceInput.readUnsignedByte();
        int b2 = compressedSliceInput.readUnsignedByte();

        boolean isUncompressed = (b0 & 0x01) == 1;
        int chunkLength = (b2 << 15) | (b1 << 7) | (b0 >>> 1);
        if (chunkLength < 0 || chunkLength > compressedSliceInput.remaining()) {
            throw new OrcCorruptionException(orcDataSourceId, "The chunkLength (%s) must not be negative or greater than remaining size (%s)", chunkLength, compressedSliceInput.remaining());
        }

        if (isUncompressed) {
            buffer = ensureCapacity(buffer, chunkLength);
            length = compressedSliceInput.read(buffer, 0, chunkLength);
            if (dwrfDecryptor.isPresent()) {
                buffer = dwrfDecryptor.get().decrypt(buffer, 0, chunkLength);
                length = buffer.length;
            }
            position = 0;
        }
        else {
            sharedDecompressionBuffer.ensureCapacity(chunkLength);
            byte[] compressedBuffer = sharedDecompressionBuffer.get();
            int readCompressed = compressedSliceInput.read(compressedBuffer, 0, chunkLength);
            if (dwrfDecryptor.isPresent()) {
                compressedBuffer = dwrfDecryptor.get().decrypt(compressedBuffer, 0, chunkLength);
                readCompressed = compressedBuffer.length;
            }

            length = decompressor.get().decompress(compressedBuffer, 0, readCompressed, createDecompressorOutputBufferAdapter());
            position = 0;
        }
        uncompressedOffset = position;
        memoryUsage.setBytes(getRetainedSizeInBytes());
    }

    public long getRetainedSizeInBytes()
    {
        return INSTANCE_SIZE +
                compressedSliceInputRetainedSizeInBytes +
                sizeOf(buffer) +
                sizeOf(temporaryBuffer);
    }

    private static byte[] ensureCapacity(byte[] buffer, int capacity)
    {
        if (buffer == null || buffer.length < capacity || buffer.length > round(capacity * BUFFER_ALLOWED_MEMORY_WASTE_RATIO)) {
            return new byte[capacity];
        }

        return buffer;
    }

    @Override
    public String toString()
    {
        return toStringHelper(this)
                .add("source", orcDataSourceId)
                .add("compressedOffset", compressedSliceInput.position())
                .add("uncompressedOffset", buffer == null ? null : position)
                .add("decompressor", decompressor.map(Object::toString).orElse("none"))
                .add("decryptor", dwrfDecryptor.map(Object::toString).orElse("none"))
                .toString();
    }

    private OrcDecompressor.OutputBuffer createDecompressorOutputBufferAdapter()
    {
        return new OrcDecompressor.OutputBuffer()
        {
            @Override
            public byte[] initialize(int size)
            {
                buffer = ensureCapacity(buffer, size);
                return buffer;
            }

            @Override
            public byte[] grow(int size)
            {
                if (size > buffer.length) {
                    buffer = Arrays.copyOfRange(buffer, 0, size);
                }
                return buffer;
            }
        };
    }
}