ArrowWriter.java
package tech.tablesaw.io.arrow;
import static org.apache.arrow.vector.types.FloatingPointPrecision.DOUBLE;
import static org.apache.arrow.vector.types.FloatingPointPrecision.SINGLE;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.nio.channels.Channels;
import java.nio.charset.StandardCharsets;
import java.time.ZoneOffset;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.*;
import org.apache.arrow.vector.ipc.ArrowStreamWriter;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import tech.tablesaw.api.*;
import tech.tablesaw.columns.Column;
import tech.tablesaw.io.RuntimeIOException;
/** Writer for persisting a Tablesaw table in Apache Arrow Streaming Format. */
public class ArrowWriter {
/**
* Returns an arrow Schema objects containing fields for each of the columns in the given Table
*/
private Schema tableSchema(Table table) {
List<Field> fields = new ArrayList<>();
for (Column<?> column : table.columns()) {
final String typeName = column.type().name();
switch (typeName) {
case "STRING":
fields.add(new Field(column.name(), FieldType.nullable(new ArrowType.Utf8()), null));
break;
case "LONG":
fields.add(
new Field(column.name(), FieldType.nullable(new ArrowType.Int(64, true)), null));
break;
case "INTEGER":
fields.add(
new Field(column.name(), FieldType.nullable(new ArrowType.Int(32, true)), null));
break;
case "SHORT":
fields.add(
new Field(column.name(), FieldType.nullable(new ArrowType.Int(16, true)), null));
break;
case "LOCAL_DATE":
fields.add(
new Field(
column.name(), FieldType.notNullable(new ArrowType.Date(DateUnit.DAY)), null));
break;
case "LOCAL_DATE_TIME":
fields.add(
new Field(
column.name(),
FieldType.notNullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, null)),
null));
break;
case "LOCAL_TIME":
fields.add(
new Field(
column.name(),
FieldType.notNullable(new ArrowType.Time(TimeUnit.MILLISECOND, 32)),
null));
break;
case "INSTANT":
fields.add(
new Field(
column.name(),
FieldType.notNullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")),
null));
break;
case "BOOLEAN":
fields.add(
new Field(column.name(), FieldType.nullable(Types.MinorType.BIT.getType()), null));
break;
case "FLOAT":
fields.add(
new Field(
column.name(), FieldType.notNullable(new ArrowType.FloatingPoint(SINGLE)), null));
break;
case "DOUBLE":
fields.add(
new Field(
column.name(), FieldType.notNullable(new ArrowType.FloatingPoint(DOUBLE)), null));
break;
default:
throw new IllegalArgumentException(
"Unhandled Column type " + typeName + " in exported data");
}
}
return new Schema(fields);
}
/**
* Writes the data from the given column into the corresponding vector in the given
* VectorSchemaRoot
*/
private void setBytes(VectorSchemaRoot schemaRoot, Column<?> column) {
final String typeName = column.type().name();
switch (typeName) {
case "STRING":
VarCharVector sv = ((VarCharVector) schemaRoot.getVector(column.name()));
StringColumn sc = (StringColumn) column;
for (int i = 0; i < sc.size(); i++) {
sv.setSafe(i, sc.get(i).getBytes(StandardCharsets.UTF_8));
}
sv.setValueCount(sc.size());
break;
case "LONG":
BigIntVector lv = ((BigIntVector) schemaRoot.getVector(column.name()));
LongColumn lc = (LongColumn) column;
for (int i = 0; i < lc.size(); i++) {
lv.setSafe(i, lc.getLong(i));
}
lv.setValueCount(lc.size());
break;
case "INTEGER":
IntVector iv = ((IntVector) schemaRoot.getVector(column.name()));
IntColumn ic = (IntColumn) column;
for (int i = 0; i < ic.size(); i++) {
iv.setSafe(i, ic.getInt(i));
}
iv.setValueCount(ic.size());
break;
case "SHORT":
SmallIntVector shortv = ((SmallIntVector) schemaRoot.getVector(column.name()));
ShortColumn shortc = (ShortColumn) column;
for (int i = 0; i < shortc.size(); i++) {
shortv.setSafe(i, shortc.getInt(i));
}
shortv.setValueCount(shortc.size());
break;
case "LOCAL_DATE":
DateDayVector dv = ((DateDayVector) schemaRoot.getVector(column.name()));
DateColumn dc = (DateColumn) column;
for (int i = 0; i < dc.size(); i++) {
dv.setSafe(i, (int) dc.get(i).toEpochDay());
}
dv.setValueCount(dc.size());
break;
case "LOCAL_DATE_TIME":
TimeStampMilliVector dtv = ((TimeStampMilliVector) schemaRoot.getVector(column.name()));
DateTimeColumn dtc = (DateTimeColumn) column;
for (int i = 0; i < dtc.size(); i++) {
dtv.setSafe(i, dtc.get(i).toInstant(ZoneOffset.UTC).toEpochMilli());
}
dtv.setValueCount(dtc.size());
break;
case "LOCAL_TIME":
TimeMilliVector tv = ((TimeMilliVector) schemaRoot.getVector(column.name()));
TimeColumn tc = (TimeColumn) column;
for (int i = 0; i < tc.size(); i++) {
tv.setSafe(i, (int) ((tc.get(i).toNanoOfDay()) / 1_000_000));
}
tv.setValueCount(tc.size());
break;
case "INSTANT":
TimeStampMilliTZVector instv =
((TimeStampMilliTZVector) schemaRoot.getVector(column.name()));
InstantColumn instc = (InstantColumn) column;
for (int i = 0; i < instc.size(); i++) {
instv.setSafe(i, instc.get(i).toEpochMilli());
}
instv.setValueCount(instc.size());
break;
case "BOOLEAN":
BitVector bv = ((BitVector) schemaRoot.getVector(column.name()));
BooleanColumn bc = (BooleanColumn) column;
for (int i = 0; i < bc.size(); i++) {
bv.setSafe(i, bc.getByte(i));
}
bv.setValueCount(bc.size());
break;
case "FLOAT":
Float4Vector fv = ((Float4Vector) schemaRoot.getVector(column.name()));
FloatColumn fc = (FloatColumn) column;
for (int i = 0; i < fc.size(); i++) {
fv.setSafe(i, fc.getFloat(i));
}
fv.setValueCount(fc.size());
break;
case "DOUBLE":
Float8Vector f8v = ((Float8Vector) schemaRoot.getVector(column.name()));
DoubleColumn f8c = (DoubleColumn) column;
for (int i = 0; i < f8c.size(); i++) {
f8v.setSafe(i, f8c.getDouble(i));
}
f8v.setValueCount(f8c.size());
break;
default:
throw new IllegalArgumentException(
"Unhandled Column type " + typeName + " in exported data");
}
}
/**
* Writes table to arrow-formatted file. The Arrow Stream format is used, meaning that there is no
* sparse index into the individual data blocks, and only sequential access is supported.
*
* <p>Note that for Arrow Streaming Format files, the extension ".arrows" is recommended. The
* ".arrow" extension is intended for use by Arrow File Format, which provides random access to
* the individual blocks
*
* <p>The arrow format specifies writing tables in record batches, along with any
* DictionaryProviders that will be used in encoding the data.
*
* <p>The process for writing record batches is as follows: - create a VectorSchemaRoot (VSR) -
* populate the vectors in the VSR with the first batch of rows from the table, using some
* arbitrary number of rows for the batch size - write the batch to the output stream - reset the
* vectors in the VSR - repopulate the vectors the next batch
*
* <p>The cycle of reset, repopulate, and write is continued until all the data has been written
*
* @param table The table to write
* @param file The file we're writing to
*/
public void write(Table table, File file) {
// Create an RootAllocator to allocate memory for our vectors
BufferAllocator allocator = new RootAllocator();
Schema schema = tableSchema(table);
List<FieldVector> fieldVectors = createFieldVectors(schema, allocator);
VectorSchemaRoot schemaRoot = new VectorSchemaRoot(fieldVectors);
try (FileOutputStream out = new FileOutputStream(file);
ArrowStreamWriter writer =
new ArrowStreamWriter(
schemaRoot, /* DictionaryProvider= */ null, Channels.newChannel(out))) {
writer.start();
for (FieldVector v : schemaRoot.getFieldVectors()) {
v.reset();
}
for (Column<?> column : table.columns()) {
setBytes(schemaRoot, column);
}
schemaRoot.setRowCount(table.rowCount());
writer.writeBatch();
writer.end();
} catch (IOException e) {
throw new RuntimeIOException(e);
}
}
private List<FieldVector> createFieldVectors(Schema schema, BufferAllocator allocator) {
return schema.getFields().stream()
.map(field -> field.createVector(allocator))
.collect(Collectors.toList());
}
}