ObjectEncoders.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.hive.functions.type;

import com.facebook.presto.common.PageBuilder;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.DuplicateMapKeyException;
import com.facebook.presto.common.block.MapBlockBuilder;
import com.facebook.presto.common.block.RowBlockBuilder;
import com.facebook.presto.common.block.SingleRowBlockWriter;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.DecimalType;
import com.facebook.presto.common.type.Decimals;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.PrestoException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Streams;
import io.airlift.slice.Slices;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveCharObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveDecimalObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveVarcharObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;

import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.function.Function;
import java.util.stream.Collectors;

import static com.facebook.presto.common.type.StandardTypes.ARRAY;
import static com.facebook.presto.common.type.StandardTypes.BIGINT;
import static com.facebook.presto.common.type.StandardTypes.BOOLEAN;
import static com.facebook.presto.common.type.StandardTypes.CHAR;
import static com.facebook.presto.common.type.StandardTypes.DATE;
import static com.facebook.presto.common.type.StandardTypes.DECIMAL;
import static com.facebook.presto.common.type.StandardTypes.DOUBLE;
import static com.facebook.presto.common.type.StandardTypes.INTEGER;
import static com.facebook.presto.common.type.StandardTypes.MAP;
import static com.facebook.presto.common.type.StandardTypes.REAL;
import static com.facebook.presto.common.type.StandardTypes.ROW;
import static com.facebook.presto.common.type.StandardTypes.SMALLINT;
import static com.facebook.presto.common.type.StandardTypes.TIMESTAMP;
import static com.facebook.presto.common.type.StandardTypes.TINYINT;
import static com.facebook.presto.common.type.StandardTypes.VARBINARY;
import static com.facebook.presto.common.type.StandardTypes.VARCHAR;
import static com.facebook.presto.common.type.TypeUtils.writeNativeValue;
import static com.facebook.presto.hive.functions.HiveFunctionErrorCode.unsupportedType;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.google.common.base.Preconditions.checkArgument;
import static java.lang.Float.floatToRawIntBits;
import static java.util.Objects.requireNonNull;

public final class ObjectEncoders
{
    private ObjectEncoders() {}

    public static ObjectEncoder createEncoder(Type type, ObjectInspector inspector)
    {
        String base = type.getTypeSignature().getBase();
        switch (base) {
            case BIGINT:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Long) o));
            case INTEGER:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Integer) o).longValue());
            case SMALLINT:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Short) o).longValue());
            case TINYINT:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Byte) o).longValue());
            case BOOLEAN:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Boolean) o));
            case DATE:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Date) o).getTime());
            case DECIMAL:
                if (Decimals.isShortDecimal(type)) {
                    DecimalType decimalType = (DecimalType) type;
                    return compose(decimal(inspector), o -> DecimalUtils.encodeToLong((BigDecimal) o, decimalType));
                }
                else if (Decimals.isLongDecimal(type)) {
                    DecimalType decimalType = (DecimalType) type;
                    return compose(decimal(inspector), o -> DecimalUtils.encodeToSlice((BigDecimal) o, decimalType));
                }
                break;
            case REAL:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> floatToRawIntBits(((Number) o).floatValue()));
            case DOUBLE:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> (Double) o);
            case TIMESTAMP:
                checkArgument(inspector instanceof PrimitiveObjectInspector);
                return compose(primitive(inspector), o -> ((Timestamp) o).getTime());
            case VARBINARY:
                if (inspector instanceof BinaryObjectInspector) {
                    return compose(primitive(inspector), o -> Slices.wrappedBuffer(((byte[]) o)));
                }
                break;
            case VARCHAR:
                if (inspector instanceof StringObjectInspector) {
                    return compose(primitive(inspector), o -> Slices.utf8Slice(o.toString()));
                }
                else if (inspector instanceof HiveVarcharObjectInspector) {
                    return compose(o -> ((HiveVarcharObjectInspector) inspector).getPrimitiveJavaObject(o).getValue(),
                            o -> Slices.utf8Slice(((String) o)));
                }
                break;
            case CHAR:
                if (inspector instanceof StringObjectInspector) {
                    return compose(primitive(inspector), o -> Slices.utf8Slice(o.toString()));
                }
                else if (inspector instanceof HiveCharObjectInspector) {
                    return compose(o -> ((HiveCharObjectInspector) inspector).getPrimitiveJavaObject(o).getValue(),
                            o -> Slices.utf8Slice(((String) o)));
                }
                break;
            case ROW:
                return StructObjectEncoder.create(type, inspector);
            case ARRAY:
                return ListObjectEncoder.create(type, inspector);
            case MAP:
                return MapObjectEncoder.create(type, inspector);
        }
        throw unsupportedType(type);
    }

    private static ObjectEncoder compose(Function<Object, Object> inspector, Function<Object, Object> encoder)
    {
        return o -> {
            if (o != null) {
                Object inspected = inspector.apply(o);
                if (inspected != null) {
                    return encoder.apply(inspected);
                }
            }
            return null;
        };
    }

    private static Function<Object, Object> primitive(ObjectInspector inspector)
    {
        return ((PrimitiveObjectInspector) inspector)::getPrimitiveJavaObject;
    }

    private static Function<Object, Object> decimal(ObjectInspector inspector)
    {
        return o -> ((HiveDecimalObjectInspector) inspector).getPrimitiveJavaObject(o).bigDecimalValue();
    }

    public static class ListObjectEncoder
            implements ObjectEncoder
    {
        private final ListObjectInspector listInspector;
        private final Type elementType;
        private final BlockObjectWriter writer;

        public static ListObjectEncoder create(Type type, ObjectInspector inspector)
        {
            checkArgument(inspector instanceof ListObjectInspector && type instanceof ArrayType);

            Type elementType = ((ArrayType) type).getElementType();
            ListObjectInspector listInspector = (ListObjectInspector) inspector;
            ObjectEncoder elementEncoder = createEncoder(elementType,
                    listInspector.getListElementObjectInspector());
            return new ListObjectEncoder(listInspector, elementType, elementEncoder);
        }

        private ListObjectEncoder(ListObjectInspector listInspector, Type elementType, ObjectEncoder elementEncoder)
        {
            this.listInspector = requireNonNull(listInspector, "listInspector is null");
            this.elementType = requireNonNull(elementType, "elementType is null");
            this.writer = new SimpleBlockObjectWriter(elementEncoder, elementType);
        }

        @Override
        public Object encode(Object o)
        {
            if (o == null) {
                return null;
            }
            final int length = listInspector.getListLength(o);
            final BlockBuilder blockBuilder = elementType.createBlockBuilder(null, length);
            for (int i = 0; i < length; i++) {
                writer.write(blockBuilder, listInspector.getListElement(o, i));
            }
            return blockBuilder.build();
        }
    }

    public static class MapObjectEncoder
            implements ObjectEncoder
    {
        private final MapType mapType;
        private final MapObjectInspector mapObjectInspector;
        private final BlockObjectWriter keyWriter;
        private final BlockObjectWriter valueWriter;

        public static MapObjectEncoder create(Type type, Object inspector)
        {
            checkArgument(type instanceof MapType &&
                    inspector instanceof MapObjectInspector);
            return new MapObjectEncoder(((MapType) type), ((MapObjectInspector) inspector));
        }

        private MapObjectEncoder(MapType type, MapObjectInspector inspector)
        {
            this.mapType = requireNonNull(type, "mapType is null");
            this.mapObjectInspector = requireNonNull(inspector, "inspector is null");
            Type keyType = type.getKeyType();
            Type valueType = type.getValueType();
            ObjectEncoder keyEncoder = createEncoder(keyType, inspector.getMapKeyObjectInspector());
            ObjectEncoder valueEncoder = createEncoder(valueType, inspector.getMapValueObjectInspector());
            this.keyWriter = requireNonNull(createBlockObjectWriter(keyEncoder, keyType), "keyWriter is null");
            this.valueWriter = requireNonNull(createBlockObjectWriter(valueEncoder, valueType), "valueWriter is null");
        }

        @Override
        public Object encode(Object object)
        {
            if (object == null) {
                return null;
            }
            Map<?, ?> rawMap = mapObjectInspector.getMap(object);

            MapBlockBuilder mapBlockBuilder = (MapBlockBuilder) mapType.createBlockBuilder(null, rawMap.size());
            BlockBuilder blockBuilder = mapBlockBuilder.beginBlockEntry();
            for (Entry<?, ?> entry : rawMap.entrySet()) {
                if (entry.getKey() == null) {
                    mapBlockBuilder.closeEntry();
                    throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "map key cannot be null");
                }
                // TODO check indeterminate
                keyWriter.write(blockBuilder, entry.getKey());
                valueWriter.write(blockBuilder, entry.getValue());
            }
            try {
                mapBlockBuilder.closeEntryStrict(mapType.getKeyBlockEquals(), mapType.getKeyBlockHashCode());
            }
            catch (DuplicateMapKeyException e) {
                throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e);
            }
            return mapType.getObject(mapBlockBuilder, mapBlockBuilder.getPositionCount() - 1);
        }
    }

    public static class StructObjectEncoder
            implements ObjectEncoder
    {
        private final RowType type;
        private final StructObjectInspector inspector;
        private final List<BlockObjectWriter> fieldWriters;

        public static StructObjectEncoder create(Type type, Object inspector)
        {
            checkArgument((type instanceof RowType) && (inspector instanceof StructObjectInspector));
            return new StructObjectEncoder((RowType) type, (StructObjectInspector) inspector);
        }

        private static BlockObjectWriter createFieldBlockObjectWriter(Type type, StructField field)
        {
            ObjectEncoder encoder = createEncoder(type, field.getFieldObjectInspector());
            return createBlockObjectWriter(encoder, type);
        }

        public StructObjectEncoder(RowType type, StructObjectInspector inspector)
        {
            this.type = requireNonNull(type, "type is null");
            this.inspector = requireNonNull(inspector, "inspector is null");
            this.fieldWriters = Streams.zip(
                    type.getFields().stream().map(RowType.Field::getType),
                    inspector.getAllStructFieldRefs().stream(),
                    StructObjectEncoder::createFieldBlockObjectWriter)
                    .collect(Collectors.toList());
        }

        @Override
        public Object encode(Object object)
        {
            if (object == null) {
                return null;
            }

            PageBuilder pageBuilder = new PageBuilder(ImmutableList.of(type));
            RowBlockBuilder rowBlockBuilder = (RowBlockBuilder) pageBuilder.getBlockBuilder(0);
            SingleRowBlockWriter blockBuilder = rowBlockBuilder.beginBlockEntry();
            List<Object> fieldObjects = inspector.getStructFieldsDataAsList(object);
            final int totalNumField = fieldWriters.size();
            final int numField = fieldObjects.size();
            for (int i = 0; i < totalNumField; i++) {
                fieldWriters.get(i).write(blockBuilder, i < numField ? fieldObjects.get(i) : null);
            }
            rowBlockBuilder.closeEntry();
            pageBuilder.declarePosition();
            return type.getObject(rowBlockBuilder, rowBlockBuilder.getPositionCount() - 1);
        }
    }

    private static BlockObjectWriter createBlockObjectWriter(ObjectEncoder encoder, Type type)
    {
        return new SimpleBlockObjectWriter(encoder, type);
    }

    private interface BlockObjectWriter
    {
        void write(BlockBuilder out, Object object);
    }

    private static class SimpleBlockObjectWriter
            implements BlockObjectWriter
    {
        private final ObjectEncoder objectEncoder;
        private final Type objectType;

        private SimpleBlockObjectWriter(ObjectEncoder objectEncoder, Type objectType)
        {
            this.objectEncoder = requireNonNull(objectEncoder, "objectEncoder is null");
            this.objectType = requireNonNull(objectType, "objectType is null");
        }

        @Override
        public void write(BlockBuilder out, Object object)
        {
            if (object != null) {
                Object encoded = objectEncoder.encode(object);
                if (encoded != null) {
                    writeNativeValue(objectType, out, encoded);
                    return;
                }
            }
            out.appendNull();
        }
    }
}