TestingOrcPredicate.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.orc;

import com.facebook.presto.common.type.CharType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.SqlDate;
import com.facebook.presto.common.type.SqlDecimal;
import com.facebook.presto.common.type.SqlTimestamp;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarbinaryType;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.orc.OrcTester.Format;
import com.facebook.presto.orc.metadata.statistics.ColumnStatistics;
import com.facebook.presto.orc.metadata.statistics.HiveBloomFilter;
import com.google.common.collect.Iterables;
import com.google.common.collect.Ordering;
import io.airlift.slice.Slice;
import io.airlift.slice.Slices;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.IntStream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.StandardTypes.ARRAY;
import static com.facebook.presto.common.type.StandardTypes.MAP;
import static com.facebook.presto.common.type.StandardTypes.ROW;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP_MICROSECONDS;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.orc.OrcTester.Format.DWRF;
import static com.google.common.base.Predicates.equalTo;
import static com.google.common.base.Predicates.notNull;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.Iterables.filter;
import static com.google.common.collect.Lists.newArrayList;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;

public final class TestingOrcPredicate
{
    public static final int ORC_STRIPE_SIZE = 30_000;
    public static final int ORC_ROW_GROUP_SIZE = 10_000;

    private TestingOrcPredicate()
    {
    }

    public static OrcPredicate createOrcPredicate(List<Type> types, List<List<?>> values, Format format, boolean isHiveWriter)
    {
        List<OrcPredicate> orcPredicates = IntStream.range(0, types.size())
                .mapToObj(i -> createOrcPredicate(i, types.get(i), values.get(i), format, isHiveWriter))
                .collect(toImmutableList());

        return new MultiOrcPredicate(orcPredicates);
    }

    public static OrcPredicate createOrcPredicate(int columnIndex, Type type, Iterable<?> values, Format format, boolean isHiveWriter)
    {
        List<Object> expectedValues = newArrayList(values);
        if (BOOLEAN.equals(type)) {
            return new BooleanOrcPredicate(columnIndex, expectedValues, false);
        }
        if (TINYINT.equals(type) || SMALLINT.equals(type) || INTEGER.equals(type) || BIGINT.equals(type)) {
            return new LongOrcPredicate(true,
                    columnIndex,
                    expectedValues.stream()
                            .map(value -> value == null ? null : ((Number) value).longValue())
                            .collect(toList()),
                    false);
        }
        if (TIMESTAMP.equals(type)) {
            return new LongOrcPredicate(false,
                    columnIndex,
                    expectedValues.stream()
                            .map(value -> value == null ? null : ((SqlTimestamp) value).getMillisUtc())
                            .collect(toList()),
                    false);
        }
        if (TIMESTAMP_MICROSECONDS.equals(type)) {
            return new LongOrcPredicate(false,
                    columnIndex,
                    expectedValues.stream()
                            .map(value -> value == null ? null : ((SqlTimestamp) value).getMicrosUtc())
                            .collect(toList()),
                    false);
        }
        if (DATE.equals(type)) {
            return new DateOrcPredicate(
                    columnIndex,
                    expectedValues.stream()
                            .map(value -> value == null ? null : (long) ((SqlDate) value).getDays())
                            .collect(toList()),
                    false);
        }
        if (REAL.equals(type) || DOUBLE.equals(type)) {
            return new DoubleOrcPredicate(
                    columnIndex,
                    expectedValues.stream()
                            .map(value -> value == null ? null : ((Number) value).doubleValue())
                            .collect(toList()),
                    false);
        }
        if (type instanceof VarbinaryType) {
            // binary does not have stats
            return new BasicOrcPredicate<>(columnIndex, expectedValues, Object.class, false);
        }
        if (type instanceof VarcharType) {
            return new StringOrcPredicate(columnIndex, expectedValues, format, isHiveWriter);
        }
        if (type instanceof CharType) {
            return new CharOrcPredicate(columnIndex, expectedValues, false);
        }
        if (type instanceof DecimalType) {
            return new DecimalOrcPredicate(columnIndex, expectedValues, false);
        }

        String baseType = type.getTypeSignature().getBase();
        if (ARRAY.equals(baseType) || MAP.equals(baseType) || ROW.equals(baseType)) {
            return new BasicOrcPredicate<>(columnIndex, expectedValues, Object.class, false);
        }
        throw new IllegalArgumentException("Unsupported type " + type);
    }

    private static class MultiOrcPredicate
            implements OrcPredicate
    {
        private final List<OrcPredicate> orcPredicates;

        public MultiOrcPredicate(List<OrcPredicate> orcPredicates)
        {
            this.orcPredicates = requireNonNull(orcPredicates, "orcPredicates is null");
        }

        @Override
        public boolean matches(long numberOfRows, Map<Integer, ColumnStatistics> statisticsByColumnIndex)
        {
            return orcPredicates.stream()
                .allMatch(predicate -> predicate.matches(numberOfRows, statisticsByColumnIndex));
        }
    }

    public static class BasicOrcPredicate<T>
            implements OrcPredicate
    {
        private final int columnIndex;
        private final List<T> expectedValues;
        private final boolean noFileStats;

        public BasicOrcPredicate(int columnIndex, Iterable<?> expectedValues, Class<T> type, boolean noFileStats)
        {
            List<T> values = new ArrayList<>();
            for (Object expectedValue : expectedValues) {
                values.add(type.cast(expectedValue));
            }
            this.columnIndex = columnIndex;
            this.expectedValues = Collections.unmodifiableList(values);
            this.noFileStats = noFileStats;
        }

        @Override
        public boolean matches(long numberOfRows, Map<Integer, ColumnStatistics> statisticsByColumnIndex)
        {
            ColumnStatistics columnStatistics = statisticsByColumnIndex.get(columnIndex);
            assertTrue(columnStatistics.hasNumberOfValues());
            if (noFileStats && numberOfRows == expectedValues.size()) {
                assertNull(columnStatistics);
                return true;
            }

            if (numberOfRows == expectedValues.size()) {
                // whole file
                assertChunkStats(expectedValues, columnStatistics);
            }
            else if (numberOfRows == ORC_ROW_GROUP_SIZE) {
                // middle section
                matchMiddleSection(columnStatistics, ORC_ROW_GROUP_SIZE);
            }
            else if (numberOfRows == ORC_STRIPE_SIZE) {
                // middle section
                matchMiddleSection(columnStatistics, ORC_STRIPE_SIZE);
            }
            else if (numberOfRows == expectedValues.size() % ORC_ROW_GROUP_SIZE || numberOfRows == expectedValues.size() % ORC_STRIPE_SIZE) {
                // tail section
                List<T> chunk = expectedValues.subList((int) (expectedValues.size() - numberOfRows), expectedValues.size());
                assertChunkStats(chunk, columnStatistics);
            }
            else {
                fail("Unexpected number of rows: " + numberOfRows);
            }
            return true;
        }

        private void matchMiddleSection(ColumnStatistics columnStatistics, int size)
        {
            int length;
            for (int offset = 0; offset < expectedValues.size(); offset += length) {
                length = Math.min(size, expectedValues.size() - offset);
                if (chunkMatchesStats(expectedValues.subList(offset, offset + length), columnStatistics)) {
                    return;
                }
            }
            fail("match not found for middle section");
        }

        private void assertChunkStats(List<T> chunk, ColumnStatistics columnStatistics)
        {
            assertTrue(chunkMatchesStats(chunk, columnStatistics));
        }

        protected boolean chunkMatchesStats(List<T> chunk, ColumnStatistics columnStatistics)
        {
            // verify non null count
            if (columnStatistics.getNumberOfValues() != Iterables.size(filter(chunk, notNull()))) {
                return false;
            }

            return true;
        }
    }

    public static class BooleanOrcPredicate
            extends BasicOrcPredicate<Boolean>
    {
        public BooleanOrcPredicate(int columnIndex, Iterable<?> expectedValues, boolean noFileStats)
        {
            super(columnIndex, expectedValues, Boolean.class, noFileStats);
        }

        @Override
        protected boolean chunkMatchesStats(List<Boolean> chunk, ColumnStatistics columnStatistics)
        {
            assertNull(columnStatistics.getIntegerStatistics());
            assertNull(columnStatistics.getDoubleStatistics());
            assertNull(columnStatistics.getStringStatistics());
            assertNull(columnStatistics.getDateStatistics());

            // check basic statistics
            if (!super.chunkMatchesStats(chunk, columnStatistics)) {
                return false;
            }

            // statistics can be missing for any reason
            if (columnStatistics.getBooleanStatistics() != null) {
                if (columnStatistics.getBooleanStatistics().getTrueValueCount() != Iterables.size(filter(chunk, equalTo(Boolean.TRUE)))) {
                    return false;
                }
            }
            return true;
        }
    }

    public static class DoubleOrcPredicate
            extends BasicOrcPredicate<Double>
    {
        public DoubleOrcPredicate(int columnIndex, Iterable<?> expectedValues, boolean noFileStats)
        {
            super(columnIndex, expectedValues, Double.class, noFileStats);
        }

        @Override
        protected boolean chunkMatchesStats(List<Double> chunk, ColumnStatistics columnStatistics)
        {
            assertNull(columnStatistics.getBooleanStatistics());
            assertNull(columnStatistics.getIntegerStatistics());
            assertNull(columnStatistics.getStringStatistics());
            assertNull(columnStatistics.getDateStatistics());

            // check basic statistics
            if (!super.chunkMatchesStats(chunk, columnStatistics)) {
                return false;
            }

            HiveBloomFilter bloomFilter = columnStatistics.getBloomFilter();
            if (bloomFilter != null) {
                for (Double value : chunk) {
                    if (value != null && !bloomFilter.testDouble(value)) {
                        return false;
                    }
                }
            }
            // statistics can be missing for any reason
            if (columnStatistics.getDoubleStatistics() != null) {
                if (chunk.stream().allMatch(Objects::isNull)) {
                    if (columnStatistics.getDoubleStatistics().getMin() != null || columnStatistics.getDoubleStatistics().getMax() != null) {
                        return false;
                    }
                }
                else {
                    // verify min
                    if (Math.abs(columnStatistics.getDoubleStatistics().getMin() - Ordering.natural().nullsLast().min(chunk)) > 0.001) {
                        return false;
                    }

                    // verify max
                    if (Math.abs(columnStatistics.getDoubleStatistics().getMax() - Ordering.natural().nullsFirst().max(chunk)) > 0.001) {
                        return false;
                    }
                }
            }
            return true;
        }
    }

    private static class DecimalOrcPredicate
            extends BasicOrcPredicate<SqlDecimal>
    {
        public DecimalOrcPredicate(int columnIndex, Iterable<?> expectedValues, boolean noFileStats)
        {
            super(columnIndex, expectedValues, SqlDecimal.class, noFileStats);
        }
    }

    public static class LongOrcPredicate
            extends BasicOrcPredicate<Long>
    {
        private final boolean testBloomFilter;

        public LongOrcPredicate(boolean testBloomFilter, int columnIndex, Iterable<?> expectedValues, boolean noFileStats)
        {
            super(columnIndex, expectedValues, Long.class, noFileStats);
            this.testBloomFilter = testBloomFilter;
        }

        @Override
        protected boolean chunkMatchesStats(List<Long> chunk, ColumnStatistics columnStatistics)
        {
            assertNull(columnStatistics.getBooleanStatistics());
            assertNull(columnStatistics.getDoubleStatistics());
            assertNull(columnStatistics.getStringStatistics());
            assertNull(columnStatistics.getDateStatistics());

            // check basic statistics
            if (!super.chunkMatchesStats(chunk, columnStatistics)) {
                return false;
            }

            // statistics can be missing for any reason
            if (columnStatistics.getIntegerStatistics() != null) {
                if (chunk.stream().allMatch(Objects::isNull)) {
                    if (columnStatistics.getIntegerStatistics().getMin() != null || columnStatistics.getIntegerStatistics().getMax() != null) {
                        return false;
                    }
                }
                else {
                    // verify min
                    if (!columnStatistics.getIntegerStatistics().getMin().equals(Ordering.natural().nullsLast().min(chunk))) {
                        return false;
                    }

                    // verify max
                    if (!columnStatistics.getIntegerStatistics().getMax().equals(Ordering.natural().nullsFirst().max(chunk))) {
                        return false;
                    }
                }
                long sum = chunk.stream()
                        .filter(Objects::nonNull)
                        .mapToLong(Long::longValue)
                        .sum();
                if (columnStatistics.getIntegerStatistics().getSum() != null && columnStatistics.getIntegerStatistics().getSum() != sum) {
                    return false;
                }
                HiveBloomFilter bloomFilter = columnStatistics.getBloomFilter();
                if (testBloomFilter && bloomFilter != null) {
                    for (Long value : chunk) {
                        if (value != null && !bloomFilter.testLong(value)) {
                            return false;
                        }
                    }
                }
            }

            return true;
        }
    }

    public static class StringOrcPredicate
            extends BasicOrcPredicate<String>
    {
        private final Format format;
        private final boolean isHiveWriter;

        public StringOrcPredicate(int columnIndex, Iterable<?> expectedValues, Format format, boolean isHiveWriter)
        {
            super(columnIndex, expectedValues, String.class, false);
            this.format = format;
            this.isHiveWriter = isHiveWriter;
        }

        @Override
        protected boolean chunkMatchesStats(List<String> chunk, ColumnStatistics columnStatistics)
        {
            assertNull(columnStatistics.getBooleanStatistics());
            assertNull(columnStatistics.getIntegerStatistics());
            assertNull(columnStatistics.getDoubleStatistics());
            assertNull(columnStatistics.getDateStatistics());

            // check basic statistics
            if (!super.chunkMatchesStats(chunk, columnStatistics)) {
                return false;
            }

            List<Slice> slices = chunk.stream()
                    .filter(Objects::nonNull)
                    .map(Slices::utf8Slice)
                    .collect(toList());

            HiveBloomFilter bloomFilter = columnStatistics.getBloomFilter();
            if (bloomFilter != null) {
                for (Slice slice : slices) {
                    if (!bloomFilter.test(slice.getBytes())) {
                        return false;
                    }
                }
                int falsePositive = 0;
                byte[] testBuffer = new byte[32];
                for (int i = 0; i < 100_000; i++) {
                    ThreadLocalRandom.current().nextBytes(testBuffer);
                    if (bloomFilter.test(testBuffer)) {
                        falsePositive++;
                    }
                }
                if (falsePositive != 0 && 1.0 * falsePositive / 100_000 > 0.55) {
                    return false;
                }
            }
            // statistics can be missing for any reason
            if (columnStatistics.getStringStatistics() != null) {
                if (slices.isEmpty()) {
                    if (columnStatistics.getStringStatistics().getMin() != null || columnStatistics.getStringStatistics().getMax() != null) {
                        return false;
                    }
                }
                else {
                    Slice chunkMin = Ordering.natural().nullsLast().min(slices);
                    Slice chunkMax = Ordering.natural().nullsFirst().max(slices);
                    if (format == DWRF && isHiveWriter) {
                        // We use the OLD open source DWRF writer for tests which uses UTF-16be for string stats. These are widened by the our reader.
                        if (columnStatistics.getStringStatistics().getMin().compareTo(chunkMin) > 0) {
                            return false;
                        }
                        if (columnStatistics.getStringStatistics().getMax().compareTo(chunkMax) < 0) {
                            return false;
                        }
                    }
                    else {
                        if (!columnStatistics.getStringStatistics().getMin().equals(chunkMin)) {
                            return false;
                        }
                        if (!columnStatistics.getStringStatistics().getMax().equals(chunkMax)) {
                            return false;
                        }
                    }
                }
            }

            return true;
        }
    }

    public static class CharOrcPredicate
            extends BasicOrcPredicate<String>
    {
        public CharOrcPredicate(int columnIndex, Iterable<?> expectedValues, boolean noFileStats)
        {
            super(columnIndex, expectedValues, String.class, noFileStats);
        }

        @Override
        protected boolean chunkMatchesStats(List<String> chunk, ColumnStatistics columnStatistics)
        {
            assertNull(columnStatistics.getBooleanStatistics());
            assertNull(columnStatistics.getIntegerStatistics());
            assertNull(columnStatistics.getDoubleStatistics());
            assertNull(columnStatistics.getDateStatistics());

            // check basic statistics
            if (!super.chunkMatchesStats(chunk, columnStatistics)) {
                return false;
            }

            List<String> strings = chunk.stream()
                    .filter(Objects::nonNull)
                    .map(String::trim)
                    .collect(toList());

            // statistics can be missing for any reason
            if (columnStatistics.getStringStatistics() != null) {
                if (strings.isEmpty()) {
                    if (columnStatistics.getStringStatistics().getMin() != null || columnStatistics.getStringStatistics().getMax() != null) {
                        return false;
                    }
                }
                else {
                    // verify min
                    String chunkMin = Ordering.natural().nullsLast().min(strings);
                    if (columnStatistics.getStringStatistics().getMin().toStringUtf8().trim().compareTo(chunkMin) > 0) {
                        return false;
                    }

                    // verify max
                    String chunkMax = Ordering.natural().nullsFirst().max(strings);
                    if (columnStatistics.getStringStatistics().getMax().toStringUtf8().trim().compareTo(chunkMax) < 0) {
                        return false;
                    }
                }
            }

            return true;
        }
    }

    public static class DateOrcPredicate
            extends BasicOrcPredicate<Long>
    {
        public DateOrcPredicate(int columnIndex, Iterable<?> expectedValues, boolean noFileStats)
        {
            super(columnIndex, expectedValues, Long.class, noFileStats);
        }

        @Override
        protected boolean chunkMatchesStats(List<Long> chunk, ColumnStatistics columnStatistics)
        {
            assertNull(columnStatistics.getBooleanStatistics());
            assertNull(columnStatistics.getIntegerStatistics());
            assertNull(columnStatistics.getDoubleStatistics());
            assertNull(columnStatistics.getStringStatistics());

            // check basic statistics
            if (!super.chunkMatchesStats(chunk, columnStatistics)) {
                return false;
            }

            // statistics can be missing for any reason
            if (columnStatistics.getDateStatistics() != null) {
                if (chunk.stream().allMatch(Objects::isNull)) {
                    if (columnStatistics.getDateStatistics().getMin() != null || columnStatistics.getDateStatistics().getMax() != null) {
                        return false;
                    }
                }
                else {
                    // verify min
                    Long min = columnStatistics.getDateStatistics().getMin().longValue();
                    if (!min.equals(Ordering.natural().nullsLast().min(chunk))) {
                        return false;
                    }

                    // verify max
                    Long statMax = columnStatistics.getDateStatistics().getMax().longValue();
                    Long chunkMax = Ordering.natural().nullsFirst().max(chunk);
                    if (!statMax.equals(chunkMax)) {
                        return false;
                    }
                }
            }

            return true;
        }
    }
}