FilteringPageSource.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.hive;
import com.facebook.presto.common.Page;
import com.facebook.presto.common.RuntimeStats;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.predicate.Domain;
import com.facebook.presto.common.predicate.FilterFunction;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.common.predicate.TupleDomainFilter;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.expressions.DynamicFilters;
import com.facebook.presto.spi.ConnectorPageSource;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.relation.InputReferenceExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionService;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import io.airlift.slice.Slice;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.IntStream;
import static com.facebook.presto.common.predicate.TupleDomainFilter.IS_NOT_NULL;
import static com.facebook.presto.common.predicate.TupleDomainFilter.IS_NULL;
import static com.facebook.presto.common.predicate.TupleDomainFilterUtils.toFilter;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.Chars.isCharType;
import static com.facebook.presto.common.type.DateType.DATE;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.RealType.REAL;
import static com.facebook.presto.common.type.SmallintType.SMALLINT;
import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
import static com.facebook.presto.common.type.TinyintType.TINYINT;
import static com.facebook.presto.common.type.Varchars.isVarcharType;
import static com.facebook.presto.expressions.DynamicFilters.extractDynamicFilters;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.and;
import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.google.common.collect.ImmutableMap.toImmutableMap;
import static java.lang.Double.longBitsToDouble;
import static java.lang.Float.intBitsToFloat;
import static java.util.Objects.requireNonNull;
import static sun.misc.Unsafe.ARRAY_BYTE_BASE_OFFSET;
public class FilteringPageSource
implements ConnectorPageSource
{
private final ConnectorPageSource delegate;
private final TupleDomainFilter[] domainFilters;
private final Type[] columnTypes;
private final Map<Integer, Integer> functionInputs; // key: hiveColumnIndex
private final FilterFunction filterFunction;
private final int outputBlockCount;
public FilteringPageSource(
List<HivePageSourceProvider.ColumnMapping> columnMappings,
TupleDomain<HiveColumnHandle> domainPredicate,
RowExpression remainingPredicate,
TypeManager typeManager,
RowExpressionService rowExpressionService,
ConnectorSession session,
Set<Integer> originalIndices,
ConnectorPageSource delegate)
{
requireNonNull(rowExpressionService, "rowExpressionService is null");
requireNonNull(remainingPredicate, "remainingPredicate is null");
requireNonNull(typeManager, "typeManager is null");
this.delegate = requireNonNull(delegate, "delegate is null");
domainFilters = new TupleDomainFilter[columnMappings.size()];
columnTypes = new Type[columnMappings.size()];
if (!domainPredicate.isAll()) {
Map<Integer, Domain> domains = domainPredicate.transform(HiveColumnHandle::getHiveColumnIndex).getDomains().get();
for (int i = 0; i < columnMappings.size(); i++) {
HiveColumnHandle columnHandle = columnMappings.get(i).getHiveColumnHandle();
int hiveColumnIndex = columnHandle.getHiveColumnIndex();
if (domains.containsKey(hiveColumnIndex)) {
domainFilters[i] = toFilter(domains.get(hiveColumnIndex));
columnTypes[i] = columnHandle.getHiveType().getType(typeManager);
}
}
}
this.functionInputs = IntStream.range(0, columnMappings.size())
.boxed()
.collect(toImmutableMap(i -> columnMappings.get(i).getHiveColumnHandle().getHiveColumnIndex(), Function.identity()));
Map<VariableReferenceExpression, InputReferenceExpression> variableToInput = columnMappings.stream()
.map(HivePageSourceProvider.ColumnMapping::getHiveColumnHandle)
.collect(toImmutableMap(
columnHandle -> new VariableReferenceExpression(Optional.empty(), columnHandle.getName(), columnHandle.getHiveType().getType(typeManager)),
columnHandle -> new InputReferenceExpression(Optional.empty(), columnHandle.getHiveColumnIndex(), columnHandle.getHiveType().getType(typeManager))));
RowExpression optimizedRemainingPredicate = rowExpressionService.getExpressionOptimizer(session).optimize(remainingPredicate, OPTIMIZED, session);
if (TRUE_CONSTANT.equals(optimizedRemainingPredicate)) {
this.filterFunction = null;
}
else {
RowExpression expression = replaceExpression(optimizedRemainingPredicate, variableToInput);
DynamicFilters.DynamicFilterExtractResult extractDynamicFilterResult = extractDynamicFilters(expression);
// dynamic filter will be added through subfield pushdown
expression = and(extractDynamicFilterResult.getStaticConjuncts());
this.filterFunction = new FilterFunction(
session.getSqlFunctionProperties(),
rowExpressionService.getDeterminismEvaluator().isDeterministic(expression),
rowExpressionService.getPredicateCompiler().compilePredicate(session.getSqlFunctionProperties(), session.getSessionFunctions(), expression).get());
}
this.outputBlockCount = requireNonNull(originalIndices, "originalIndices is null").size();
}
@Override
public Page getNextPage()
{
Page page = delegate.getNextPage();
if (page == null || page.getPositionCount() == 0) {
return page;
}
int positionCount = page.getPositionCount();
int[] positions = new int[positionCount];
for (int i = 0; i < positionCount; i++) {
positions[i] = i;
}
for (int i = 0; i < page.getChannelCount(); i++) {
TupleDomainFilter domainFilter = domainFilters[i];
if (domainFilter != null) {
positionCount = filterBlock(page.getBlock(i), columnTypes[i], domainFilter, positions, positionCount);
if (positionCount == 0) {
return new Page(0);
}
}
}
if (filterFunction != null) {
RuntimeException[] errors = new RuntimeException[positionCount];
int[] inputChannels = filterFunction.getInputChannels();
Block[] inputBlocks = new Block[inputChannels.length];
for (int i = 0; i < inputChannels.length; i++) {
inputBlocks[i] = page.getBlock(this.functionInputs.get(inputChannels[i]));
}
Page inputPage = new Page(page.getPositionCount(), inputBlocks);
positionCount = filterFunction.filter(inputPage, positions, positionCount, errors);
for (int i = 0; i < positionCount; i++) {
if (errors[i] != null) {
throw errors[i];
}
}
if (positionCount == 0) {
return new Page(0);
}
}
if (outputBlockCount == page.getChannelCount()) {
return page.getPositions(positions, 0, positionCount);
}
Block[] blocks = new Block[outputBlockCount];
for (int i = 0; i < outputBlockCount; i++) {
blocks[i] = page.getBlock(i);
}
return new Page(page.getPositionCount(), blocks).getPositions(positions, 0, positionCount);
}
@Override
public long getSystemMemoryUsage()
{
return delegate.getSystemMemoryUsage();
}
@Override
public RuntimeStats getRuntimeStats()
{
return delegate.getRuntimeStats();
}
@Override
public void close()
throws IOException
{
delegate.close();
}
private static int filterBlock(Block block, Type type, TupleDomainFilter filter, int[] positions, int positionCount)
{
int outputPositionsCount = 0;
for (int i = 0; i < positionCount; i++) {
int position = positions[i];
if (block.isNull(position)) {
if (filter.testNull()) {
positions[outputPositionsCount] = position;
outputPositionsCount++;
}
}
else if (testNonNullPosition(block, position, type, filter)) {
positions[outputPositionsCount] = position;
outputPositionsCount++;
}
}
return outputPositionsCount;
}
private static boolean testNonNullPosition(Block block, int position, Type type, TupleDomainFilter filter)
{
if (type == BIGINT || type == INTEGER || type == SMALLINT || type == TINYINT || type == TIMESTAMP || type == DATE) {
return filter.testLong(type.getLong(block, position));
}
if (type == BOOLEAN) {
return filter.testBoolean(type.getBoolean(block, position));
}
if (type.equals(DOUBLE)) {
return filter.testDouble(longBitsToDouble(block.getLong(position)));
}
if (type.equals(REAL)) {
return filter.testFloat(intBitsToFloat(block.getInt(position)));
}
if (type instanceof DecimalType) {
if (((DecimalType) type).isShort()) {
return filter.testLong(block.getLong(position));
}
else {
return filter.testDecimal(block.getLong(position, 0), block.getLong(position, Long.BYTES));
}
}
if (isVarcharType(type) || isCharType(type)) {
Slice slice = block.getSlice(position, 0, block.getSliceLength(position));
return filter.testBytes((byte[]) slice.getBase(), (int) slice.getAddress() - ARRAY_BYTE_BASE_OFFSET, slice.length());
}
if (type instanceof ArrayType || type instanceof MapType || type instanceof RowType) {
if (IS_NULL == filter) {
return block.isNull(position);
}
if (IS_NOT_NULL == filter) {
return !block.isNull(position);
}
}
throw new UnsupportedOperationException("Unexpected column type " + type);
}
@Override
public long getCompletedBytes()
{
return delegate.getCompletedBytes();
}
@Override
public long getCompletedPositions()
{
return delegate.getCompletedPositions();
}
@Override
public long getReadTimeNanos()
{
return delegate.getReadTimeNanos();
}
@Override
public boolean isFinished()
{
return delegate.isFinished();
}
}