TableFilteringTest.java

/*
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package tech.tablesaw;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static tech.tablesaw.api.QuerySupport.numberColumn;

import java.time.LocalDate;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import tech.tablesaw.api.DateColumn;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import tech.tablesaw.columns.dates.PackedLocalDate;
import tech.tablesaw.io.csv.CsvReadOptions;

/** Tests for filtering on the T class */
public class TableFilteringTest {

  private Table table;

  @BeforeEach
  public void setUp() throws Exception {
    table = Table.read().csv(CsvReadOptions.builder("../data/bush.csv").minimizeColumnSizes());
  }

  @Test
  public void testFilter1() {
    Table result = table.where(table.numberColumn("approval").isLessThan(53));
    ShortColumn a = result.shortColumn("approval");
    for (double v : a) {
      assertTrue(v < 53);
    }
  }

  @Test
  public void testReject() {
    Table result = table.dropWhere(table.numberColumn("approval").isLessThan(70));
    ShortColumn a = result.shortColumn("approval");
    for (double v : a) {
      assertFalse(v < 70);
    }
  }

  @Test
  public void testReject2() {
    Table result = table.dropWhere(numberColumn("approval").isLessThan(70));
    ShortColumn a = result.shortColumn("approval");
    for (double v : a) {
      assertFalse(v < 70);
    }
  }

  @Test
  public void testRejectWithMissingValues() {

    String[] values = {"a", "b", "", "d"};
    double[] values2 = {1, Double.NaN, 3, 4};
    StringColumn sc = StringColumn.create("s", values);
    DoubleColumn nc = DoubleColumn.create("n", values2);
    Table test = Table.create("test", sc, nc);
    Table result = test.dropRowsWithMissingValues();
    assertEquals(2, result.rowCount());
    assertEquals("a", result.stringColumn("s").get(0));
    assertEquals("d", result.stringColumn("s").get(1));
  }

  @Test
  public void testSelectRange() {
    Table result = table.inRange(20, 30);
    assertEquals(10, result.rowCount());
    for (Column<?> c : result.columns()) {
      for (int r = 0; r < result.rowCount(); r++) {
        assertEquals(table.getString(r + 20, c.name()), result.getString(r, c.name()));
      }
    }
  }

  @Test
  public void testSelectRows() {
    Table result = table.rows(20, 30);
    assertEquals(2, result.rowCount());
    for (Column<?> c : result.columns()) {
      assertEquals(table.getString(20, c.name()), result.getString(0, c.name()));
      assertEquals(table.getString(30, c.name()), result.getString(1, c.name()));
    }
  }

  @Test
  public void testSampleRows() {
    Table result = table.sampleN(20);
    assertEquals(20, result.rowCount());
  }

  @Test
  public void testSampleProportion() {
    Table result = table.sampleX(.1);
    assertEquals(32, result.rowCount());
  }

  @Test
  public void testRejectRows() {
    Table result = table.dropRows(20, 30);
    assertEquals(table.rowCount() - 2, result.rowCount());
    for (Column<?> c : result.columns()) {
      assertEquals(table.getString(21, c.name()), result.getString(20, c.name()));
      assertEquals(table.getString(32, c.name()), result.getString(30, c.name()));
    }
  }

  @Test
  public void testRejectRange() {
    Table result = table.dropRange(20, 30);
    assertEquals(table.rowCount() - 10, result.rowCount());
    for (Column<?> c : result.columns()) {
      for (int r = 30; r < result.rowCount(); r++) {
        assertEquals(result.getString(r, c.name()), table.getString(r + 10, c.name()));
      }
    }
  }

  @Test
  public void testFilter2() {
    Table result = table.where(table.dateColumn("date").isInApril());
    DateColumn d = result.dateColumn("date");
    for (LocalDate v : d) {
      assertTrue(PackedLocalDate.isInApril(PackedLocalDate.pack(v)));
    }
  }

  @Test
  public void testFilter3() {
    Table result =
        table.where(
            table
                .dateColumn("date")
                .isInApril()
                .and(table.numberColumn("approval").isGreaterThan(70)));

    DateColumn dates = result.dateColumn("date");
    ShortColumn approval = result.shortColumn("approval");
    for (int row = 0; row < result.rowCount(); row++) {
      assertTrue(PackedLocalDate.isInApril(dates.getIntInternal(row)));
      assertTrue(approval.get(row) > 70);
    }
  }

  @Test
  public void testFilter4() {
    Table result =
        table
            .where(
                table
                    .dateColumn("date")
                    .isInApril()
                    .and(table.numberColumn("approval").isGreaterThan(70)))
            .retainColumns("who", "approval");
    assertEquals(2, result.columnCount());
    assertTrue(result.columnNames().contains("who"));
    assertTrue(result.columnNames().contains("approval"));
  }
}