TestHiveAggregationFunctions.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.hive.functions;

import com.facebook.airlift.testing.Closeables;
import com.facebook.presto.common.type.ArrayType;
import com.facebook.presto.common.type.TimeZoneKey;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.server.testing.TestingPrestoServer;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.MaterializedRow;
import com.facebook.presto.tests.TestingPrestoClient;
import com.facebook.presto.tpch.TpchPlugin;
import com.google.inject.Key;
import org.intellij.lang.annotations.Language;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Stream;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
import static com.google.common.base.Preconditions.checkArgument;
import static org.testng.Assert.assertEquals;

public class TestHiveAggregationFunctions
{
    private static final Type BIGINT_ARRAY = new ArrayType(BIGINT);

    private TestingPrestoServer server;
    private TestingPrestoClient client;

    @BeforeClass
    public void setup()
            throws Exception
    {
        server = createServer();
        client = new TestingPrestoClient(server, testSessionBuilder()
                .setTimeZoneKey(TimeZoneKey.getTimeZoneKey("America/Bahia_Banderas"))
                .build());
    }

    @AfterClass
    public void destroy()
    {
        Closeables.closeQuietly(server, client);
    }

    @Test
    public void aggregationFunctions()
    {
        check("select hive.default.avg(nationkey) from tpch.tiny.nation",
                column(DOUBLE, 12.0));
        check("select regionkey, hive.default.avg(nationkey) from tpch.tiny.nation group by regionkey order by 1",
                column(BIGINT, 0L, 1L, 2L, 3L, 4L),
                column(DOUBLE, 10.0, 9.4, 13.6, 15.4, 11.6));
        check("select hive.default.collect_list(regionkey) from tpch.tiny.nation",
                column(BIGINT_ARRAY, Arrays.asList(0L, 1L, 1L, 1L, 4L, 0L, 3L, 3L, 2L, 2L, 4L, 4L, 2L, 4L, 0L, 0L, 0L, 1L, 2L, 3L, 4L, 2L, 3L, 3L, 1L)));
        check("select hive.default.collect_set(regionkey) from tpch.tiny.nation",
                column(BIGINT_ARRAY, Arrays.asList(0L, 1L, 4L, 3L, 2L)));
        check("select hive.default.corr(nationkey, regionkey) from tpch.tiny.nation",
                column(DOUBLE, 0.18042685));
        check("select hive.default.covar_pop(nationkey, regionkey) from tpch.tiny.nation",
                column(DOUBLE, 1.84));
        check("select hive.default.covar_samp(nationkey, regionkey) from tpch.tiny.nation",
                column(DOUBLE, 1.9166666));
        check("select hive.default.max(name) from tpch.tiny.nation",
                column(VARCHAR, "VIETNAM"));
        check("select regionkey, hive.default.max(name) from tpch.tiny.nation group by regionkey order by 1",
                column(BIGINT, 0L, 1L, 2L, 3L, 4L),
                column(VARCHAR, "MOZAMBIQUE", "UNITED STATES", "VIETNAM", "UNITED KINGDOM", "SAUDI ARABIA"));
        check("select hive.default.min(name) from tpch.tiny.nation",
                column(VARCHAR, "ALGERIA"));
        check("select regionkey, hive.default.min(name) from tpch.tiny.nation group by regionkey order by 1",
                column(BIGINT, 0L, 1L, 2L, 3L, 4L),
                column(VARCHAR, "ALGERIA", "ARGENTINA", "CHINA", "FRANCE", "EGYPT"));
        check("select hive.default.std(nationkey) from tpch.tiny.nation",
                column(DOUBLE, 7.2111025));
        check("select regionkey, hive.default.std(nationkey) from tpch.tiny.nation group by regionkey order by 1",
                column(BIGINT, 0L, 1L, 2L, 3L, 4L),
                column(DOUBLE, 6.35609943, 9.35093578, 5.08330601, 7.39188744, 5.16139516));
        check("select hive.default.sum(nationkey) from tpch.tiny.nation",
                column(BIGINT, 300L));
        check("select regionkey, hive.default.sum(nationkey) from tpch.tiny.nation group by regionkey order by 1",
                column(BIGINT, 0L, 1L, 2L, 3L, 4L),
                column(BIGINT, 50L, 47L, 68L, 77L, 58L));
        check("select hive.default.variance(nationkey) from tpch.tiny.nation",
                column(DOUBLE, 52.0));
        check("select hive.default.var_samp(nationkey) from tpch.tiny.nation",
                column(DOUBLE, 54.1666666));
    }

    @SuppressWarnings("UnknownLanguage")
    public void check(@Language("SQL") String query, Column... expectedColumns)
    {
        checkArgument(expectedColumns != null && expectedColumns.length > 0);
        int numColumns = expectedColumns.length;
        int numRows = expectedColumns[0].values.length;
        checkArgument(Stream.of(expectedColumns).allMatch(c -> c != null && c.values.length == numRows));

        MaterializedResult result = client.execute(query).getResult();
        assertEquals(result.getRowCount(), numRows);

        for (int i = 0; i < numColumns; i++) {
            assertEquals(result.getTypes().get(i), expectedColumns[i].type);
        }
        List<MaterializedRow> rows = result.getMaterializedRows();
        for (int i = 0; i < numRows; i++) {
            for (int j = 0; j < numColumns; j++) {
                Object actual = rows.get(i).getField(j);
                Object expected = expectedColumns[j].values[i];
                if (expectedColumns[j].type.equals(DOUBLE)) {
                    assertEquals(((Number) actual).doubleValue(), ((double) expected), 0.000001);
                }
                else {
                    assertEquals(actual, expected);
                }
            }
        }
    }

    private static TestingPrestoServer createServer()
            throws Exception
    {
        TestingPrestoServer server = new TestingPrestoServer();
        server.installPlugin(new TpchPlugin());
        server.installPlugin(new HiveFunctionNamespacePlugin());
        server.createCatalog("tpch", "tpch");
        FunctionAndTypeManager functionAndTypeManager = server.getInstance(Key.get(FunctionAndTypeManager.class));
        functionAndTypeManager.loadFunctionNamespaceManager(
                "hive-functions",
                "hive",
                Collections.emptyMap(),
                server.getPluginNodeManager());
        server.refreshNodes();
        return server;
    }

    public static Column column(Type type, Object... values)
    {
        return new Column(type, values);
    }

    private static class Column
    {
        private final Type type;

        private final Object[] values;

        private Column(Type type, Object[] values)
        {
            this.type = type;
            this.values = values;
        }
    }
}