TestEncryption.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.parquet.reader;
import com.facebook.presto.common.block.Block;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.parquet.Field;
import com.facebook.presto.parquet.GroupField;
import com.facebook.presto.parquet.ParquetDataSource;
import com.facebook.presto.parquet.ParquetDataSourceId;
import com.facebook.presto.parquet.PrimitiveField;
import com.facebook.presto.parquet.RichColumnDescriptor;
import com.facebook.presto.parquet.cache.MetadataReader;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import io.airlift.units.DataSize;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.crypto.FileDecryptionProperties;
import org.apache.parquet.crypto.InternalFileDecryptor;
import org.apache.parquet.crypto.ParquetCipher;
import org.apache.parquet.hadoop.metadata.BlockMetaData;
import org.apache.parquet.hadoop.metadata.FileMetaData;
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
import org.apache.parquet.io.ColumnIO;
import org.apache.parquet.io.GroupColumnIO;
import org.apache.parquet.io.MessageColumnIO;
import org.apache.parquet.io.PrimitiveColumnIO;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.PrimitiveType;
import org.testng.annotations.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.parquet.ParquetTypeUtils.getArrayElementColumn;
import static com.facebook.presto.parquet.ParquetTypeUtils.getColumnIO;
import static com.facebook.presto.parquet.ParquetTypeUtils.getMapKeyValueColumn;
import static com.facebook.presto.parquet.ParquetTypeUtils.lookupColumnByName;
import static org.apache.parquet.io.ColumnIOUtil.columnDefinitionLevel;
import static org.apache.parquet.io.ColumnIOUtil.columnRepetitionLevel;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.BINARY;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT32;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
import static org.apache.parquet.schema.Type.Repetition.OPTIONAL;
import static org.apache.parquet.schema.Type.Repetition.REQUIRED;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNull;
public class TestEncryption
{
private final Configuration conf = new Configuration(false);
@Test
public void testBasicDecryption()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
Map<String, String> extraMetadata = new HashMap<String, String>() {{
put("key1", "value1");
put("key2", "value2");
}};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withCodec("GZIP")
.withExtraMeta(extraMetadata)
.withPageSize(1000)
.withDictionaryEnabled()
.withFooterEncryption()
.build();
decryptAndValidate(inputFile);
}
@Test
public void testAllColumnsDecryption()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"id", "bal", "name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withCodec("GZIP")
.withPageSize(1000)
.withFooterEncryption()
.withDictionaryEnabled()
.build();
decryptAndValidate(inputFile);
}
@Test
public void testNoColumnsDecryption()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withCodec("GZIP")
.withPageSize(1000)
.withDictionaryEnabled()
.withFooterEncryption()
.build();
decryptAndValidate(inputFile);
}
@Test
public void testOneRecord()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(1)
.withCodec("GZIP")
.withPageSize(1000)
.withDictionaryEnabled()
.withFooterEncryption()
.build();
decryptAndValidate(inputFile);
}
@Test
public void testMillionRows()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(1000000)
.withCodec("GZIP")
.withPageSize(1000)
.withDictionaryEnabled()
.withFooterEncryption()
.build();
decryptAndValidate(inputFile);
}
@Test
public void testPlainTextFooter()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withCodec("SNAPPY")
.withDictionaryEnabled()
.withPageSize(1000)
.build();
decryptAndValidate(inputFile);
}
@Test
public void testLargePageSize()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(100000)
.withCodec("GZIP")
.withPageSize(100000)
.withDictionaryEnabled()
.withFooterEncryption()
.build();
decryptAndValidate(inputFile);
}
@Test
public void testAesGcmCtr()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(100000)
.withCodec("GZIP")
.withPageSize(1000)
.withDictionaryEnabled()
.withEncrytionAlgorithm(ParquetCipher.AES_GCM_CTR_V1)
.build();
decryptAndValidate(inputFile);
}
@Test
public void testDataMaskingSingleColumn()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name"};
String[] maskingColumn = {"name"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withCodec("GZIP")
.withPageSize(1000)
.withFooterEncryption()
.withDataMaskingTest()
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingMultipleColumns()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
String[] maskingColumn = {"name", "gender"};
Map<String, String> extraMetadata = new HashMap<String, String>() {{
put("key1", "value1");
put("key2", "value2");
}};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withCodec("GZIP")
.withExtraMeta(extraMetadata)
.withPageSize(1000)
.withFooterEncryption()
.withDataMaskingTest()
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingEncryptedFooter()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
String[] maskingColumn = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withDataMaskingTest()
.withCodec("GZIP")
.withFooterEncryption()
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingPlaintextFooter()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
String[] maskingColumn = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withDataMaskingTest()
.withCodec("GZIP")
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingGcmCtr()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
String[] maskingColumn = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withDataMaskingTest()
.withCodec("GZIP")
.withEncrytionAlgorithm(ParquetCipher.AES_GCM_CTR_V1)
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingLargePage()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
String[] maskingColumn = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withDataMaskingTest()
.withCodec("GZIP")
.withPageSize(100000)
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingOneRecord()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"name", "gender"};
String[] maskingColumn = {"name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(1)
.withDataMaskingTest()
.withCodec("GZIP")
.build();
validateMasking(inputFile, maskingColumn);
}
@Test
public void testDataMaskingAllColumns()
throws IOException
{
MessageType schema = createSchema();
String[] encryptColumns = {"id", "name", "gender"};
String[] maskingColumn = {"id", "name", "gender"};
TestFile inputFile = new TestFileBuilder(conf, schema)
.withEncryptColumns(encryptColumns)
.withNumRecord(10000)
.withDataMaskingTest()
.withCodec("GZIP")
.build();
validateMasking(inputFile, maskingColumn);
}
private MessageType createSchema()
{
return new MessageType("schema",
new PrimitiveType(OPTIONAL, INT64, "id"),
new PrimitiveType(OPTIONAL, INT32, "bal"),
new PrimitiveType(REQUIRED, BINARY, "name"),
new PrimitiveType(OPTIONAL, BINARY, "gender"));
}
private void decryptAndValidate(TestFile inputFile)
throws IOException
{
Path path = new Path(inputFile.getFileName());
FileSystem fileSystem = path.getFileSystem(conf);
FSDataInputStream inputStream = fileSystem.open(path);
Optional<InternalFileDecryptor> fileDecryptor = createFileDecryptor();
ParquetDataSource dataSource = new MockParquetDataSource(new ParquetDataSourceId(path.toString()), inputStream);
ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, inputFile.getFileSize(), fileDecryptor, false).getParquetMetadata();
FileMetaData fileMetaData = parquetMetadata.getFileMetaData();
MessageType fileSchema = fileMetaData.getSchema();
MessageColumnIO messageColumn = getColumnIO(fileSchema, fileSchema);
ParquetReader parquetReader = createParquetReader(parquetMetadata, messageColumn, dataSource, fileDecryptor);
validateFile(parquetReader, messageColumn, inputFile);
}
private void validateMasking(TestFile inputFile, String[] maskingColumn)
throws IOException
{
Path path = new Path(inputFile.getFileName());
FileSystem fileSystem = path.getFileSystem(conf);
FSDataInputStream inputStream = fileSystem.open(path);
Optional<InternalFileDecryptor> fileDecryptor = createFileDecryptor();
ParquetDataSource dataSource = new MockParquetDataSource(new ParquetDataSourceId(path.toString()), inputStream);
ParquetMetadata parquetMetadata = MetadataReader.readFooter(dataSource, inputFile.getFileSize(), fileDecryptor, true).getParquetMetadata();
FileMetaData fileMetaData = parquetMetadata.getFileMetaData();
MessageType fileSchema = fileMetaData.getSchema();
MessageColumnIO messageColumn = getColumnIO(fileSchema, fileSchema);
ParquetReader parquetReader = createParquetReader(parquetMetadata, messageColumn, dataSource, fileDecryptor);
validateFile(parquetReader, messageColumn, inputFile, maskingColumn);
}
private Optional<InternalFileDecryptor> createFileDecryptor()
{
FileDecryptionProperties fileDecryptionProperties = EncryptDecryptUtil.getFileDecryptionProperties();
if (fileDecryptionProperties != null) {
return Optional.of(new InternalFileDecryptor(fileDecryptionProperties));
}
return Optional.empty();
}
private static void validateFile(ParquetReader parquetReader, MessageColumnIO messageColumn, TestFile inputFile)
throws IOException
{
String[] maskingColumn = {};
validateFile(parquetReader, messageColumn, inputFile, maskingColumn);
}
private static void validateFile(ParquetReader parquetReader, MessageColumnIO messageColumn, TestFile inputFile, String[] maskingColumn)
throws IOException
{
int rowIndex = 0;
int batchSize = parquetReader.nextBatch();
while (batchSize > 0) {
validateColumn("id", BIGINT, rowIndex, parquetReader, messageColumn, inputFile, maskingColumn);
validateColumn("bal", INTEGER, rowIndex, parquetReader, messageColumn, inputFile, maskingColumn);
validateColumn("name", VARCHAR, rowIndex, parquetReader, messageColumn, inputFile, maskingColumn);
validateColumn("gender", VARCHAR, rowIndex, parquetReader, messageColumn, inputFile, maskingColumn);
rowIndex += batchSize;
batchSize = parquetReader.nextBatch();
}
}
@VisibleForTesting
static void validateColumn(String name, Type type, int rowIndex, ParquetReader parquetReader, MessageColumnIO messageColumn, TestFile inputFile, String[] maskingColumn)
throws IOException
{
HashSet<String> maskingColumnSet = new HashSet<>(Arrays.asList(maskingColumn));
if (maskingColumnSet.contains(name)) {
Field columnIO = constructField(type, lookupColumnByName(messageColumn, name)).orElse(null);
assertNull(columnIO);
}
else {
Block block = parquetReader.readBlock(constructField(type, lookupColumnByName(messageColumn, name)).orElse(null));
for (int i = 0; i < block.getPositionCount(); i++) {
if (type.equals(BIGINT)) {
assertEquals(inputFile.getFileContent()[rowIndex++].getLong(name, 0), block.getLong(i));
}
else if (type.equals(INTEGER)) {
assertEquals(inputFile.getFileContent()[rowIndex++].getInteger(name, 0), block.getInt(i));
}
else if (type.equals(VARCHAR)) {
assertEquals(inputFile.getFileContent()[rowIndex++].getString(name, 0), block.getSlice(i, 0, block.getSliceLength(i)).toStringUtf8());
}
}
}
}
@VisibleForTesting
static Optional<Field> constructField(Type type, ColumnIO columnIO)
{
if (columnIO == null) {
return Optional.empty();
}
boolean required = columnIO.getType().getRepetition() != OPTIONAL;
int repetitionLevel = columnRepetitionLevel(columnIO);
int definitionLevel = columnDefinitionLevel(columnIO);
if (type instanceof RowType) {
RowType rowType = (RowType) type;
GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO;
ImmutableList.Builder<Optional<Field>> fieldsBuilder = ImmutableList.builder();
List<RowType.Field> fields = rowType.getFields();
boolean structHasParameters = false;
for (int i = 0; i < fields.size(); i++) {
RowType.Field rowField = fields.get(i);
String name = rowField.getName().get().toLowerCase(Locale.ENGLISH);
Optional<Field> field = constructField(rowField.getType(), lookupColumnByName(groupColumnIO, name));
structHasParameters |= field.isPresent();
fieldsBuilder.add(field);
}
if (structHasParameters) {
return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, fieldsBuilder.build()));
}
return Optional.empty();
}
if (type instanceof MapType) {
MapType mapType = (MapType) type;
GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO;
GroupColumnIO keyValueColumnIO = getMapKeyValueColumn(groupColumnIO);
if (keyValueColumnIO.getChildrenCount() != 2) {
return Optional.empty();
}
Optional<Field> keyField = constructField(mapType.getKeyType(), keyValueColumnIO.getChild(0));
Optional<Field> valueField = constructField(mapType.getValueType(), keyValueColumnIO.getChild(1));
return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(keyField, valueField)));
}
if (type instanceof ArrayType) {
ArrayType arrayType = (ArrayType) type;
GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO;
if (groupColumnIO.getChildrenCount() != 1) {
return Optional.empty();
}
Optional<Field> field = constructField(arrayType.getElementType(), getArrayElementColumn(groupColumnIO.getChild(0)));
return Optional.of(new GroupField(type, repetitionLevel, definitionLevel, required, ImmutableList.of(field)));
}
PrimitiveColumnIO primitiveColumnIO = (PrimitiveColumnIO) columnIO;
RichColumnDescriptor column = new RichColumnDescriptor(primitiveColumnIO.getColumnDescriptor(), columnIO.getType().asPrimitiveType());
return Optional.of(new PrimitiveField(type, repetitionLevel, definitionLevel, required, column, primitiveColumnIO.getId()));
}
@VisibleForTesting
static ParquetReader createParquetReader(ParquetMetadata parquetMetadata,
MessageColumnIO messageColumn,
ParquetDataSource dataSource,
Optional<InternalFileDecryptor> fileDecryptor)
{
return createParquetReader(parquetMetadata, messageColumn, dataSource, fileDecryptor, new DataSize(100000, DataSize.Unit.BYTE));
}
@VisibleForTesting
static ParquetReader createParquetReader(ParquetMetadata parquetMetadata,
MessageColumnIO messageColumn,
ParquetDataSource dataSource,
Optional<InternalFileDecryptor> fileDecryptor,
DataSize maxReadBlockSize)
{
ImmutableList.Builder<BlockMetaData> blocks = ImmutableList.builder();
ImmutableList.Builder<Long> blockStarts = ImmutableList.builder();
long nextStart = 0;
for (BlockMetaData block : parquetMetadata.getBlocks()) {
blocks.add(block);
blockStarts.add(nextStart);
nextStart += block.getRowCount();
}
return new ParquetReader(
messageColumn,
blocks.build(),
Optional.empty(),
dataSource,
com.facebook.presto.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext(),
maxReadBlockSize,
false,
false,
null,
null,
false,
fileDecryptor);
}
}