AggregateFunctionsTest.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package tech.tablesaw.aggregate;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static tech.tablesaw.aggregate.AggregateFunctions.allTrue;
import static tech.tablesaw.aggregate.AggregateFunctions.anyTrue;
import static tech.tablesaw.aggregate.AggregateFunctions.countFalse;
import static tech.tablesaw.aggregate.AggregateFunctions.countMissing;
import static tech.tablesaw.aggregate.AggregateFunctions.countTrue;
import static tech.tablesaw.aggregate.AggregateFunctions.countUnique;
import static tech.tablesaw.aggregate.AggregateFunctions.countWithMissing;
import static tech.tablesaw.aggregate.AggregateFunctions.earliestDate;
import static tech.tablesaw.aggregate.AggregateFunctions.latestDate;
import static tech.tablesaw.aggregate.AggregateFunctions.mean;
import static tech.tablesaw.aggregate.AggregateFunctions.noneTrue;
import static tech.tablesaw.aggregate.AggregateFunctions.percentile90;
import static tech.tablesaw.aggregate.AggregateFunctions.percentile95;
import static tech.tablesaw.aggregate.AggregateFunctions.percentile99;
import static tech.tablesaw.aggregate.AggregateFunctions.proportionFalse;
import static tech.tablesaw.aggregate.AggregateFunctions.proportionTrue;
import static tech.tablesaw.aggregate.AggregateFunctions.stdDev;
import static tech.tablesaw.aggregate.AggregateFunctions.sum;
import static tech.tablesaw.api.QuerySupport.and;
import static tech.tablesaw.api.QuerySupport.date;
import static tech.tablesaw.api.QuerySupport.num;
import static tech.tablesaw.api.QuerySupport.str;

import java.time.Instant;
import java.time.LocalDate;
import org.apache.commons.math3.stat.StatUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import tech.tablesaw.api.BooleanColumn;
import tech.tablesaw.api.ColumnType;
import tech.tablesaw.api.DateColumn;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.InstantColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.io.csv.CsvReadOptions;
import tech.tablesaw.table.SelectionTableSliceGroup;
import tech.tablesaw.table.StandardTableSliceGroup;
import tech.tablesaw.table.TableSliceGroup;

class AggregateFunctionsTest {

  private Table table;

  @BeforeEach
  void setUp() throws Exception {
    table = Table.read().csv(CsvReadOptions.builder("../data/bush.csv"));
  }

  @Test
  void testGroupMean() {
    StringColumn byColumn = table.stringColumn("who");
    TableSliceGroup group = StandardTableSliceGroup.create(table, byColumn);
    Table result = group.aggregate("approval", mean, stdDev);
    assertEquals(3, result.columnCount());
    assertEquals("who", result.column(0).name());
    assertEquals(6, result.rowCount());
    assertEquals("65.671875", result.getUnformatted(0, 1));
    assertEquals("10.648876067826901", result.getUnformatted(0, 2));
  }

  @Test
  void testDateMin() {
    StringColumn byColumn = table.dateColumn("date").yearQuarter();
    Table result = table.summarize("approval", "date", mean, earliestDate).by(byColumn);
    assertEquals(3, result.columnCount());
    assertEquals(13, result.rowCount());
  }

  @Test
  void testInstantMinMax() {
    Instant i1 = Instant.ofEpochMilli(10_000L);
    Instant i2 = Instant.ofEpochMilli(20_000L);
    Instant i3 = Instant.ofEpochMilli(30_000L);
    Instant i4 = null;

    // Explicitly test having a first value of missing
    InstantColumn ic = InstantColumn.create("instants", 5);
    ic.appendMissing();
    ic.append(i3);
    ic.append(i1);
    ic.append(i2);
    ic.appendMissing();
    ic.append(i4);
    Table test = Table.create("testInstantMath", ic);
    Table minI = test.summarize("instants", AggregateFunctions.minInstant).apply();
    Table maxI = test.summarize("instants", AggregateFunctions.maxInstant).apply();
    assertEquals(i1, minI.get(0, 0));
    assertEquals(i3, maxI.get(0, 0));
  }

  @Test
  void testInstantMinWorksWithLeadingNull() {
    Instant i1 = null;
    Instant i2 = Instant.ofEpochMilli(20_000L);
    Instant i3 = Instant.ofEpochMilli(30_000L);

    InstantColumn ic1 = InstantColumn.create("instants", 3);
    ic1.append(i1);
    ic1.append(i2);
    ic1.append(i3);
    assertEquals(i2, ic1.min());

    InstantColumn ic2 = InstantColumn.create("instants", 3);
    ic2.appendMissing();
    ic2.append(i2);
    ic2.append(i3);
    assertEquals(i2, ic2.min());
  }

  @Test
  void testInstantMaxWorksWithLeadingNull() {
    Instant i1 = null;
    Instant i2 = Instant.ofEpochMilli(20_000L);
    Instant i3 = Instant.ofEpochMilli(30_000L);

    InstantColumn ic1 = InstantColumn.create("instants", 3);
    ic1.append(i1);
    ic1.append(i2);
    ic1.append(i3);
    assertEquals(i3, ic1.max());

    InstantColumn ic2 = InstantColumn.create("instants", 3);
    ic2.appendMissing();
    ic2.append(i2);
    ic2.append(i3);
    assertEquals(i3, ic2.max());
  }

  @Test
  void testHaving() {
    StringColumn byColumn = table.dateColumn("date").yearQuarter();
    Table result =
        table
            .summarize("approval", mean, AggregateFunctions.count)
            .groupBy(byColumn)
            .having(num("Mean [approval]").isGreaterThan(60));
    assertEquals(7, result.rowCount());

    result = table.summarize("approval", mean, AggregateFunctions.count).by(byColumn);
    assertEquals(13, result.rowCount());
  }

  @Test
  void testGroupBy() {
    StringColumn byColumn = table.dateColumn("date").yearQuarter();
    Table result = table.summarize("approval", mean, AggregateFunctions.count).by(byColumn);
    assertEquals(13, result.rowCount());

    result = table.summarize("approval", mean, AggregateFunctions.count).groupBy(byColumn).apply();
    assertEquals(13, result.rowCount());
  }

  @Test
  void testBooleanAggregateFunctions() {
    boolean[] values = {true, false};
    BooleanColumn bc = BooleanColumn.create("test", values);
    assertTrue(anyTrue.summarize(bc));
    assertFalse(noneTrue.summarize(bc));
    assertFalse(allTrue.summarize(bc));
  }

  @Test
  void testBooleanNumericAggregateFunctions() {
    boolean[] values = {true, false, false, false};
    BooleanColumn bc = BooleanColumn.create("test", values);
    assertEquals(0.25, proportionTrue.summarize(bc));
    assertEquals(0.75, proportionFalse.summarize(bc));
  }

  @Test
  void testBooleanNumericFunctionGroup() {
    boolean[] values = {true, false, false, false, true, true, true, false};
    String[] group = {"a", "a", "a", "a", "b", "b", "b", "b"};
    BooleanColumn bc = BooleanColumn.create("test", values);
    StringColumn sc = StringColumn.create("group_key", group);

    Table table = Table.create(sc, bc);
    Table summarized = table.summarize("test", proportionTrue, proportionFalse).by("group_key");

    assertEquals(2, summarized.rowCount());
    assertEquals(
        1, summarized.where(summarized.stringColumn("group_key").isEqualTo("a")).rowCount());
    assertEquals(
        1, summarized.where(summarized.stringColumn("group_key").isEqualTo("b")).rowCount());
    assertEquals(
        ColumnType.DOUBLE,
        summarized.where(summarized.stringColumn(0).isEqualTo("a")).column(1).type());
    assertEquals(
        ColumnType.DOUBLE,
        summarized.where(summarized.stringColumn(0).isEqualTo("a")).column(2).type());
    assertEquals(
        ColumnType.DOUBLE,
        summarized.where(summarized.stringColumn(0).isEqualTo("b")).column(1).type());
    assertEquals(
        ColumnType.DOUBLE,
        summarized.where(summarized.stringColumn(0).isEqualTo("b")).column(2).type());
    assertEquals(
        0.25, summarized.where(summarized.stringColumn(0).isEqualTo("a")).doubleColumn(1).get(0));
    assertEquals(
        0.75, summarized.where(summarized.stringColumn(0).isEqualTo("a")).doubleColumn(2).get(0));
    assertEquals(
        0.75, summarized.where(summarized.stringColumn(0).isEqualTo("b")).doubleColumn(1).get(0));
    assertEquals(
        0.25, summarized.where(summarized.stringColumn(0).isEqualTo("b")).doubleColumn(2).get(0));
  }

  @Test
  void testGroupMean2() {
    Table result = table.summarize("approval", mean, stdDev).apply();
    assertEquals(2, result.columnCount());
  }

  @Test
  void testApplyWithNonNumericResults() {
    Table result = table.summarize("date", earliestDate, latestDate).apply();
    assertEquals(2, result.columnCount());
  }

  @Test
  void testGroupMean3a() {
    Summarizer function = table.summarize("approval", mean, stdDev);
    Table result = function.by(10);
    assertEquals(32, result.rowCount());
  }

  @Test
  void testGroupMean3b() {
    Summarizer function = table.summarize("approval", mean, stdDev);
    Table result = function.groupBy(10).apply();
    assertEquals(32, result.rowCount());
  }

  @Test
  void testGroupMean3c() {
    Summarizer function = table.summarize("approval", mean, stdDev);
    Table result = function.groupBy(10).having(num("mean [approval]").isGreaterThan(60));
    assertEquals(21, result.rowCount());
  }

  @Test
  void testGroupMean4() {
    table.addColumns(table.numberColumn("approval").cube());
    table.column(3).setName("cubed");
    Table result = table.summarize("approval", "cubed", mean, stdDev).apply();
    assertEquals(4, result.columnCount());
  }

  @Test
  void testGroupMeanByStep() {
    TableSliceGroup group = SelectionTableSliceGroup.create(table, "Step", 5);
    Table result = group.aggregate("approval", mean, stdDev);
    assertEquals(3, result.columnCount());
    assertEquals("53.6", result.getUnformatted(0, 1));
    assertEquals("2.5099800796022267", result.getUnformatted(0, 2));
  }

  @Test
  void testSummaryWithACalculatedColumn() {
    Summarizer summarizer = new Summarizer(table, table.dateColumn("date").year(), mean);
    Table t = summarizer.apply();
    double avg = t.doubleColumn(0).get(0);
    assertTrue(avg > 2002 && avg < 2003);
  }

  @Test
  void test2ColumnGroupMean() {
    StringColumn byColumn1 = table.stringColumn("who");
    DateColumn byColumn2 = table.dateColumn("date");
    Table result = table.summarize("approval", mean, sum).by(byColumn1, byColumn2);
    assertEquals(4, result.columnCount());
    assertEquals("who", result.column(0).name());
    assertEquals(323, result.rowCount());
    assertEquals(
        "46.0",
        result
            .where(
                and(str("who").isEqualTo("fox"), date("date").isEqualTo(LocalDate.of(2001, 1, 24))))
            .getUnformatted(0, 2));
  }

  @Test
  void testComplexSummarizing() {
    table.addColumns(table.numberColumn("approval").cube());
    table.column(3).setName("cubed");
    StringColumn byColumn1 = table.stringColumn("who");
    StringColumn byColumn2 = table.dateColumn("date").yearMonth();
    Table result = table.summarize("approval", "cubed", mean, sum).by(byColumn1, byColumn2);
    assertEquals(6, result.columnCount());
    assertEquals("who", result.column(0).name());
    assertEquals("date year & month", result.column(1).name());
  }

  @Test
  void testMultipleColumnTypes() {

    boolean[] args = {true, false, true, false};
    BooleanColumn booleanColumn = BooleanColumn.create("b", args);

    double[] numbers = {1, 2, 3, 4};
    DoubleColumn numberColumn = DoubleColumn.create("n", numbers);

    String[] strings = {"M", "F", "M", "F"};
    StringColumn stringColumn = StringColumn.create("s", strings);

    Table table = Table.create("test", booleanColumn, numberColumn);
    Table summary =
        table.summarize(booleanColumn, numberColumn, countTrue, stdDev).by(stringColumn);

    assertEquals(2.0, summary.doubleColumn(1).get(0), 1e-6);
    assertEquals(0.0, summary.doubleColumn(1).get(1), 1e-6);
    assertEquals(Math.sqrt(2), summary.doubleColumn(2).get(0), 1e-6);
    assertEquals(Math.sqrt(2), summary.doubleColumn(2).get(1), 1e-6);
  }

  @Test
  void testMultipleColumnTypesWithApply() {

    boolean[] args = {true, false, true, false};
    BooleanColumn booleanColumn = BooleanColumn.create("b", args);

    double[] numbers = {1, 2, 3, 4};
    DoubleColumn numberColumn = DoubleColumn.create("n", numbers);

    String[] strings = {"M", "F", "M", "F"};
    StringColumn stringColumn = StringColumn.create("s", strings);

    Table table = Table.create("test", booleanColumn, numberColumn, stringColumn);
    Table summarized = table.summarize(booleanColumn, numberColumn, countTrue, stdDev).apply();
    assertEquals(1.2909944487358056, summarized.doubleColumn(1).get(0), 0.00001);
  }

  @Test
  void testBooleanFunctions() {
    BooleanColumn c = BooleanColumn.create("test");
    c.append(true);
    c.appendCell("");
    c.append(false);
    assertEquals(1, countTrue.summarize(c), 0.0001);
    assertEquals(1, countFalse.summarize(c), 0.0001);
    assertEquals(0.5, proportionFalse.summarize(c), 0.0001);
    assertEquals(0.5, proportionTrue.summarize(c), 0.0001);
    assertEquals(1, countMissing.summarize(c), 0.0001);
    assertEquals(3, countWithMissing.summarize(c), 0.0001);
    assertEquals(2, countUnique.summarize(c), 0.0001);
  }

  @Test
  void testPercentileFunctions() {
    double[] values = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
    DoubleColumn c = DoubleColumn.create("test", values);
    c.appendCell("");

    assertEquals(1, countMissing.summarize(c), 0.0001);
    assertEquals(11, countWithMissing.summarize(c), 0.0001);

    assertEquals(StatUtils.percentile(values, 90), percentile90.summarize(c), 0.0001);
    assertEquals(StatUtils.percentile(values, 95), percentile95.summarize(c), 0.0001);
    assertEquals(StatUtils.percentile(values, 99), percentile99.summarize(c), 0.0001);

    assertEquals(10, countUnique.summarize(c), 0.0001);
  }
}