TestDataWritableWriter.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.hive.parquet.write;

import com.facebook.airlift.log.Logger;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe;
import org.apache.hadoop.hive.ql.io.parquet.timestamp.NanoTimeUtils;
import org.apache.hadoop.hive.ql.io.parquet.write.DataWritableWriter;
import org.apache.hadoop.hive.serde2.io.DateWritable;
import org.apache.hadoop.hive.serde2.io.ParquetHiveRecord;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.ByteObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DateObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveCharObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.HiveVarcharObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.ShortObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.TimestampObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.parquet.io.api.Binary;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.GroupType;
import org.apache.parquet.schema.OriginalType;
import org.apache.parquet.schema.Type;

import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;

/**
 * This class is copied from org.apache.hadoop.hive.ql.io.parquet.write.DataWritableWriter
 * and extended to support empty arrays and maps (HIVE-13632).
 * Additionally, there is a support for arrays without an inner element layer and
 * support for maps where  MAP_KEY_VALUE is incorrectly used in place of MAP
 * for backward-compatibility rules testing (https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists)
 */
public class TestDataWritableWriter
{
    private static final Logger log = Logger.get(DataWritableWriter.class);
    private final RecordConsumer recordConsumer;
    private final GroupType schema;
    private final boolean singleLevelArray;

    public TestDataWritableWriter(final RecordConsumer recordConsumer, final GroupType schema, boolean singleLevelArray)
    {
        this.recordConsumer = recordConsumer;
        this.schema = schema;
        this.singleLevelArray = singleLevelArray;
    }

    /**
     * It writes all record values to the Parquet RecordConsumer.
     *
     * @param record Contains the record that are going to be written.
     */
    public void write(final ParquetHiveRecord record)
    {
        if (record != null) {
            recordConsumer.startMessage();
            try {
                writeGroupFields(record.getObject(), record.getObjectInspector(), schema);
            }
            catch (RuntimeException e) {
                String errorMessage = "Parquet record is malformed: " + e.getMessage();
                log.error(e, errorMessage);
                throw new RuntimeException(errorMessage, e);
            }
            recordConsumer.endMessage();
        }
    }

    /**
     * It writes all the fields contained inside a group to the RecordConsumer.
     *
     * @param value The list of values contained in the group.
     * @param inspector The object inspector used to get the correct value type.
     * @param type Type that contains information about the group schema.
     */
    private void writeGroupFields(final Object value, final StructObjectInspector inspector, final GroupType type)
    {
        if (value != null) {
            List<? extends StructField> fields = inspector.getAllStructFieldRefs();
            List<Object> fieldValuesList = inspector.getStructFieldsDataAsList(value);

            for (int i = 0; i < type.getFieldCount(); i++) {
                Type fieldType = type.getType(i);
                String fieldName = fieldType.getName();
                Object fieldValue = fieldValuesList.get(i);

                if (fieldValue != null) {
                    ObjectInspector fieldInspector = fields.get(i).getFieldObjectInspector();
                    recordConsumer.startField(fieldName, i);
                    writeValue(fieldValue, fieldInspector, fieldType);
                    recordConsumer.endField(fieldName, i);
                }
            }
        }
    }

    /**
     * It writes the field value to the Parquet RecordConsumer. It detects the field type, and calls
     * the correct write function.
     *
     * @param value The writable object that contains the value.
     * @param inspector The object inspector used to get the correct value type.
     * @param type Type that contains information about the type schema.
     */
    private void writeValue(final Object value, final ObjectInspector inspector, final Type type)
    {
        if (type.isPrimitive()) {
            checkInspectorCategory(inspector, ObjectInspector.Category.PRIMITIVE);
            writePrimitive(value, (PrimitiveObjectInspector) inspector);
        }
        else {
            GroupType groupType = type.asGroupType();
            OriginalType originalType = type.getOriginalType();

            if (originalType != null && originalType.equals(OriginalType.LIST)) {
                checkInspectorCategory(inspector, ObjectInspector.Category.LIST);
                if (singleLevelArray) {
                    writeSingleLevelArray(value, (ListObjectInspector) inspector, groupType);
                }
                else {
                    writeArray(value, (ListObjectInspector) inspector, groupType);
                }
            }
            else if (originalType != null && (originalType.equals(OriginalType.MAP) || originalType.equals(OriginalType.MAP_KEY_VALUE))) {
                checkInspectorCategory(inspector, ObjectInspector.Category.MAP);
                writeMap(value, (MapObjectInspector) inspector, groupType);
            }
            else {
                checkInspectorCategory(inspector, ObjectInspector.Category.STRUCT);
                writeGroup(value, (StructObjectInspector) inspector, groupType);
            }
        }
    }

    /**
     * Checks that an inspector matches the category indicated as a parameter.
     *
     * @param inspector The object inspector to check
     * @param category The category to match
     * @throws IllegalArgumentException if inspector does not match the category
     */
    private void checkInspectorCategory(ObjectInspector inspector, ObjectInspector.Category category)
    {
        if (!inspector.getCategory().equals(category)) {
            throw new IllegalArgumentException("Invalid data type: expected " + category
                    + " type, but found: " + inspector.getCategory());
        }
    }

    /**
     * It writes a group type and all its values to the Parquet RecordConsumer.
     * This is used only for optional and required groups.
     *
     * @param value Object that contains the group values.
     * @param inspector The object inspector used to get the correct value type.
     * @param type Type that contains information about the group schema.
     */
    private void writeGroup(final Object value, final StructObjectInspector inspector, final GroupType type)
    {
        recordConsumer.startGroup();
        writeGroupFields(value, inspector, type);
        recordConsumer.endGroup();
    }

    /**
     * It writes a list type and its array elements to the Parquet RecordConsumer.
     * This is called when the original type (LIST) is detected by writeValue()/
     * This function assumes the following schema:
     * optional group arrayCol (LIST) {
     * repeated group array {
     * optional TYPE array_element;
     * }
     * }
     *
     * @param value The object that contains the array values.
     * @param inspector The object inspector used to get the correct value type.
     * @param type Type that contains information about the group (LIST) schema.
     */
    private void writeArray(final Object value, final ListObjectInspector inspector, final GroupType type)
    {
        // Get the internal array structure
        GroupType repeatedType = type.getType(0).asGroupType();
        recordConsumer.startGroup();

        List<?> arrayValues = inspector.getList(value);
        if (!arrayValues.isEmpty()) {
            recordConsumer.startField(repeatedType.getName(), 0);
            ObjectInspector elementInspector = inspector.getListElementObjectInspector();

            Type elementType = repeatedType.getType(0);
            String elementName = elementType.getName();

            for (Object element : arrayValues) {
                recordConsumer.startGroup();
                if (element != null) {
                    recordConsumer.startField(elementName, 0);
                    writeValue(element, elementInspector, elementType);
                    recordConsumer.endField(elementName, 0);
                }
                recordConsumer.endGroup();
            }

            recordConsumer.endField(repeatedType.getName(), 0);
        }
        recordConsumer.endGroup();
    }

    private void writeSingleLevelArray(final Object value, final ListObjectInspector inspector, final GroupType type)
    {
        // Get the internal array structure
        Type elementType = type.getType(0);

        recordConsumer.startGroup();

        List<?> arrayValues = inspector.getList(value);
        if (!arrayValues.isEmpty()) {
            recordConsumer.startField(elementType.getName(), 0);
            ObjectInspector elementInspector = inspector.getListElementObjectInspector();

            for (Object element : arrayValues) {
                if (element == null) {
                    throw new IllegalArgumentException("Array elements are requires in given schema definition");
                }
                writeValue(element, elementInspector, elementType);
            }

            recordConsumer.endField(elementType.getName(), 0);
        }
        recordConsumer.endGroup();
    }

    /**
     * It writes a map type and its key-pair values to the Parquet RecordConsumer.
     * This is called when the original type (MAP) is detected by writeValue().
     * This function assumes the following schema:
     * optional group mapCol (MAP) {
     * repeated group map (MAP_KEY_VALUE) {
     * required TYPE key;
     * optional TYPE value;
     * }
     * }
     *
     * @param value The object that contains the map key-values.
     * @param inspector The object inspector used to get the correct value type.
     * @param type Type that contains information about the group (MAP) schema.
     */
    private void writeMap(final Object value, final MapObjectInspector inspector, final GroupType type)
    {
        // Get the internal map structure (MAP_KEY_VALUE)
        GroupType repeatedType = type.getType(0).asGroupType();

        recordConsumer.startGroup();
        Map<?, ?> mapValues = inspector.getMap(value);
        if (mapValues != null && !mapValues.isEmpty()) {
            recordConsumer.startField(repeatedType.getName(), 0);

            Type keyType = repeatedType.getType(0);
            String keyName = keyType.getName();
            ObjectInspector keyInspector = inspector.getMapKeyObjectInspector();

            Type valuetype = repeatedType.getType(1);
            String valueName = valuetype.getName();
            ObjectInspector valueInspector = inspector.getMapValueObjectInspector();

            for (Map.Entry<?, ?> keyValue : mapValues.entrySet()) {
                recordConsumer.startGroup();
                if (keyValue != null) {
                    // write key element
                    Object keyElement = keyValue.getKey();
                    recordConsumer.startField(keyName, 0);
                    writeValue(keyElement, keyInspector, keyType);
                    recordConsumer.endField(keyName, 0);

                    // write value element
                    Object valueElement = keyValue.getValue();
                    if (valueElement != null) {
                        recordConsumer.startField(valueName, 1);
                        writeValue(valueElement, valueInspector, valuetype);
                        recordConsumer.endField(valueName, 1);
                    }
                }
                recordConsumer.endGroup();
            }

            recordConsumer.endField(repeatedType.getName(), 0);
        }
        recordConsumer.endGroup();
    }

    /**
     * It writes the primitive value to the Parquet RecordConsumer.
     *
     * @param value The object that contains the primitive value.
     * @param inspector The object inspector used to get the correct value type.
     */
    private void writePrimitive(final Object value, final PrimitiveObjectInspector inspector)
    {
        if (value == null) {
            return;
        }

        switch (inspector.getPrimitiveCategory()) {
            case VOID:
                return;
            case DOUBLE:
                recordConsumer.addDouble(((DoubleObjectInspector) inspector).get(value));
                break;
            case BOOLEAN:
                recordConsumer.addBoolean(((BooleanObjectInspector) inspector).get(value));
                break;
            case FLOAT:
                recordConsumer.addFloat(((FloatObjectInspector) inspector).get(value));
                break;
            case BYTE:
                recordConsumer.addInteger(((ByteObjectInspector) inspector).get(value));
                break;
            case INT:
                recordConsumer.addInteger(((IntObjectInspector) inspector).get(value));
                break;
            case LONG:
                recordConsumer.addLong(((LongObjectInspector) inspector).get(value));
                break;
            case SHORT:
                recordConsumer.addInteger(((ShortObjectInspector) inspector).get(value));
                break;
            case STRING:
                String v = ((StringObjectInspector) inspector).getPrimitiveJavaObject(value);
                recordConsumer.addBinary(Binary.fromString(v));
                break;
            case CHAR:
                String vChar = ((HiveCharObjectInspector) inspector).getPrimitiveJavaObject(value).getStrippedValue();
                recordConsumer.addBinary(Binary.fromString(vChar));
                break;
            case VARCHAR:
                String vVarchar = ((HiveVarcharObjectInspector) inspector).getPrimitiveJavaObject(value).getValue();
                recordConsumer.addBinary(Binary.fromString(vVarchar));
                break;
            case BINARY:
                byte[] vBinary = ((BinaryObjectInspector) inspector).getPrimitiveJavaObject(value);
                recordConsumer.addBinary(Binary.fromByteArray(vBinary));
                break;
            case TIMESTAMP:
                Timestamp ts = ((TimestampObjectInspector) inspector).getPrimitiveJavaObject(value);
                recordConsumer.addBinary(NanoTimeUtils.getNanoTime(ts, false).toBinary());
                break;
            case DECIMAL:
                HiveDecimal vDecimal = ((HiveDecimal) inspector.getPrimitiveJavaObject(value));
                DecimalTypeInfo decTypeInfo = (DecimalTypeInfo) inspector.getTypeInfo();
                recordConsumer.addBinary(decimalToBinary(vDecimal, decTypeInfo));
                break;
            case DATE:
                Date vDate = ((DateObjectInspector) inspector).getPrimitiveJavaObject(value);
                recordConsumer.addInteger(DateWritable.dateToDays(vDate));
                break;
            default:
                throw new IllegalArgumentException("Unsupported primitive data type: " + inspector.getPrimitiveCategory());
        }
    }

    private Binary decimalToBinary(final HiveDecimal hiveDecimal, final DecimalTypeInfo decimalTypeInfo)
    {
        int prec = decimalTypeInfo.precision();
        int scale = decimalTypeInfo.scale();
        byte[] decimalBytes = hiveDecimal.setScale(scale).unscaledValue().toByteArray();

        // Estimated number of bytes needed.
        int precToBytes = ParquetHiveSerDe.PRECISION_TO_BYTE_COUNT[prec - 1];
        if (precToBytes == decimalBytes.length) {
            // No padding needed.
            return Binary.fromByteArray(decimalBytes);
        }

        byte[] tgt = new byte[precToBytes];
        if (hiveDecimal.signum() == -1) {
            // For negative number, initializing bits to 1
            for (int i = 0; i < precToBytes; i++) {
                tgt[i] |= 0xFF;
            }
        }

        System.arraycopy(decimalBytes, 0, tgt, precToBytes - decimalBytes.length, decimalBytes.length); // Padding leading zeroes/ones.
        return Binary.fromByteArray(tgt);
    }
}