DeltaExpressionUtils.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.delta;
import com.facebook.airlift.log.Logger;
import com.facebook.presto.common.predicate.Domain;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.predicate.ValueSet;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.PrestoException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Streams;
import io.airlift.slice.Slice;
import io.delta.kernel.data.FilteredColumnarBatch;
import io.delta.kernel.data.Row;
import io.delta.kernel.internal.InternalScanFileUtils;
import io.delta.kernel.utils.CloseableIterator;
import java.io.IOException;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.function.Predicate;
import static com.facebook.presto.delta.DeltaColumnHandle.ColumnType.PARTITION;
import static com.facebook.presto.delta.DeltaErrorCode.DELTA_INVALID_PARTITION_VALUE;
import static com.facebook.presto.delta.DeltaErrorCode.DELTA_UNSUPPORTED_COLUMN_TYPE;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.airlift.slice.Slices.utf8Slice;
import static java.lang.Double.doubleToRawLongBits;
import static java.lang.Double.parseDouble;
import static java.lang.Float.floatToRawIntBits;
import static java.lang.Float.parseFloat;
import static java.lang.Long.parseLong;
import static java.lang.String.format;
public final class DeltaExpressionUtils
{
private static final Logger logger = Logger.get(DeltaExpressionUtils.class);
private DeltaExpressionUtils()
{
}
/**
* Split the predicate into partition and regular column predicates
*/
public static List<TupleDomain<ColumnHandle>> splitPredicate(
TupleDomain<ColumnHandle> predicate)
{
ImmutableMap.Builder<ColumnHandle, Domain> partitionColumnPredicates = ImmutableMap.builder();
ImmutableMap.Builder<ColumnHandle, Domain> regularColumnPredicates = ImmutableMap.builder();
Optional<Map<ColumnHandle, Domain>> domains = predicate.getDomains();
domains.ifPresent(columnHandleDomainMap -> columnHandleDomainMap.forEach((key, value) -> {
DeltaColumnHandle columnHandle = (DeltaColumnHandle) key;
if (columnHandle.getColumnType() == PARTITION) {
partitionColumnPredicates.put(key, value);
}
else {
regularColumnPredicates.put(key, value);
}
}));
return ImmutableList.of(
TupleDomain.withColumnDomains(partitionColumnPredicates.build()),
TupleDomain.withColumnDomains(regularColumnPredicates.build()));
}
/**
* Utility method that takes an iterator of {@link FilteredColumnarBatch}s and a predicate and returns an iterator
* of {@link FilteredColumnarBatch}s that satisfy the predicate (predicate evaluates to a deterministic NO)
*/
public static CloseableIterator<Row> iterateWithPartitionPruning(
CloseableIterator<FilteredColumnarBatch> inputIterator,
TupleDomain<DeltaColumnHandle> predicate,
TypeManager typeManager)
{
TupleDomain<String> partitionPredicate = extractPartitionColumnsPredicate(predicate);
if (partitionPredicate.isAll()) {
return new AllFilesIterator(inputIterator);
}
if (partitionPredicate.isNone()) {
// nothing passes the partition predicate, return empty iterator
return new NoneFilesIterator(inputIterator);
}
Optional<List<TupleDomain.ColumnDomain<DeltaColumnHandle>>> columnDomains = predicate.getColumnDomains();
List<DeltaColumnHandle> partitionColumns = columnDomains.map(domains -> domains.stream()
.filter(entry -> entry.getColumn().getColumnType() == PARTITION)
.map(TupleDomain.ColumnDomain::getColumn)
.collect(toImmutableList())).orElse(ImmutableList.of());
return new FilteredByPredicateIterator(inputIterator, partitionPredicate, partitionColumns, typeManager);
}
private static TupleDomain<String> extractPartitionColumnsPredicate(TupleDomain<DeltaColumnHandle> predicate)
{
return predicate.transform(
deltaColumnHandle -> {
if (deltaColumnHandle.getColumnType() != PARTITION) {
return null;
}
return deltaColumnHandle.getName();
});
}
private static class NoneFilesIterator
implements CloseableIterator<Row>
{
private final CloseableIterator<FilteredColumnarBatch> inputIterator;
NoneFilesIterator(CloseableIterator<FilteredColumnarBatch> inputIterator)
{
this.inputIterator = inputIterator;
}
@Override
public boolean hasNext()
{
return false;
}
@Override
public Row next()
{
throw new NoSuchElementException();
}
@Override
public void close()
throws IOException
{
inputIterator.close();
}
}
private static class BatchRowIterator
implements CloseableIterator<Row>
{
private final CloseableIterator<FilteredColumnarBatch> inputIterator;
private final Iterator<Row> rows;
private CloseableIterator<Row> prev;
public BatchRowIterator(CloseableIterator<FilteredColumnarBatch> inputIterator,
Optional<Predicate<Row>> rowFilter)
{
this.inputIterator = inputIterator;
this.rows = Streams.stream(inputIterator)
.flatMap(batch -> {
if (prev != null) {
try {
prev.close();
}
catch (IOException e) {
throw new RuntimeException("Failed to close previous row batch", e);
}
}
prev = batch.getRows();
return Streams.stream(prev);
})
// if there is a filter to be applied, it applies it
.filter(row -> rowFilter.map(predicate -> predicate.test(row)).orElse(true))
.iterator();
}
@Override
public boolean hasNext()
{
return rows.hasNext();
}
@Override
public Row next()
{
return rows.next();
}
@Override
public void close() throws IOException
{
if (prev != null) {
prev.close();
}
if (inputIterator != null) {
inputIterator.close();
}
}
}
private static class AllFilesIterator
extends BatchRowIterator
{
public AllFilesIterator(CloseableIterator<FilteredColumnarBatch> inputIterator)
{
super(inputIterator, Optional.empty());
}
}
private static class FilteredByPredicateIterator
extends BatchRowIterator
{
public FilteredByPredicateIterator(CloseableIterator<FilteredColumnarBatch> inputIterator,
TupleDomain<String> partitionPredicate,
List<DeltaColumnHandle> partitionColumns,
TypeManager typeManager)
{
super(inputIterator,
Optional.of(row -> evaluatePartitionPredicate(partitionPredicate, partitionColumns, typeManager, row)));
}
private static boolean evaluatePartitionPredicate(
TupleDomain<String> partitionPredicate,
List<DeltaColumnHandle> partitionColumns,
TypeManager typeManager,
Row row)
{
checkArgument(!partitionPredicate.isNone(), "Expecting a predicate with at least one expression");
for (DeltaColumnHandle partitionColumn : partitionColumns) {
String columnName = partitionColumn.getName();
String partitionValue = InternalScanFileUtils.getPartitionValues(row).get(columnName);
String filePath = InternalScanFileUtils.getAddFileStatus(row).getPath();
logger.debug("Obtaining domain of file: " + filePath);
Domain domain = getDomain(partitionColumn, partitionValue, typeManager, filePath);
Optional<Map<String, Domain>> domains = partitionPredicate.getDomains();
if (!domains.isPresent()) {
logger.debug("Domain is not present in file: " + filePath);
return false;
}
Domain columnPredicate = domains.get().get(columnName);
if (columnPredicate == null) {
continue; // there is no predicate on this column
}
if (columnPredicate.intersect(domain).isNone()) {
logger.debug("Empty set after domain intersection with file: " + filePath);
return false;
}
}
return true;
}
private static Domain getDomain(DeltaColumnHandle columnHandle, String partitionValue, TypeManager typeManager, String filePath)
{
Type type = typeManager.getType(columnHandle.getDataType());
if (partitionValue == null) {
return Domain.onlyNull(type);
}
String typeBase = columnHandle.getDataType().getBase();
try {
switch (typeBase) {
case StandardTypes.TINYINT:
case StandardTypes.SMALLINT:
case StandardTypes.INTEGER:
case StandardTypes.BIGINT:
Long intValue = parseLong(partitionValue);
return Domain.create(ValueSet.of(type, intValue), false);
case StandardTypes.REAL:
Long realValue = (long) floatToRawIntBits(parseFloat(partitionValue));
return Domain.create(ValueSet.of(type, realValue), false);
case StandardTypes.DOUBLE:
Long doubleValue = doubleToRawLongBits(parseDouble(partitionValue));
return Domain.create(ValueSet.of(type, doubleValue), false);
case StandardTypes.VARCHAR:
case StandardTypes.VARBINARY:
Slice sliceValue = utf8Slice(partitionValue);
return Domain.create(ValueSet.of(type, sliceValue), false);
case StandardTypes.DATE:
Long dateValue = Date.valueOf(partitionValue).getTime(); // convert to millis
return Domain.create(ValueSet.of(type, dateValue), false);
case StandardTypes.TIMESTAMP:
Long timestampValue = Timestamp.valueOf(partitionValue).getTime(); // convert to millis
return Domain.create(ValueSet.of(type, timestampValue), false);
case StandardTypes.BOOLEAN:
Boolean booleanValue = Boolean.valueOf(partitionValue);
return Domain.create(ValueSet.of(type, booleanValue), false);
default:
throw new PrestoException(DELTA_UNSUPPORTED_COLUMN_TYPE,
format("Unsupported data type '%s' for partition column %s", columnHandle.getDataType(), columnHandle.getName()));
}
}
catch (IllegalArgumentException exception) {
throw new PrestoException(DELTA_INVALID_PARTITION_VALUE,
format("Can not parse partition value '%s' of type '%s' for partition column '%s' in file '%s'",
partitionValue, columnHandle.getDataType(), columnHandle.getName(), filePath),
exception);
}
}
}
}