BaseSubfieldExtractionRewriter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.hive.rule;
import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.predicate.NullableValue;
import com.facebook.presto.common.predicate.TupleDomain;
import com.facebook.presto.expressions.DefaultRowExpressionTraversalVisitor;
import com.facebook.presto.expressions.LogicalRowExpressions;
import com.facebook.presto.hive.BaseHiveTableHandle;
import com.facebook.presto.hive.BaseHiveTableLayoutHandle;
import com.facebook.presto.hive.SubfieldExtractor;
import com.facebook.presto.spi.ColumnHandle;
import com.facebook.presto.spi.ConnectorPlanRewriter;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.ConnectorTableHandle;
import com.facebook.presto.spi.ConnectorTableLayout;
import com.facebook.presto.spi.ConnectorTableLayoutHandle;
import com.facebook.presto.spi.Constraint;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.PrestoWarning;
import com.facebook.presto.spi.TableHandle;
import com.facebook.presto.spi.connector.ConnectorMetadata;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.plan.FilterNode;
import com.facebook.presto.spi.plan.PlanNode;
import com.facebook.presto.spi.plan.PlanNodeIdAllocator;
import com.facebook.presto.spi.plan.TableScanNode;
import com.facebook.presto.spi.plan.ValuesNode;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.DomainTranslator.ExtractionResult;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.RowExpressionService;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import static com.facebook.presto.expressions.DynamicFilters.isDynamicFilter;
import static com.facebook.presto.expressions.DynamicFilters.removeNestedDynamicFilters;
import static com.facebook.presto.expressions.LogicalRowExpressions.FALSE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
import static com.facebook.presto.expressions.RowExpressionNodeInliner.replaceExpression;
import static com.facebook.presto.hive.HiveWarningCode.HIVE_TABLESCAN_CONVERTED_TO_VALUESNODE;
import static com.facebook.presto.hive.MetadataUtils.isEntireColumn;
import static com.facebook.presto.spi.StandardErrorCode.DIVISION_BY_ZERO;
import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static com.facebook.presto.spi.relation.ExpressionOptimizer.Level.OPTIMIZED;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableBiMap.toImmutableBiMap;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Sets.intersection;
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Objects.requireNonNull;
public abstract class BaseSubfieldExtractionRewriter
extends ConnectorPlanRewriter<Void>
{
private static final ConnectorTableLayout EMPTY_TABLE_LAYOUT = new ConnectorTableLayout(
new ConnectorTableLayoutHandle() {},
Optional.empty(),
TupleDomain.none(),
Optional.empty(),
Optional.empty(),
Optional.empty(),
emptyList());
protected final RowExpressionService rowExpressionService;
protected final Function<TableHandle, ConnectorMetadata> transactionToMetadata;
private final ConnectorSession session;
private final PlanNodeIdAllocator idAllocator;
private final StandardFunctionResolution functionResolution;
private final FunctionMetadataManager functionMetadataManager;
public BaseSubfieldExtractionRewriter(
ConnectorSession session,
PlanNodeIdAllocator idAllocator,
RowExpressionService rowExpressionService,
StandardFunctionResolution functionResolution,
FunctionMetadataManager functionMetadataManager,
Function<TableHandle, ConnectorMetadata> transactionToMetadata)
{
this.session = requireNonNull(session, "session is null");
this.idAllocator = requireNonNull(idAllocator, "idAllocator is null");
this.rowExpressionService = requireNonNull(rowExpressionService, "rowExpressionService is null");
this.functionResolution = requireNonNull(functionResolution, "functionResolution is null");
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.transactionToMetadata = transactionToMetadata;
}
@Override
public PlanNode visitFilter(FilterNode filter, RewriteContext<Void> context)
{
if (!(filter.getSource() instanceof TableScanNode)) {
return visitPlan(filter, context);
}
TableScanNode tableScan = (TableScanNode) filter.getSource();
if (!isPushdownFilterSupported(session, tableScan.getTable())) {
return filter;
}
RowExpression expression = filter.getPredicate();
TableHandle handle = tableScan.getTable();
ConnectorMetadata metadata = transactionToMetadata.apply(handle);
BiMap<VariableReferenceExpression, VariableReferenceExpression> symbolToColumnMapping =
tableScan.getAssignments().entrySet().stream().collect(toImmutableBiMap(
Map.Entry::getKey,
entry -> new VariableReferenceExpression(
Optional.empty(),
getColumnName(session, metadata, handle.getConnectorHandle(), entry.getValue()),
entry.getKey().getType())));
RowExpression replacedExpression = replaceExpression(expression, symbolToColumnMapping);
// replaceExpression() may further optimize the expression;
// if the resulting expression is always false, then return empty Values node
if (FALSE_CONSTANT.equals(replacedExpression)) {
return getValuesNode(tableScan);
}
ConnectorPushdownFilterResult pushdownFilterResult = pushdownFilter(
session,
metadata,
handle.getConnectorHandle(),
replacedExpression,
handle.getLayout());
ConnectorTableLayout layout = pushdownFilterResult.getLayout();
if (layout.getPredicate().isNone()) {
return getValuesNode(tableScan);
}
TableScanNode node = getTableScanNode(tableScan, handle, pushdownFilterResult);
RowExpression unenforcedFilter = pushdownFilterResult.getUnenforcedConstraint();
if (!TRUE_CONSTANT.equals(unenforcedFilter)) {
return new FilterNode(
tableScan.getSourceLocation(),
idAllocator.getNextId(),
node,
replaceExpression(unenforcedFilter, symbolToColumnMapping.inverse()));
}
return node;
}
@Override
public PlanNode visitTableScan(TableScanNode tableScan, RewriteContext<Void> context)
{
if (!isPushdownFilterSupported(session, tableScan.getTable())) {
return tableScan;
}
TableHandle handle = tableScan.getTable();
ConnectorMetadata metadata = transactionToMetadata.apply(handle);
ConnectorPushdownFilterResult pushdownFilterResult = pushdownFilter(
session,
metadata,
handle.getConnectorHandle(),
TRUE_CONSTANT,
handle.getLayout());
if (pushdownFilterResult.getLayout().getPredicate().isNone()) {
return getValuesNode(tableScan);
}
TableScanNode node = getTableScanNode(tableScan, handle, pushdownFilterResult);
RowExpression unenforcedFilter = pushdownFilterResult.getUnenforcedConstraint();
if (!TRUE_CONSTANT.equals(unenforcedFilter)) {
throw new PrestoException(
GENERIC_INTERNAL_ERROR,
format("Unenforced filter found %s but not handled", unenforcedFilter));
}
return node;
}
public ConnectorPushdownFilterResult pushdownFilter(
ConnectorSession session,
ConnectorMetadata metadata,
ConnectorTableHandle tableHandle,
RowExpression filter,
Optional<ConnectorTableLayoutHandle> currentLayoutHandle)
{
checkArgument(!FALSE_CONSTANT.equals(filter), "Cannot pushdown filter that is always false");
if (TRUE_CONSTANT.equals(filter) && currentLayoutHandle.isPresent()) {
return new ConnectorPushdownFilterResult(metadata.getTableLayout(session, currentLayoutHandle.get()), TRUE_CONSTANT);
}
// Split the filter into 3 groups of conjuncts:
// - range filters that apply to entire columns,
// - range filters that apply to subfields,
// - the rest. Intersect these with possibly pre-existing filters.
ExtractionResult<Subfield> decomposedFilter = rowExpressionService.getDomainTranslator()
.fromPredicate(session, filter, new SubfieldExtractor(
functionResolution,
rowExpressionService.getExpressionOptimizer(session),
session).toColumnExtractor());
if (currentLayoutHandle.isPresent()) {
BaseHiveTableLayoutHandle currentHiveLayout = (BaseHiveTableLayoutHandle) currentLayoutHandle.get();
decomposedFilter = intersectExtractionResult(
new ExtractionResult(currentHiveLayout.getDomainPredicate(), currentHiveLayout.getRemainingPredicate()),
decomposedFilter);
}
if (decomposedFilter.getTupleDomain().isNone()) {
return new ConnectorPushdownFilterResult(EMPTY_TABLE_LAYOUT, FALSE_CONSTANT);
}
RowExpression optimizedRemainingExpression = rowExpressionService.getExpressionOptimizer(session)
.optimize(decomposedFilter.getRemainingExpression(), OPTIMIZED, session);
if (optimizedRemainingExpression instanceof ConstantExpression) {
ConstantExpression constantExpression = (ConstantExpression) optimizedRemainingExpression;
if (FALSE_CONSTANT.equals(constantExpression) || constantExpression.getValue() == null) {
return new ConnectorPushdownFilterResult(EMPTY_TABLE_LAYOUT, FALSE_CONSTANT);
}
}
Map<String, ColumnHandle> columnHandles = metadata.getColumnHandles(session, tableHandle);
TupleDomain<ColumnHandle> entireColumnDomain = decomposedFilter.getTupleDomain()
.transform(subfield -> isEntireColumn(subfield) ? subfield.getRootName() : null)
.transform(columnHandles::get);
if (currentLayoutHandle.isPresent()) {
entireColumnDomain = entireColumnDomain.intersect(((BaseHiveTableLayoutHandle) currentLayoutHandle.get()).getPartitionColumnPredicate());
}
Constraint<ColumnHandle> constraint = new Constraint<>(entireColumnDomain);
// Extract deterministic conjuncts that apply to partition columns and specify these as Constraint#predicate
constraint = extractDeterministicConjuncts(session, decomposedFilter, columnHandles, entireColumnDomain, constraint);
RemainingExpressions remainingExpressions = getRemainingExpressions(tableHandle, decomposedFilter, columnHandles);
return getConnectorPushdownFilterResult(
columnHandles,
metadata,
session,
remainingExpressions,
decomposedFilter,
optimizedRemainingExpression,
constraint,
currentLayoutHandle,
tableHandle);
}
protected abstract ConnectorPushdownFilterResult getConnectorPushdownFilterResult(
Map<String, ColumnHandle> columnHandles,
ConnectorMetadata metadata,
ConnectorSession session,
RemainingExpressions remainingExpressions,
ExtractionResult<Subfield> decomposedFilter,
RowExpression optimizedRemainingExpression,
Constraint<ColumnHandle> constraint,
Optional<ConnectorTableLayoutHandle> currentLayoutHandle,
ConnectorTableHandle tableHandle);
protected abstract boolean isPushdownFilterSupported(ConnectorSession session, TableHandle tableHandle);
private static String getColumnName(
ConnectorSession session,
ConnectorMetadata metadata,
ConnectorTableHandle tableHandle,
ColumnHandle columnHandle)
{
return metadata.getColumnMetadata(session, tableHandle, columnHandle).getName();
}
private static TableScanNode getTableScanNode(
TableScanNode tableScan,
TableHandle handle,
ConnectorPushdownFilterResult pushdownFilterResult)
{
return new TableScanNode(
tableScan.getSourceLocation(),
tableScan.getId(),
new TableHandle(
handle.getConnectorId(),
handle.getConnectorHandle(),
handle.getTransaction(),
Optional.of(pushdownFilterResult.getLayout().getHandle())),
tableScan.getOutputVariables(),
tableScan.getAssignments(),
tableScan.getTableConstraints(),
pushdownFilterResult.getLayout().getPredicate(),
TupleDomain.all(),
tableScan.getCteMaterializationInfo());
}
private static ExtractionResult intersectExtractionResult(
ExtractionResult left,
ExtractionResult right)
{
RowExpression newRemainingExpression;
if (right.getRemainingExpression().equals(TRUE_CONSTANT)) {
newRemainingExpression = left.getRemainingExpression();
}
else if (left.getRemainingExpression().equals(TRUE_CONSTANT)) {
newRemainingExpression = right.getRemainingExpression();
}
else {
newRemainingExpression = LogicalRowExpressions.and(left.getRemainingExpression(), right.getRemainingExpression());
}
return new ExtractionResult(
left.getTupleDomain().intersect(right.getTupleDomain()), newRemainingExpression);
}
private Constraint<ColumnHandle> extractDeterministicConjuncts(
ConnectorSession session,
ExtractionResult<Subfield> decomposedFilter,
Map<String, ColumnHandle> columnHandles,
TupleDomain<ColumnHandle> entireColumnDomain,
Constraint<ColumnHandle> constraint)
{
if (!TRUE_CONSTANT.equals(decomposedFilter.getRemainingExpression())) {
LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions(
rowExpressionService.getDeterminismEvaluator(),
functionResolution,
functionMetadataManager);
RowExpression deterministicPredicate = logicalRowExpressions.filterDeterministicConjuncts(decomposedFilter.getRemainingExpression());
if (!TRUE_CONSTANT.equals(deterministicPredicate)) {
ConstraintEvaluator evaluator = new ConstraintEvaluator(rowExpressionService, session, columnHandles, deterministicPredicate);
List<ColumnHandle> predicateInputs = ImmutableList.<ColumnHandle>builder().addAll(evaluator.getArguments()).build();
constraint = new Constraint<>(entireColumnDomain, Optional.of(evaluator::isCandidate), Optional.of(predicateInputs));
}
}
return constraint;
}
private ValuesNode getValuesNode(TableScanNode tableScan)
{
session.getWarningCollector().add(new PrestoWarning(
HIVE_TABLESCAN_CONVERTED_TO_VALUESNODE,
format(
"No rows from table '%s' matched the filter",
((BaseHiveTableHandle) (tableScan.getTable().getConnectorHandle())).getTableName())));
return new ValuesNode(
tableScan.getSourceLocation(),
idAllocator.getNextId(),
tableScan.getOutputVariables(),
ImmutableList.of(),
Optional.of(tableScan.getTable().getConnectorHandle().toString()));
}
private RemainingExpressions getRemainingExpressions(
ConnectorTableHandle tableHandle,
ExtractionResult<Subfield> decomposedFilter,
Map<String, ColumnHandle> columnHandles)
{
LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions(
rowExpressionService.getDeterminismEvaluator(),
functionResolution,
functionMetadataManager);
List<RowExpression> conjuncts = extractConjuncts(decomposedFilter.getRemainingExpression());
ImmutableList.Builder<RowExpression> dynamicConjuncts = ImmutableList.builder();
ImmutableList.Builder<RowExpression> staticConjuncts = ImmutableList.builder();
for (RowExpression conjunct : conjuncts) {
if (isDynamicFilter(conjunct) || useDynamicFilter(conjunct, tableHandle, columnHandles)) {
dynamicConjuncts.add(conjunct);
}
else {
staticConjuncts.add(conjunct);
}
}
RowExpression dynamicFilterExpression = logicalRowExpressions.combineConjuncts(dynamicConjuncts.build());
RowExpression remainingExpression = logicalRowExpressions.combineConjuncts(staticConjuncts.build());
remainingExpression = removeNestedDynamicFilters(remainingExpression);
return new RemainingExpressions(dynamicFilterExpression, remainingExpression);
}
private static class ConstraintEvaluator
{
private final Map<String, ColumnHandle> assignments;
private final RowExpressionService evaluator;
private final ConnectorSession session;
private final RowExpression expression;
private final Set<ColumnHandle> arguments;
public ConstraintEvaluator(
RowExpressionService evaluator,
ConnectorSession session,
Map<String, ColumnHandle> assignments,
RowExpression expression)
{
this.assignments = assignments;
this.evaluator = evaluator;
this.session = session;
this.expression = expression;
arguments = ImmutableSet.copyOf(extractVariableExpressions(expression)).stream()
.map(VariableReferenceExpression::getName)
.map(assignments::get)
.collect(toImmutableSet());
}
public Set<ColumnHandle> getArguments()
{
return arguments;
}
private boolean isCandidate(Map<ColumnHandle, NullableValue> bindings)
{
if (intersection(bindings.keySet(), arguments).isEmpty()) {
return true;
}
Function<VariableReferenceExpression, Object> variableResolver = variable -> {
ColumnHandle column = assignments.get(variable.getName());
checkArgument(column != null, "Missing column assignment for %s", variable);
if (!bindings.containsKey(column)) {
return variable;
}
return bindings.get(column).getValue();
};
// Skip pruning if evaluation fails in a recoverable way. Failing here can cause
// spurious query failures for partitions that would otherwise be filtered out.
RowExpression optimized;
try {
optimized = evaluator.getExpressionOptimizer(session).optimize(expression, OPTIMIZED, session, variableResolver);
}
catch (PrestoException e) {
propagateIfUnhandled(e);
return true;
}
// If any conjuncts evaluate to FALSE or null, then the whole predicate will never be true and so the partition should be pruned
if (!(optimized instanceof ConstantExpression)) {
return true;
}
ConstantExpression constantExpression = (ConstantExpression) optimized;
return !Boolean.FALSE.equals(constantExpression.getValue()) && !constantExpression.isNull();
}
private static void propagateIfUnhandled(PrestoException e)
throws PrestoException
{
int errorCode = e.getErrorCode().getCode();
if (errorCode == DIVISION_BY_ZERO.toErrorCode().getCode()
|| errorCode == INVALID_CAST_ARGUMENT.toErrorCode().getCode()
|| errorCode == INVALID_FUNCTION_ARGUMENT.toErrorCode().getCode()
|| errorCode == NUMERIC_VALUE_OUT_OF_RANGE.toErrorCode().getCode()) {
return;
}
throw e;
}
}
protected static Set<VariableReferenceExpression> extractVariableExpressions(RowExpression expression)
{
ImmutableSet.Builder<VariableReferenceExpression> builder = ImmutableSet.builder();
expression.accept(new VariableReferenceBuilderVisitor(), builder);
return builder.build();
}
private static class VariableReferenceBuilderVisitor
extends DefaultRowExpressionTraversalVisitor<ImmutableSet.Builder<VariableReferenceExpression>>
{
@Override
public Void visitVariableReference(
VariableReferenceExpression variable,
ImmutableSet.Builder<VariableReferenceExpression> builder)
{
builder.add(variable);
return null;
}
}
public boolean useDynamicFilter(
RowExpression expression,
ConnectorTableHandle tableHandle,
Map<String, ColumnHandle> columnHandleMap)
{
return false;
}
public static class ConnectorPushdownFilterResult
{
private final ConnectorTableLayout layout;
private final RowExpression unenforcedConstraint;
public ConnectorPushdownFilterResult(ConnectorTableLayout layout, RowExpression unenforcedConstraint)
{
this.layout = requireNonNull(layout, "layout is null");
this.unenforcedConstraint = requireNonNull(unenforcedConstraint, "unenforcedConstraint is null");
}
public ConnectorTableLayout getLayout()
{
return layout;
}
public RowExpression getUnenforcedConstraint()
{
return unenforcedConstraint;
}
}
public static class RemainingExpressions
{
public final RowExpression dynamicFilterExpression;
public final RowExpression remainingExpression;
public RemainingExpressions(RowExpression dynamicFilterExpression, RowExpression remainingExpression)
{
this.dynamicFilterExpression = dynamicFilterExpression;
this.remainingExpression = remainingExpression;
}
public RowExpression getDynamicFilterExpression()
{
return dynamicFilterExpression;
}
public RowExpression getRemainingExpression()
{
return remainingExpression;
}
}
}