TestArrayCombinationsFunction.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package com.facebook.presto.operator.scalar;

import com.facebook.presto.common.type.ArrayType;
import com.google.common.collect.ContiguousSet;
import com.google.common.collect.ImmutableList;
import org.testng.annotations.Test;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.UnknownType.UNKNOWN;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.operator.scalar.ArrayCombinationsFunction.combinationCount;
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.google.common.math.LongMath.factorial;
import static java.lang.String.format;
import static java.util.Arrays.asList;
import static org.testng.Assert.assertEquals;

public class TestArrayCombinationsFunction
        extends AbstractTestFunctions
{
    @Test
    public void testCombinationCount()
    {
        for (int n = 0; n < 5; n++) {
            for (int k = 0; k <= n; k++) {
                assertEquals(combinationCount(n, k), factorial(n) / factorial(n - k) / factorial(k));
            }
        }

        assertEquals(combinationCount(42, 7), 26978328);
        assertEquals(combinationCount(100, 4), 3921225);
    }

    @Test
    public void testBasic()
    {
        assertFunction("combinations(ARRAY['bar', 'foo', 'baz', 'foo'], 0)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(ImmutableList.of()));
        assertFunction("combinations(ARRAY['bar', 'foo', 'baz', 'foo'], 1)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(
                ImmutableList.of("bar"),
                ImmutableList.of("foo"),
                ImmutableList.of("baz"),
                ImmutableList.of("foo")));
        assertFunction("combinations(ARRAY['bar', 'foo', 'baz', 'foo'], 2)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(
                ImmutableList.of("bar", "foo"),
                ImmutableList.of("bar", "baz"),
                ImmutableList.of("foo", "baz"),
                ImmutableList.of("bar", "foo"),
                ImmutableList.of("foo", "foo"),
                ImmutableList.of("baz", "foo")));
        assertFunction("combinations(ARRAY['bar', 'foo', 'baz', 'foo'], 3)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(
                ImmutableList.of("bar", "foo", "baz"),
                ImmutableList.of("bar", "foo", "foo"),
                ImmutableList.of("bar", "baz", "foo"),
                ImmutableList.of("foo", "baz", "foo")));
        assertFunction("combinations(ARRAY['bar', 'foo', 'baz', 'foo'], 4)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(
                ImmutableList.of("bar", "foo", "baz", "foo")));
        assertFunction("combinations(ARRAY['bar', 'foo', 'baz', 'foo'], 5)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of());
        assertFunction("combinations(ARRAY['a', 'bb', 'ccc', 'dddd'], 2)", new ArrayType(new ArrayType(createVarcharType(4))), ImmutableList.of(
                ImmutableList.of("a", "bb"),
                ImmutableList.of("a", "ccc"),
                ImmutableList.of("bb", "ccc"),
                ImmutableList.of("a", "dddd"),
                ImmutableList.of("bb", "dddd"),
                ImmutableList.of("ccc", "dddd")));
    }

    @Test
    public void testLimits()
    {
        assertInvalidFunction("combinations(sequence(1, 40), -1)", INVALID_FUNCTION_ARGUMENT, "combination size must not be negative: -1");
        assertInvalidFunction("combinations(sequence(1, 40), 10)", INVALID_FUNCTION_ARGUMENT, "combination size must not exceed 5: 10");
        assertInvalidFunction("combinations(sequence(1, 100), 5)", INVALID_FUNCTION_ARGUMENT, "combinations exceed max size");
    }

    @Test
    public void testCardinality()
    {
        for (int n = 0; n < 5; n++) {
            for (int k = 0; k <= n; k++) {
                String array = "ARRAY" + ContiguousSet.closedOpen(0, n).asList();
                assertFunction(format("cardinality(combinations(%s, %s))", array, k), BIGINT, factorial(n) / factorial(n - k) / factorial(k));
            }
        }
    }

    @Test
    public void testNull()
    {
        assertFunction("combinations(CAST(NULL AS array(bigint)), 2)", new ArrayType(new ArrayType(BIGINT)), null);

        assertFunction("combinations(ARRAY['foo', NULL, 'bar'], 2)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(
                asList("foo", null),
                asList("foo", "bar"),
                asList(null, "bar")));

        assertFunction("combinations(ARRAY [NULL, NULL, NULL], 2)", new ArrayType(new ArrayType(UNKNOWN)), ImmutableList.of(
                asList(null, null),
                asList(null, null),
                asList(null, null)));

        assertFunction("combinations(ARRAY [NULL, 3, NULL], 2)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(
                asList(null, 3),
                asList(null, null),
                asList(3, null)));
    }

    @Test
    public void testTypeCombinations()
    {
        assertFunction("combinations(ARRAY[1, 2, 3], 2)", new ArrayType(new ArrayType(INTEGER)), ImmutableList.of(
                ImmutableList.of(1, 2),
                ImmutableList.of(1, 3),
                ImmutableList.of(2, 3)));
        assertFunction("combinations(ARRAY[1.1E0, 2.1E0, 3.1E0], 2)", new ArrayType(new ArrayType(DOUBLE)), ImmutableList.of(
                ImmutableList.of(1.1, 2.1),
                ImmutableList.of(1.1, 3.1),
                ImmutableList.of(2.1, 3.1)));
        assertFunction("combinations(ARRAY[true, false, true], 2)", new ArrayType(new ArrayType(BOOLEAN)), ImmutableList.of(
                ImmutableList.of(true, false),
                ImmutableList.of(true, true),
                ImmutableList.of(false, true)));

        assertFunction("combinations(ARRAY[ARRAY['A1', 'A2'], ARRAY['B1'], ARRAY['C1', 'C2']], 2)", new ArrayType(new ArrayType(new ArrayType(createVarcharType(2)))), ImmutableList.of(
                ImmutableList.of(ImmutableList.of("A1", "A2"), ImmutableList.of("B1")),
                ImmutableList.of(ImmutableList.of("A1", "A2"), ImmutableList.of("C1", "C2")),
                ImmutableList.of(ImmutableList.of("B1"), ImmutableList.of("C1", "C2"))));

        assertFunction("combinations(ARRAY['\u4FE1\u5FF5\u7231', '\u5E0C\u671B', '\u671B'], 2)", new ArrayType(new ArrayType(createVarcharType(3))), ImmutableList.of(
                ImmutableList.of("\u4FE1\u5FF5\u7231", "\u5E0C\u671B"),
                ImmutableList.of("\u4FE1\u5FF5\u7231", "\u671B"),
                ImmutableList.of("\u5E0C\u671B", "\u671B")));

        assertFunction("combinations(ARRAY[], 2)", new ArrayType(new ArrayType(UNKNOWN)), ImmutableList.of());
        assertFunction("combinations(ARRAY[''], 2)", new ArrayType(new ArrayType(createVarcharType(0))), ImmutableList.of());
        assertFunction("combinations(ARRAY['', ''], 2)", new ArrayType(new ArrayType(createVarcharType(0))), ImmutableList.of(ImmutableList.of("", "")));
    }
}