BytecodeUtils.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.sql.gen;

import com.facebook.presto.bytecode.Binding;
import com.facebook.presto.bytecode.BytecodeBlock;
import com.facebook.presto.bytecode.BytecodeNode;
import com.facebook.presto.bytecode.CallSiteBinder;
import com.facebook.presto.bytecode.Scope;
import com.facebook.presto.bytecode.Variable;
import com.facebook.presto.bytecode.control.IfStatement;
import com.facebook.presto.bytecode.expression.BytecodeExpression;
import com.facebook.presto.bytecode.instruction.LabelNode;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.function.SqlFunctionProperties;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.operator.scalar.BuiltInScalarFunctionImplementation;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice;
import com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentProperty;
import com.facebook.presto.spi.function.JavaScalarFunctionImplementation;
import com.facebook.presto.sql.gen.InputReferenceCompiler.InputReferenceNode;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.primitives.Primitives;
import io.airlift.slice.Slice;

import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;

import static com.facebook.presto.bytecode.OpCode.NOP;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantFalse;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.constantTrue;
import static com.facebook.presto.bytecode.expression.BytecodeExpressions.invokeDynamic;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ArgumentType.VALUE_TYPE;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.NullConvention.BLOCK_AND_POSITION;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ReturnPlaceConvention.PROVIDED_BLOCKBUILDER;
import static com.facebook.presto.operator.scalar.ScalarFunctionImplementationChoice.ReturnPlaceConvention.STACK;
import static com.facebook.presto.sql.gen.Bootstrap.BOOTSTRAP_METHOD;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.String.format;
import static java.util.Objects.requireNonNull;

public final class BytecodeUtils
{
    private BytecodeUtils()
    {
    }

    public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class<?> returnType, Class<?>... stackArgsToPop)
    {
        return handleNullValue(scope, label, returnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), false);
    }

    public static BytecodeNode ifWasNullPopAndGoto(Scope scope, LabelNode label, Class<?> methodReturnType, Iterable<? extends Class<?>> stackArgsToPop)
    {
        return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), false);
    }

    public static BytecodeNode ifWasNullClearPopAndGoto(Scope scope, LabelNode label, Class<?> methodReturnType, Class<?>... stackArgsToPop)
    {
        return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.empty(), true);
    }

    public static BytecodeNode ifWasNullClearPopAppendAndGoto(Scope scope, LabelNode label, Class<?> methodReturnType, Variable outputBlockVariable, Iterable<? extends Class<?>> stackArgsToPop)
    {
        return handleNullValue(scope, label, methodReturnType, ImmutableList.copyOf(stackArgsToPop), Optional.of(outputBlockVariable), true);
    }

    public static BytecodeNode handleNullValue(Scope scope,
            LabelNode label,
            Class<?> methodReturnType,
            List<Class<?>> stackArgsToPop,
            Optional<Variable> outputBlockVariable,
            boolean clearNullFlag)
    {
        if (outputBlockVariable.isPresent()) {
            checkArgument(methodReturnType == void.class);
        }

        Variable wasNull = scope.getVariable("wasNull");

        BytecodeBlock nullCheck = new BytecodeBlock()
                .setDescription("ifWasNullGoto")
                .append(wasNull);

        String clearComment = null;
        if (clearNullFlag) {
            nullCheck.append(wasNull.set(constantFalse()));
            clearComment = "clear wasNull";
        }

        BytecodeBlock isNull = new BytecodeBlock();
        for (Class<?> parameterType : stackArgsToPop) {
            isNull.pop(parameterType);
        }

        String loadDefaultOrAppendNullComment;
        if (!outputBlockVariable.isPresent()) {
            isNull.pushJavaDefault(methodReturnType);
            loadDefaultOrAppendNullComment = format("loadJavaDefault(%s)", methodReturnType.getName());
        }
        else {
            isNull.append(outputBlockVariable.get()
                    .invoke("appendNull", BlockBuilder.class)
                    .pop());
            loadDefaultOrAppendNullComment = "appendNullToOutputBlock";
        }

        isNull.gotoLabel(label);

        String popComment = null;
        if (!stackArgsToPop.isEmpty()) {
            popComment = format("pop(%s)", Joiner.on(", ").join(stackArgsToPop));
        }

        return new IfStatement("if wasNull then %s", Joiner.on(", ").skipNulls().join(clearComment, popComment, loadDefaultOrAppendNullComment, "goto " + label.getLabel()))
                .condition(nullCheck)
                .ifTrue(isNull);
    }

    public static BytecodeNode boxPrimitive(Class<?> type)
    {
        BytecodeBlock block = new BytecodeBlock().comment("box primitive");
        if (type == long.class) {
            return block.invokeStatic(Long.class, "valueOf", Long.class, long.class);
        }
        if (type == double.class) {
            return block.invokeStatic(Double.class, "valueOf", Double.class, double.class);
        }
        if (type == boolean.class) {
            return block.invokeStatic(Boolean.class, "valueOf", Boolean.class, boolean.class);
        }
        if (type.isPrimitive()) {
            throw new UnsupportedOperationException("not yet implemented: " + type);
        }

        return NOP;
    }

    public static BytecodeNode unboxPrimitive(Class<?> unboxedType)
    {
        BytecodeBlock block = new BytecodeBlock().comment("unbox primitive");
        if (unboxedType == long.class) {
            return block.invokeVirtual(Long.class, "longValue", long.class);
        }
        if (unboxedType == double.class) {
            return block.invokeVirtual(Double.class, "doubleValue", double.class);
        }
        if (unboxedType == boolean.class) {
            return block.invokeVirtual(Boolean.class, "booleanValue", boolean.class);
        }
        throw new UnsupportedOperationException("not yet implemented: " + unboxedType);
    }

    public static BytecodeExpression loadConstant(CallSiteBinder callSiteBinder, Object constant, Class<?> type)
    {
        Binding binding = callSiteBinder.bind(MethodHandles.constant(type, constant));
        return loadConstant(binding);
    }

    public static BytecodeExpression loadConstant(Binding binding)
    {
        return invokeDynamic(
                BOOTSTRAP_METHOD,
                ImmutableList.of(binding.getBindingId()),
                "constant_" + binding.getBindingId(),
                binding.getType().returnType());
    }

    public static BytecodeNode generateInvocation(
            Scope scope,
            String name,
            JavaScalarFunctionImplementation function,
            Optional<BytecodeNode> instance,
            List<BytecodeNode> arguments,
            CallSiteBinder binder)
    {
        return generateInvocation(
                scope,
                name,
                function,
                instance,
                arguments,
                binder,
                Optional.empty());
    }

    public static BytecodeNode generateInvocation(
            Scope scope,
            String name,
            JavaScalarFunctionImplementation function,
            Optional<BytecodeNode> instance,
            List<BytecodeNode> arguments,
            CallSiteBinder binder,
            Optional<OutputBlockVariableAndType> outputBlockVariableAndType)
    {
        LabelNode end = new LabelNode("end");
        BytecodeBlock block = new BytecodeBlock()
                .setDescription("invoke " + name);

        List<Class<?>> stackTypes = new ArrayList<>();
        if (function instanceof BuiltInScalarFunctionImplementation && ((BuiltInScalarFunctionImplementation) function).getInstanceFactory().isPresent()) {
            checkArgument(instance.isPresent());
        }

        // Index of current parameter in the MethodHandle
        int currentParameterIndex = 0;

        // Index of parameter (without @IsNull) in Presto function
        int realParameterIndex = 0;

        // Go through all the choices in the function and then pick the best one
        List<ScalarFunctionImplementationChoice> choices = getAllScalarFunctionImplementationChoices(function);
        ScalarFunctionImplementationChoice bestChoice = null;
        for (ScalarFunctionImplementationChoice currentChoice : choices) {
            boolean isValid = true;
            for (int i = 0; i < arguments.size(); i++) {
                if (currentChoice.getArgumentProperty(i).getArgumentType() != VALUE_TYPE) {
                    continue;
                }
                if (currentChoice.getArgumentProperty(i).getNullConvention() == BLOCK_AND_POSITION && !(arguments.get(i) instanceof InputReferenceNode)
                        || currentChoice.getReturnPlaceConvention() == PROVIDED_BLOCKBUILDER && (!outputBlockVariableAndType.isPresent())) {
                    isValid = false;
                    break;
                }
            }
            if (isValid) {
                bestChoice = currentChoice;
            }
        }

        checkState(bestChoice != null, "None of the scalar function implementation choices are valid");
        Binding binding = binder.bind(bestChoice.getMethodHandle());

        MethodType methodType = binding.getType();
        Class<?> returnType = methodType.returnType();
        Class<?> unboxedReturnType = Primitives.unwrap(returnType);

        boolean boundInstance = false;
        while (currentParameterIndex < methodType.parameterArray().length) {
            Class<?> type = methodType.parameterArray()[currentParameterIndex];
            stackTypes.add(type);
            if (bestChoice.getInstanceFactory().isPresent() && !boundInstance) {
                checkState(type.equals(bestChoice.getInstanceFactory().get().type().returnType()), "Mismatched type for instance parameter");
                block.append(instance.get());
                boundInstance = true;
            }
            else if (type == SqlFunctionProperties.class) {
                block.append(scope.getVariable("properties"));
            }
            else if (type == BlockBuilder.class) {
                block.append(outputBlockVariableAndType.get().getOutputBlockVariable());
            }
            else {
                ArgumentProperty argumentProperty = bestChoice.getArgumentProperty(realParameterIndex);
                switch (argumentProperty.getArgumentType()) {
                    case VALUE_TYPE:
                        // Apply null convention for value type argument
                        switch (argumentProperty.getNullConvention()) {
                            case RETURN_NULL_ON_NULL:
                                block.append(arguments.get(realParameterIndex));
                                checkArgument(!Primitives.isWrapperType(type), "Non-nullable argument must not be primitive wrapper type");
                                switch (bestChoice.getReturnPlaceConvention()) {
                                    case STACK:
                                        block.append(ifWasNullPopAndGoto(scope, end, unboxedReturnType, Lists.reverse(stackTypes)));
                                        break;
                                    case PROVIDED_BLOCKBUILDER:
                                        checkArgument(unboxedReturnType == void.class);
                                        block.append(ifWasNullClearPopAppendAndGoto(scope, end, unboxedReturnType, outputBlockVariableAndType.get().getOutputBlockVariable(), Lists.reverse(stackTypes)));
                                        break;
                                    default:
                                        throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention()));
                                }
                                break;
                            case USE_NULL_FLAG:
                                block.append(arguments.get(realParameterIndex));
                                block.append(scope.getVariable("wasNull"));
                                block.append(scope.getVariable("wasNull").set(constantFalse()));
                                stackTypes.add(boolean.class);
                                currentParameterIndex++;
                                break;
                            case USE_BOXED_TYPE:
                                block.append(arguments.get(realParameterIndex));
                                block.append(boxPrimitiveIfNecessary(scope, type));
                                block.append(scope.getVariable("wasNull").set(constantFalse()));
                                break;
                            case BLOCK_AND_POSITION:
                                InputReferenceNode inputReferenceNode = (InputReferenceNode) arguments.get(realParameterIndex);
                                block.append(inputReferenceNode.produceBlockAndPosition());
                                stackTypes.add(int.class);
                                currentParameterIndex++;
                                break;
                            default:
                                throw new UnsupportedOperationException(format("Unsupported null convention: %s", argumentProperty.getNullConvention()));
                        }
                        break;
                    case FUNCTION_TYPE:
                        block.append(arguments.get(realParameterIndex));
                        break;
                    default:
                        throw new UnsupportedOperationException(format("Unsupported argument type: %s", argumentProperty.getArgumentType()));
                }
                realParameterIndex++;
            }
            currentParameterIndex++;
        }
        block.append(invoke(binding, name));

        if (bestChoice.isNullable()) {
            switch (bestChoice.getReturnPlaceConvention()) {
                case STACK:
                    block.append(unboxPrimitiveIfNecessary(scope, returnType));
                    break;
                case PROVIDED_BLOCKBUILDER:
                    // no-op
                    break;
                default:
                    throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention()));
            }
        }
        block.visitLabel(end);

        if (outputBlockVariableAndType.isPresent()) {
            switch (bestChoice.getReturnPlaceConvention()) {
                case STACK:
                    block.append(generateWrite(binder, scope, scope.getVariable("wasNull"), outputBlockVariableAndType.get().getType(), outputBlockVariableAndType.get().getOutputBlockVariable()));
                    break;
                case PROVIDED_BLOCKBUILDER:
                    // no-op
                    break;
                default:
                    throw new UnsupportedOperationException(format("Unsupported return place convention: %s", bestChoice.getReturnPlaceConvention()));
            }
        }
        return block;
    }

    public static BytecodeBlock unboxPrimitiveIfNecessary(Scope scope, Class<?> boxedType)
    {
        BytecodeBlock block = new BytecodeBlock();
        LabelNode end = new LabelNode("end");
        Class<?> unboxedType = Primitives.unwrap(boxedType);
        Variable wasNull = scope.getVariable("wasNull");

        if (unboxedType.isPrimitive()) {
            LabelNode notNull = new LabelNode("notNull");
            block.dup(boxedType)
                    .ifNotNullGoto(notNull)
                    .append(wasNull.set(constantTrue()))
                    .comment("swap boxed null with unboxed default")
                    .pop(boxedType)
                    .pushJavaDefault(unboxedType)
                    .gotoLabel(end)
                    .visitLabel(notNull)
                    .append(unboxPrimitive(unboxedType));
        }
        else {
            block.dup(boxedType)
                    .ifNotNullGoto(end)
                    .append(wasNull.set(constantTrue()));
        }
        block.visitLabel(end);

        return block;
    }

    public static BytecodeNode boxPrimitiveIfNecessary(Scope scope, Class<?> type)
    {
        checkArgument(!type.isPrimitive(), "cannot box into primitive type");
        if (!Primitives.isWrapperType(type)) {
            return NOP;
        }
        BytecodeBlock notNull = new BytecodeBlock().comment("box primitive");
        Class<?> expectedCurrentStackType;
        if (type == Long.class) {
            notNull.invokeStatic(Long.class, "valueOf", Long.class, long.class);
            expectedCurrentStackType = long.class;
        }
        else if (type == Double.class) {
            notNull.invokeStatic(Double.class, "valueOf", Double.class, double.class);
            expectedCurrentStackType = double.class;
        }
        else if (type == Boolean.class) {
            notNull.invokeStatic(Boolean.class, "valueOf", Boolean.class, boolean.class);
            expectedCurrentStackType = boolean.class;
        }
        else {
            throw new UnsupportedOperationException("not yet implemented: " + type);
        }

        BytecodeBlock condition = new BytecodeBlock().append(scope.getVariable("wasNull"));

        BytecodeBlock wasNull = new BytecodeBlock()
                .pop(expectedCurrentStackType)
                .pushNull()
                .checkCast(type);

        return new IfStatement()
                .condition(condition)
                .ifTrue(wasNull)
                .ifFalse(notNull);
    }

    public static BytecodeExpression invoke(Binding binding, String name)
    {
        return invokeDynamic(BOOTSTRAP_METHOD, ImmutableList.of(binding.getBindingId()), name, binding.getType());
    }

    public static BytecodeNode generateWrite(
            CallSiteBinder callSiteBinder,
            Scope scope,
            Variable wasNullVariable,
            Type type,
            Variable outputBlockVariable)
    {
        Class<?> valueJavaType = type.getJavaType();
        if (!valueJavaType.isPrimitive() && valueJavaType != Slice.class) {
            valueJavaType = Object.class;
        }
        String methodName = "write" + Primitives.wrap(valueJavaType).getSimpleName();

        // the value to be written is at the top of stack
        Variable tempValue = scope.createTempVariable(valueJavaType);
        return new BytecodeBlock()
                .comment("if (wasNull)")
                .append(new IfStatement()
                        .condition(wasNullVariable)
                        .ifTrue(new BytecodeBlock()
                                .comment("output.appendNull();")
                                .pop(valueJavaType)
                                .getVariable(outputBlockVariable)
                                .invokeInterface(BlockBuilder.class, "appendNull", BlockBuilder.class)
                                .pop())
                        .ifFalse(new BytecodeBlock()
                                .comment("%s.%s(output, %s)", type.getTypeSignature(), methodName, valueJavaType.getSimpleName())
                                .putVariable(tempValue)
                                .append(loadConstant(callSiteBinder.bind(type, Type.class)))
                                .getVariable(outputBlockVariable)
                                .getVariable(tempValue)
                                .invokeInterface(Type.class, methodName, void.class, BlockBuilder.class, valueJavaType)));
    }

    public static List<ScalarFunctionImplementationChoice> getAllScalarFunctionImplementationChoices(JavaScalarFunctionImplementation function)
    {
        if (function instanceof BuiltInScalarFunctionImplementation) {
            return ((BuiltInScalarFunctionImplementation) function).getAllChoices();
        }
        return ImmutableList.of(new ScalarFunctionImplementationChoice(
                function.isNullable(),
                function.getInvocationConvention().getArgumentConventions().stream().map(ArgumentProperty::valueTypeArgumentProperty).collect(toImmutableList()),
                STACK,
                function.getMethodHandle(),
                Optional.empty()));
    }

    public static class OutputBlockVariableAndType
    {
        private final Variable outputBlockVariable;
        private final Type type;

        public OutputBlockVariableAndType(Variable outputBlockVariable, Type type)
        {
            this.outputBlockVariable = requireNonNull(outputBlockVariable);
            this.type = requireNonNull(type);
        }

        public Variable getOutputBlockVariable()
        {
            return outputBlockVariable;
        }

        public Type getType()
        {
            return type;
        }
    }
}