DruidAggregationProjectConverter.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.druid;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.TypeManager;
import com.facebook.presto.spi.ConnectorSession;
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.function.FunctionMetadata;
import com.facebook.presto.spi.function.FunctionMetadataManager;
import com.facebook.presto.spi.function.StandardFunctionResolution;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slice;
import org.joda.time.DateTimeZone;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.druid.DruidErrorCode.DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION;
import static com.facebook.presto.druid.DruidExpression.derived;
import static com.facebook.presto.druid.DruidPushdownUtils.getLiteralAsString;
import static java.lang.String.format;
import static java.util.Locale.ENGLISH;
import static java.util.Objects.requireNonNull;
public class DruidAggregationProjectConverter
extends DruidProjectExpressionConverter
{
private static final Map<String, String> PRESTO_TO_DRUID_OPERATORS = ImmutableMap.of(
"-", "SUB",
"+", "ADD",
"*", "MULT",
"/", "DIV");
private static final String FROM_UNIXTIME = "from_unixtime";
private static final String DATE_TRUNC = "date_trunc";
private final FunctionMetadataManager functionMetadataManager;
private final ConnectorSession session;
public DruidAggregationProjectConverter(
ConnectorSession session,
TypeManager typeManager,
FunctionMetadataManager functionMetadataManager,
StandardFunctionResolution standardFunctionResolution)
{
super(typeManager, standardFunctionResolution);
this.functionMetadataManager = requireNonNull(functionMetadataManager, "functionMetadataManager is null");
this.session = requireNonNull(session, "session is null");
}
@Override
public DruidExpression visitCall(
CallExpression call,
Map<VariableReferenceExpression, DruidQueryGeneratorContext.Selection> context)
{
Optional<DruidExpression> basicCallHandlingResult = basicCallHandling(call, context);
if (basicCallHandlingResult.isPresent()) {
return basicCallHandlingResult.get();
}
FunctionMetadata functionMetadata = functionMetadataManager.getFunctionMetadata(call.getFunctionHandle());
Optional<OperatorType> operatorTypeOptional = functionMetadata.getOperatorType();
if (operatorTypeOptional.isPresent()) {
OperatorType operatorType = operatorTypeOptional.get();
if (operatorType.isArithmeticOperator()) {
return handleArithmeticExpression(call, operatorType, context);
}
if (operatorType.isComparisonOperator()) {
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Unsupported operator: " + call + " to pushdown for Druid connector.");
}
}
return handleFunction(call, context);
}
@Override
public DruidExpression visitConstant(
ConstantExpression literal,
Map<VariableReferenceExpression, DruidQueryGeneratorContext.Selection> context)
{
return new DruidExpression(getLiteralAsString(session, literal), DruidQueryGeneratorContext.Origin.LITERAL);
}
private DruidExpression handleDateTruncationViaDateTruncation(
CallExpression function,
Map<VariableReferenceExpression, DruidQueryGeneratorContext.Selection> context)
{
RowExpression timeInputParameter = function.getArguments().get(1);
String inputColumn;
String inputTimeZone;
String inputFormat;
CallExpression timeConversion = getExpressionAsFunction(timeInputParameter, timeInputParameter);
if (!timeConversion.getDisplayName().toLowerCase(ENGLISH).equals(FROM_UNIXTIME)) {
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Unsupported time function: " + timeConversion.getDisplayName() + " to pushdown for Druid connector.");
}
inputColumn = timeConversion.getArguments().get(0).accept(this, context).getDefinition();
inputTimeZone = timeConversion.getArguments().size() > 1 ? getStringFromConstant(timeConversion.getArguments().get(1)) : DateTimeZone.UTC.getID();
inputFormat = "seconds";
RowExpression intervalParameter = function.getArguments().get(0);
if (!(intervalParameter instanceof ConstantExpression)) {
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Unsupported interval unit: " + intervalParameter + " to pushdown for Druid connector.");
}
return derived("dateTrunc(" + inputColumn + "," + inputFormat + ", " + inputTimeZone + ", " + getStringFromConstant(intervalParameter) + ")");
}
private DruidExpression handleArithmeticExpression(
CallExpression expression,
OperatorType operatorType,
Map<VariableReferenceExpression, DruidQueryGeneratorContext.Selection> context)
{
List<RowExpression> arguments = expression.getArguments();
if (arguments.size() == 1) {
String prefix = operatorType == OperatorType.NEGATION ? "-" : "";
return derived(prefix + arguments.get(0).accept(this, context).getDefinition());
}
if (arguments.size() == 2) {
DruidExpression left = arguments.get(0).accept(this, context);
DruidExpression right = arguments.get(1).accept(this, context);
String prestoOperator = operatorType.getOperator();
String druidOperator = PRESTO_TO_DRUID_OPERATORS.get(prestoOperator);
if (druidOperator == null) {
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Unsupported binary expression: " + prestoOperator + " to pushdown for Druid connector.");
}
return derived(format("%s(%s, %s)", druidOperator, left.getDefinition(), right.getDefinition()));
}
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Unsupported arithmetic expression: " + expression + " to pushdown for Druid connector.");
}
private DruidExpression handleFunction(
CallExpression function,
Map<VariableReferenceExpression, DruidQueryGeneratorContext.Selection> context)
{
if (function.getDisplayName().toLowerCase(ENGLISH).equals(DATE_TRUNC)) {
return handleDateTruncationViaDateTruncation(function, context);
}
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Unsupported function: " + function.getDisplayName() + " to pushdown for Druid connector.");
}
private static String getStringFromConstant(RowExpression expression)
{
if (expression instanceof ConstantExpression) {
Object value = ((ConstantExpression) expression).getValue();
if (value instanceof String) {
return (String) value;
}
if (value instanceof Slice) {
return ((Slice) value).toStringUtf8();
}
}
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Expected string literal but found: " + expression + " to pushdown for Druid connector.");
}
private CallExpression getExpressionAsFunction(
RowExpression originalExpression,
RowExpression expression)
{
if (expression instanceof CallExpression) {
CallExpression call = (CallExpression) expression;
if (standardFunctionResolution.isCastFunction(call.getFunctionHandle())) {
if (isImplicitCast(call.getArguments().get(0).getType(), call.getType())) {
return getExpressionAsFunction(originalExpression, call.getArguments().get(0));
}
}
else {
return call;
}
}
throw new PrestoException(DRUID_PUSHDOWN_UNSUPPORTED_EXPRESSION, "Could not dig function out of expression: " + originalExpression + ", inside of: " + expression + " to pushdown for Druid connector.");
}
}