SmileConverterTest.java
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package tech.tablesaw.conversion.smile;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import java.io.IOException;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.LocalTime;
import java.time.ZoneOffset;
import org.junit.jupiter.api.Test;
import smile.data.DataFrame;
import smile.data.formula.Formula;
import smile.regression.LinearModel;
import smile.regression.OLS;
import tech.tablesaw.api.BooleanColumn;
import tech.tablesaw.api.DateColumn;
import tech.tablesaw.api.DateTimeColumn;
import tech.tablesaw.api.DoubleColumn;
import tech.tablesaw.api.FloatColumn;
import tech.tablesaw.api.InstantColumn;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.LongColumn;
import tech.tablesaw.api.ShortColumn;
import tech.tablesaw.api.StringColumn;
import tech.tablesaw.api.Table;
import tech.tablesaw.api.TimeColumn;
public class SmileConverterTest {
@Test
public void regression() throws IOException {
Table moneyball = Table.read().csv("../data/baseball.csv");
moneyball.addColumns(
moneyball.numberColumn("RS").subtract(moneyball.numberColumn("RA")).setName("RD"));
LinearModel winsModel =
OLS.fit(Formula.lhs("RD"), moneyball.selectColumns("W", "RD").smile().toDataFrame());
assertNotNull(winsModel.toString());
}
@Test
public void allColumnTypes() throws IOException {
Table table = Table.create();
table.addColumns(BooleanColumn.create("boolean", new boolean[] {true, false}));
table.addColumns(DoubleColumn.create("double", new double[] {1.2, 3.4}));
table.addColumns(FloatColumn.create("float", new float[] {5.6f, 7.8f}));
table.addColumns(
InstantColumn.create(
"instant",
new Instant[] {
Instant.ofEpochMilli(1578452479123l), Instant.ofEpochMilli(1578451111111l)
}));
table.addColumns(IntColumn.create("int", new int[] {8, 9}));
table.addColumns(
DateColumn.create(
"date", new LocalDate[] {LocalDate.of(2020, 01, 01), LocalDate.of(2020, 01, 07)}));
table.addColumns(
DateTimeColumn.create(
"datetime",
new LocalDateTime[] {
LocalDateTime.ofInstant(Instant.ofEpochMilli(1333352479123l), ZoneOffset.UTC),
LocalDateTime.ofInstant(Instant.ofEpochMilli(1333333333333l), ZoneOffset.UTC)
}));
table.addColumns(
TimeColumn.create(
"time", new LocalTime[] {LocalTime.of(8, 37, 48), LocalTime.of(8, 59, 06)}));
table.addColumns(LongColumn.create("long", new long[] {3l, 4l}));
table.addColumns(ShortColumn.create("short", new short[] {1, 2}));
table.addColumns(StringColumn.create("string", new String[] {"james", "bond"}));
table.addColumns(StringColumn.create("text", new String[] {"foo", "bar"}));
DataFrame dataframe = table.smile().toDataFrame();
assertEquals(2, dataframe.nrows());
}
}