ArrayColumnValidator.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.facebook.presto.verifier.checksum;

import com.facebook.presto.common.type.AbstractVarcharType;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.RowType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.sql.tree.ArithmeticUnaryExpression;
import com.facebook.presto.sql.tree.Cast;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.Identifier;
import com.facebook.presto.sql.tree.LambdaArgumentDeclaration;
import com.facebook.presto.sql.tree.LambdaExpression;
import com.facebook.presto.sql.tree.LongLiteral;
import com.facebook.presto.sql.tree.SingleColumn;
import com.facebook.presto.sql.tree.TryExpression;
import com.facebook.presto.verifier.framework.Column;
import com.facebook.presto.verifier.framework.VerifierConfig;
import com.google.common.collect.ImmutableList;

import javax.inject.Inject;

import java.util.List;
import java.util.Objects;
import java.util.Optional;

import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.sql.QueryUtil.functionCall;
import static com.facebook.presto.sql.QueryUtil.identifier;
import static com.facebook.presto.verifier.framework.VerifierUtil.delimitedIdentifier;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Objects.requireNonNull;

public class ArrayColumnValidator
        implements ColumnValidator
{
    private final FloatingPointColumnValidator floatingPointValidator;
    private final boolean useErrorMarginForFloatingPointArrays;
    private final boolean validateStringAsDouble;

    @Inject
    public ArrayColumnValidator(VerifierConfig config, FloatingPointColumnValidator floatingPointValidator)
    {
        this.floatingPointValidator = requireNonNull(floatingPointValidator, "floatingPointValidator is null");
        this.useErrorMarginForFloatingPointArrays = config.isUseErrorMarginForFloatingPointArrays();
        this.validateStringAsDouble = config.isValidateStringAsDouble();
    }

    @Override
    public List<SingleColumn> generateChecksumColumns(Column column)
    {
        Type columnType = column.getType();
        ImmutableList.Builder<SingleColumn> builder = ImmutableList.builder();

        // coalesce(checksum(try(array_sort(array_column))), checksum(array_column))
        Expression checksum = generateArrayChecksum(column.getExpression(), columnType);
        // checksum(cardinality(array_column))
        Expression arrayCardinalityChecksum = functionCall("checksum", functionCall("cardinality", column.getExpression()));
        // coalesce(sum(cardinality(array_column)), 0)
        Expression arrayCardinalitySum = new CoalesceExpression(
                functionCall("sum", functionCall("cardinality", column.getExpression())), new LongLiteral("0"));

        // For arrays of floating point numbers we have a different processing, akin to FloatingPointColumnValidator.
        if (useFloatingPointPath(column)) {
            builder.addAll(generateFloatingPointArrayChecksumColumns(column));
        }
        else if (useStringAsDoublePath(column)) {
            builder.add(new SingleColumn(checksum, delimitedIdentifier(getChecksumColumnAlias(column))));
            builder.addAll(generateStringArrayChecksumColumns(column));
        }
        else {
            builder.add(new SingleColumn(checksum, delimitedIdentifier(getChecksumColumnAlias(column))));
        }
        builder.add(new SingleColumn(arrayCardinalityChecksum, delimitedIdentifier(getCardinalityChecksumColumnAlias(column))));
        builder.add(new SingleColumn(arrayCardinalitySum, delimitedIdentifier(getCardinalitySumColumnAlias(column))));

        return builder.build();
    }

    @Override
    public List<ColumnMatchResult<ArrayColumnChecksum>> validate(Column column, ChecksumResult controlResult, ChecksumResult testResult)
    {
        checkArgument(
                controlResult.getRowCount() == testResult.getRowCount(),
                "Test row count (%s) does not match control row count (%s)",
                testResult.getRowCount(),
                controlResult.getRowCount());

        boolean useFloatingPointPath = useFloatingPointPath(column);
        boolean useStringAsDoublePath = useStringAsDoublePath(column) && ColumnValidatorUtil.isStringAsDoubleColumn(column, controlResult, testResult);

        ArrayColumnChecksum controlChecksum = toColumnChecksum(column, controlResult, useFloatingPointPath, useStringAsDoublePath);
        ArrayColumnChecksum testChecksum = toColumnChecksum(column, testResult, useFloatingPointPath, useStringAsDoublePath);

        if (!Objects.equals(controlChecksum.getCardinalityChecksum(), testChecksum.getCardinalityChecksum()) ||
                !Objects.equals(controlChecksum.getCardinalitySum(), testChecksum.getCardinalitySum())) {
            return ImmutableList.of(new ColumnMatchResult<>(false, column, Optional.of("cardinality mismatch"), controlChecksum, testChecksum));
        }

        if (useFloatingPointPath) {
            ColumnMatchResult<FloatingPointColumnChecksum> result =
                    floatingPointValidator.validate(column, controlChecksum.getFloatingPointChecksum(), testChecksum.getFloatingPointChecksum());
            return ImmutableList.of(new ColumnMatchResult<>(result.isMatched(), column, result.getMessage(), controlChecksum, testChecksum));
        }

        if (useStringAsDoublePath) {
            Column asDoubleArrayColumn = getAsDoubleArrayColumn(column);
            ColumnMatchResult<FloatingPointColumnChecksum> result =
                    floatingPointValidator.validate(asDoubleArrayColumn, controlChecksum.getFloatingPointChecksum(), testChecksum.getFloatingPointChecksum());
            return ImmutableList.of(new ColumnMatchResult<>(result.isMatched(), column, result.getMessage(), controlChecksum, testChecksum));
        }

        return ImmutableList.of(new ColumnMatchResult<>(Objects.equals(controlChecksum, testChecksum), column, controlChecksum, testChecksum));
    }

    public static List<SingleColumn> generateFloatingPointArrayChecksumColumns(Column column)
    {
        checkArgument(column.getType() instanceof ArrayType, "Expect ArrayType, found %s", column.getType().getDisplayName());
        Type elementType = ((ArrayType) column.getType()).getElementType();
        checkArgument(Column.FLOATING_POINT_TYPES.contains(elementType), "Expect Double or Real, found %s", elementType.getDisplayName());

        Expression expression = elementType.equals(DOUBLE) ? column.getExpression() : new Cast(column.getExpression(), new ArrayType(DOUBLE).getDisplayName());

        // sum(array_sum(filter(array_column, x -> is_finite(x))))
        Expression sum = functionCall(
                "sum",
                functionCall("array_sum", functionCall("filter", expression, generateLambdaExpression("is_finite"))));
        // sum(cardinality(filter(array_column, x -> is_nan(x))))
        Expression nanCount = functionCall(
                "sum",
                functionCall("cardinality", functionCall("filter", expression, generateLambdaExpression("is_nan"))));
        // sum(cardinality(filter(array_column, x -> x = Infinite())))
        Expression posInfCount = functionCall(
                "sum",
                functionCall("cardinality", functionCall("filter", expression, generateInfinityLambdaExpression(ArithmeticUnaryExpression.Sign.PLUS))));
        // sum(cardinality(filter(array_column, x -> x = -Infinite())))
        Expression negInfCount = functionCall(
                "sum",
                functionCall("cardinality", functionCall("filter", expression, generateInfinityLambdaExpression(ArithmeticUnaryExpression.Sign.MINUS))));
        return ImmutableList.of(
                new SingleColumn(sum, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getSumColumnAlias(column)))),
                new SingleColumn(nanCount, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getNanCountColumnAlias(column)))),
                new SingleColumn(posInfCount, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getPositiveInfinityCountColumnAlias(column)))),
                new SingleColumn(negInfCount, Optional.of(delimitedIdentifier(FloatingPointColumnValidator.getNegativeInfinityCountColumnAlias(column)))));
    }

    public static List<SingleColumn> generateStringArrayChecksumColumns(Column column)
    {
        checkArgument(column.getType() instanceof ArrayType, "Expect ArrayType, found %s", column.getType().getDisplayName());
        Type elementType = ((ArrayType) column.getType()).getElementType();
        checkArgument(elementType instanceof AbstractVarcharType, "Expect VarcharType, found %s", elementType.getDisplayName());

        Column asDoubleArrayColumn = getAsDoubleArrayColumn(column);
        return ImmutableList.<SingleColumn>builder()
                .addAll(generateFloatingPointArrayChecksumColumns(asDoubleArrayColumn))
                .addAll(ColumnValidatorUtil.generateNullCountColumns(column, asDoubleArrayColumn))
                .build();
    }

    public static Expression generateArrayChecksum(Expression column, Type type)
    {
        checkArgument(type instanceof ArrayType, "Expect ArrayType, found %s", type.getDisplayName());
        Type elementType = ((ArrayType) type).getElementType();

        if (elementType.isOrderable()) {
            Expression arraySort = functionCall("array_sort", column);

            if (elementType instanceof ArrayType || elementType instanceof RowType) {
                return new CoalesceExpression(
                        functionCall("checksum", new TryExpression(arraySort)),
                        functionCall("checksum", column));
            }
            return functionCall("checksum", arraySort);
        }
        return functionCall("checksum", column);
    }

    private static Expression generateInfinityLambdaExpression(ArithmeticUnaryExpression.Sign sign)
    {
        ComparisonExpression lambdaBody = new ComparisonExpression(
                ComparisonExpression.Operator.EQUAL,
                new Identifier("x"),
                new ArithmeticUnaryExpression(sign, functionCall("infinity")));
        return new LambdaExpression(ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))), lambdaBody);
    }

    private static Expression generateLambdaExpression(String functionName)
    {
        return new LambdaExpression(
                ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))),
                functionCall(functionName, new Identifier("x")));
    }

    private static ArrayColumnChecksum toColumnChecksum(Column column, ChecksumResult checksumResult, boolean useFloatingPointPath, boolean useStringAsDoublePath)
    {
        if (checksumResult.getRowCount() == 0) {
            return new ArrayColumnChecksum(
                    null, null, 0,
                    useFloatingPointPath || useStringAsDoublePath ? Optional.of(new FloatingPointColumnChecksum(null, 0, 0, 0, 0)) : Optional.empty());
        }

        Object cardinalityChecksum = checksumResult.getChecksum(getCardinalityChecksumColumnAlias(column));
        long cardinalitySum = (long) checksumResult.getChecksum(getCardinalitySumColumnAlias(column));

        if (useFloatingPointPath) {
            return new ArrayColumnChecksum(
                    null,
                    cardinalityChecksum,
                    cardinalitySum,
                    Optional.of(FloatingPointColumnValidator.toColumnChecksum(column, checksumResult, checksumResult.getRowCount())));
        }

        Object checksum = checksumResult.getChecksum(getChecksumColumnAlias(column));

        if (useStringAsDoublePath) {
            Column asDoubleArrayColumn = getAsDoubleArrayColumn(column);
            return new ArrayColumnChecksum(
                    null,
                    cardinalityChecksum,
                    cardinalitySum,
                    Optional.of(FloatingPointColumnValidator.toColumnChecksum(asDoubleArrayColumn, checksumResult, checksumResult.getRowCount())));
        }

        return new ArrayColumnChecksum(checksum, cardinalityChecksum, cardinalitySum, Optional.empty());
    }

    private boolean useFloatingPointPath(Column column)
    {
        Type columnType = column.getType();
        checkArgument(columnType instanceof ArrayType, "Expect ArrayType, found %s", columnType.getDisplayName());
        Type elementType = ((ArrayType) columnType).getElementType();
        return (useErrorMarginForFloatingPointArrays && Column.FLOATING_POINT_TYPES.contains(elementType));
    }

    private boolean useStringAsDoublePath(Column column)
    {
        Type columnType = column.getType();
        checkArgument(columnType instanceof ArrayType, "Expect ArrayType, found %s", columnType.getDisplayName());
        Type elementType = ((ArrayType) columnType).getElementType();
        return (validateStringAsDouble && elementType instanceof AbstractVarcharType);
    }

    public static Column getAsDoubleArrayColumn(Column column)
    {
        return Column.create(column.getName() + "_as_double", getAsDoubleArrayExpression(column), new ArrayType(DOUBLE));
    }

    private static Expression getAsDoubleArrayExpression(Column column)
    {
        // transform(array_column, x -> try_cast(x as double))
        return functionCall("transform", column.getExpression(), new LambdaExpression(ImmutableList.of(new LambdaArgumentDeclaration(identifier("x"))),
                new Cast(identifier("x"), DOUBLE.getDisplayName(), true, false)));
    }

    private static String getChecksumColumnAlias(Column column)
    {
        return column.getName() + "$checksum";
    }

    private static String getCardinalityChecksumColumnAlias(Column column)
    {
        return column.getName() + "$cardinality_checksum";
    }

    private static String getCardinalitySumColumnAlias(Column column)
    {
        return column.getName() + "$cardinality_sum";
    }
}