TestNoisyAggregationUtils.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.operator.aggregation.noisyaggregation;
import com.facebook.presto.common.type.StandardTypes;
import java.util.ArrayList;
import java.util.List;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import static com.facebook.presto.common.type.Decimals.MAX_PRECISION;
public class TestNoisyAggregationUtils
{
public static final BiFunction<Object, Object, Boolean> notEqualDoubleAssertion = (actual, expected) -> !new Double(actual.toString()).equals(new Double(expected.toString()));
public static final BiFunction<Object, Object, Boolean> equalDoubleAssertion =
(actual, expected) -> Math.abs(new Double(actual.toString()) - new Double(expected.toString())) <= 1e-12;
public static final BiFunction<Object, Object, Boolean> equalLongAssertion = (actual, expected) -> new Long(actual.toString()).equals(new Long(expected.toString()));
public static final double DEFAULT_TEST_STANDARD_DEVIATION = 1.0;
public static final BiFunction<Object, Object, Boolean> withinSomeStdAssertion = (actual, expected) -> {
double actualValue = new Double(actual.toString());
double expectedValue = new Double(expected.toString());
return expectedValue - 50 * DEFAULT_TEST_STANDARD_DEVIATION <= actualValue && actualValue <= expectedValue + 50 * DEFAULT_TEST_STANDARD_DEVIATION;
};
private TestNoisyAggregationUtils()
{
}
public static <T> List<T> createTestValues(int numRows, boolean includeNull, T value, boolean fixedValue)
{
ArrayList<T> values = new ArrayList<>();
for (int i = 0; i < numRows; i++) {
if (fixedValue) {
values.add(value);
}
else {
if (value instanceof Double) {
values.add((T) Double.valueOf(i));
}
else if (value instanceof Integer) {
values.add((T) Integer.valueOf(i));
}
else if (value instanceof Long) {
values.add((T) Long.valueOf(i));
}
else if (value instanceof Boolean) {
values.add((T) Boolean.valueOf(i % 2 == 0));
}
}
}
if (includeNull) {
values.remove(0);
values.add(null);
}
return values;
}
/**
* Build a dataset that can be selected from. This is in the form of:
* (SELECT
* CAST(index AS bigint) AS index,
* CAST(col_bigint AS bigint) as col_bigint,
* CAST(col_varchar AS varchar) as col_varchar
* FROM (
* VALUES
* (1, 1, '{}'),
* (NULL, NULL, NULL)
* ) AS t (index, col_bigint, col_varchar))
* <p>
* CASTs is to make sure data type is explicitly provided, not inferred
*/
public static String buildData(int numRows, boolean includeNullValue, List<String> types)
{
int finalNumRows = numRows;
if (includeNullValue) {
finalNumRows = numRows - 1;
}
// Build CASTs to make sure data type is explicitly provided, not inferred
StringBuilder sb = new StringBuilder();
sb.append("(SELECT ");
sb.append("CAST(index AS bigint) AS index, ");
for (int i = 0; i < types.size(); i++) {
String type = types.get(i);
String typeString = type.equals(StandardTypes.DECIMAL) ? "DECIMAL(" + MAX_PRECISION + ")" : type;
String column = buildColumnName(type);
sb.append("CAST(").append(column).append(" AS ").append(typeString).append(") AS ").append(column);
if (i < types.size() - 1) {
sb.append(",");
}
sb.append(" ");
}
sb.append("FROM (VALUES ");
for (int i = 0; i < finalNumRows; i++) {
if (i > 0) {
sb.append(",");
}
buildRow(sb, i, types, false);
}
if (includeNullValue) {
sb.append(",");
buildRow(sb, finalNumRows, types, true);
}
sb.append(") AS t (").append("index");
// build column names
for (String type : types) {
sb.append(", ").append(buildColumnName(type));
}
sb.append("))");
return sb.toString();
}
public static String buildColumnName(String type)
{
return "col_" + type;
}
public static void buildRow(StringBuilder sb, int index, List<String> types, boolean isNullRow)
{
// index column
sb.append("(").append(isNullRow ? "NULL" : index);
// value column(s)
for (String type : types) {
sb.append(", ");
if (isNullRow) {
sb.append("NULL");
}
else {
switch (type) {
case StandardTypes.TINYINT:
case StandardTypes.SMALLINT:
case StandardTypes.INTEGER:
case StandardTypes.BIGINT:
sb.append(index);
break;
case StandardTypes.REAL:
case StandardTypes.DOUBLE:
case StandardTypes.DECIMAL:
sb.append(index).append(".0");
break;
case StandardTypes.VARCHAR:
case StandardTypes.CHAR:
case StandardTypes.VARBINARY:
case StandardTypes.JSON:
sb.append("'{}'");
break;
case StandardTypes.BOOLEAN:
sb.append(index % 2 == 0 ? "true" : "false");
break;
}
}
}
sb.append(")");
}
public static double sum(List<Double> values)
{
return values.stream().mapToDouble(f -> f == null ? 0 : f).sum();
}
public static double sumLong(List<Long> values)
{
return values.stream().mapToLong(v -> v == null ? 0 : v).sum();
}
public static double countTrue(List<Boolean> values)
{
return values.stream().mapToLong(v -> v == null || !v ? 0 : 1).sum();
}
public static double avg(List<Double> values)
{
return sum(values) / countNonNull(values);
}
public static double avgLong(List<Long> values)
{
return sumLong(values) / countNonNullLong(values);
}
public static double countNonNull(List<Double> values)
{
return values.stream().mapToLong(f -> f == null ? 0 : 1).sum();
}
public static long countNonNullLong(List<Long> values)
{
return values.stream().mapToLong(v -> v == null ? 0 : 1).sum();
}
public static List<String> toNullableStringList(List<Long> values)
{
return values.stream().map(v -> v == null ? null : String.valueOf(v)).collect(Collectors.toList());
}
}