SwitchCodeGenerator.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.gen;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.instruction.LabelNode;
import com.facebook.presto.bytecode.instruction.VariableInstruction;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.function.FunctionHandle;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.ConstantExpression;
import com.facebook.presto.spi.relation.RowExpression;
import com.facebook.presto.spi.relation.SpecialFormExpression;
import com.google.common.collect.ImmutableList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue;
import static com.facebook.presto.common.function.OperatorType.EQUAL;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.WHEN;
import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
import static com.facebook.presto.sql.gen.SpecialFormBytecodeGenerator.generateWrite;
import static com.google.common.base.Preconditions.checkArgument;
public class SwitchCodeGenerator
implements SpecialFormBytecodeGenerator
{
private static final String CASE_LABEL_PREFIX = "_case_";
private static final String RESULT_LABEL_PREFIX = "_result_";
// TODO - move this to a RowExpressionUtil class
private static boolean isEqualsExpression(RowExpression expression)
{
return expression instanceof CallExpression
&& ((CallExpression) expression).getDisplayName().equals(EQUAL.getFunctionName().getObjectName())
&& ((CallExpression) expression).getArguments().size() == 2;
}
@Override
public BytecodeNode generateExpression(BytecodeGeneratorContext generatorContext, Type returnType, List<RowExpression> arguments, Optional<Variable> outputBlockVariable)
{
Scope scope = generatorContext.getScope();
BytecodeNode elseValue;
List<RowExpression> whenClauses;
RowExpression last = arguments.get(arguments.size() - 1);
if (last instanceof SpecialFormExpression && ((SpecialFormExpression) last).getForm().equals(WHEN)) {
whenClauses = arguments.subList(1, arguments.size());
elseValue = new BytecodeBlock()
.append(generatorContext.wasNull().set(constantTrue()))
.pushJavaDefault(returnType.getJavaType());
}
else {
whenClauses = arguments.subList(1, arguments.size() - 1);
elseValue = generatorContext.generate(last, Optional.empty());
}
// determine the type of the value and result
RowExpression value = arguments.get(0);
Class<?> valueType = value.getType().getJavaType();
// We generate SearchedCase as CASE TRUE WHEN p1 THEN v1 WHEN p2 THEN p2...
boolean searchedCase = (value instanceof ConstantExpression && ((ConstantExpression) value).getType() == BOOLEAN &&
((ConstantExpression) value).getValue() == Boolean.TRUE);
// evaluate the value and store it in a variable
LabelNode elseLabel = new LabelNode("else");
LabelNode endLabel = new LabelNode("end");
BytecodeBlock block = new BytecodeBlock();
Optional<BytecodeNode> getTempVariableNode;
if (!searchedCase) {
BytecodeNode valueBytecode = generatorContext.generate(value, Optional.empty());
Variable tempVariable = scope.createTempVariable(valueType);
block.append(valueBytecode)
.append(BytecodeUtils.ifWasNullClearPopAndGoto(scope, elseLabel, void.class, valueType))
.putVariable(tempVariable);
getTempVariableNode = Optional.of(VariableInstruction.loadVariable(tempVariable));
}
else {
getTempVariableNode = Optional.empty();
}
Variable wasNull = generatorContext.wasNull();
block.putVariable(wasNull, false);
Map<RowExpression, LabelNode> resultLabels = new HashMap<>();
// We already know the P1 .. Pn are all boolean just call them and search for true (false/null don't matter).
for (RowExpression clause : whenClauses) {
checkArgument(clause instanceof SpecialFormExpression && ((SpecialFormExpression) clause).getForm().equals(WHEN));
RowExpression operand = ((SpecialFormExpression) clause).getArguments().get(0);
BytecodeNode operandBytecode;
if (searchedCase) {
operandBytecode = generatorContext.generate(operand, Optional.empty());
}
else {
// call equals(value, operandBytecode)
FunctionHandle equalsFunction = generatorContext.getFunctionManager().resolveOperator(EQUAL, fromTypes(value.getType(), operand.getType()));
operandBytecode = generatorContext.generateCall(
EQUAL.name(),
generatorContext.getFunctionManager().getJavaScalarFunctionImplementation(equalsFunction),
ImmutableList.of(
generatorContext.generate(operand,
Optional.empty()),
getTempVariableNode.get()));
}
block.append(operandBytecode);
IfStatement ifWasNull = new IfStatement().condition(wasNull);
ifWasNull.ifTrue()
.putVariable(wasNull, false)
.pop(Boolean.class); // pop the result of the predicate eval
// Here the TOS is the result of the predicate.
RowExpression result = ((SpecialFormExpression) clause).getArguments().get(1);
LabelNode target = resultLabels.get(result);
if (target == null) {
target = new LabelNode(RESULT_LABEL_PREFIX + resultLabels.size());
resultLabels.put(result, target);
}
ifWasNull.ifFalse().ifTrueGoto(target);
block.append(ifWasNull);
}
// Here we evaluate the else result.
block.visitLabel(elseLabel)
.append(elseValue)
.gotoLabel(endLabel);
// Now generate the result expression code.
for (Map.Entry<RowExpression, LabelNode> resultLabel : resultLabels.entrySet()) {
block.visitLabel(resultLabel.getValue())
.append(generatorContext.generate(resultLabel.getKey(), Optional.empty()))
.gotoLabel(endLabel);
}
block.visitLabel(endLabel);
outputBlockVariable.ifPresent(output -> block.append(generateWrite(generatorContext, returnType, output)));
return block;
}
}