SortMergeJoin.java
package tech.tablesaw.joining;
import com.google.common.collect.Streams;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import tech.tablesaw.api.ColumnType;
import tech.tablesaw.api.IntColumn;
import tech.tablesaw.api.Row;
import tech.tablesaw.api.Table;
import tech.tablesaw.columns.Column;
import tech.tablesaw.selection.Selection;
/** Implements joins between two or more Tables */
class SortMergeJoin implements JoinStrategy {
private static final String LEFT_RECORD_ID_NAME = "_left_record_id_";
private static final String RIGHT_RECORD_ID_NAME = "_right_record_id_";
private static final String TABLE_ALIAS = "T";
public static final String PLACEHOLDER_COL_PREFIX = "Placeholder_";
private final String[] leftjoinColumnNames;
private int[] leftJoinColumnPositions;
private int[] rightJoinColumnPositions;
private final AtomicInteger joinTableId = new AtomicInteger(1);
/**
* Constructor.
*
* @param table The table to join on.
* @param joinColumnNames The join column names to join on.
*/
public SortMergeJoin(Table table, String... joinColumnNames) {
this.leftJoinColumnPositions = getJoinIndexes(table, joinColumnNames);
this.leftjoinColumnNames = joinColumnNames;
}
/**
* Finds the index of the columns corresponding to the columnNames. E.G. The column named "ID" is
* located at index 5 in table.
*
* @param table the table that contains the columns.
* @param columnNames the column names to find indexes of.
* @return a list of column indexes within the table.
*/
private int[] getJoinIndexes(Table table, String[] columnNames) {
int[] results = new int[columnNames.length];
for (int i = 0; i < columnNames.length; i++) {
String nm = columnNames[i];
results[i] = table.columnIndex(nm);
}
return results;
}
/**
* Joins two tables.
*
* @param t1 the table on the left side of the join.
* @param t2 the table on the right side of the join.
* @param joinType the type of join.
* @param allowDuplicates if {@code false} the join will fail if any columns other than the join
* column have the same name if {@code true} the join will succeed and duplicate columns are
* renamed
* @param keepAllJoinKeyColumns if {@code false} the join will only keep join key columns in
* table1 if {@code true} the join will return all join key columns in both table, which may
* have difference when there are null values
* @param table2JoinColumnNames The names of the columns in table2 to join on.
* @return the joined table
*/
public Table performJoin(
Table t1,
Table t2,
JoinType joinType,
boolean allowDuplicates,
boolean keepAllJoinKeyColumns,
int[] leftJoinColumnIndexes,
String... table2JoinColumnNames) {
this.leftJoinColumnPositions = leftJoinColumnIndexes;
rightJoinColumnPositions = getJoinIndexes(t2, table2JoinColumnNames);
Table table1 = t1.sortAscendingOn(leftjoinColumnNames);
Table table2 = t2.sortAscendingOn(table2JoinColumnNames);
Column<?>[] cols =
Streams.concat(table1.columns().stream(), table2.columns().stream())
.map(Column::emptyCopy)
.toArray(Column[]::new);
// A set of column indexes in the result table that can be ignored. They are duplicate join
// keys.
int[] resultIgnoreColIndexes =
keepAllJoinKeyColumns ? new int[0] : getIgnoredColumns(table1, joinType, cols);
Table result = emptyTableFromColumns(table1, allowDuplicates, cols);
// add indexes for outer join processing
IntColumn indexLeft = IntColumn.indexColumn(LEFT_RECORD_ID_NAME, table1.rowCount(), 0);
table1.addColumns(indexLeft);
result.addColumns(IntColumn.create(LEFT_RECORD_ID_NAME));
IntColumn indexRight = IntColumn.indexColumn(RIGHT_RECORD_ID_NAME, table2.rowCount(), 0);
table2.addColumns(indexRight);
result.addColumns(IntColumn.create(RIGHT_RECORD_ID_NAME));
validateJoinColumns(table1, table2);
if (table1.rowCount() == 0 && (joinType == JoinType.LEFT_OUTER || joinType == JoinType.INNER)) {
// Handle special case of empty table here so it doesn't fall through to the behavior
// that adds rows for full outer and right outer joins
if (!keepAllJoinKeyColumns) {
result.removeColumns(resultIgnoreColIndexes);
}
return result;
}
if (joinType == JoinType.INNER) {
joinInner(result, table1, table2, resultIgnoreColIndexes);
} else if (joinType == JoinType.LEFT_OUTER) {
joinLeft(result, table1, table2, resultIgnoreColIndexes);
} else if (joinType == JoinType.RIGHT_OUTER) {
joinRight(result, table1, table2, resultIgnoreColIndexes);
} else if (joinType == JoinType.FULL_OUTER) {
joinFull(result, table1, table2, resultIgnoreColIndexes);
}
result.removeColumns(LEFT_RECORD_ID_NAME, RIGHT_RECORD_ID_NAME);
if (!keepAllJoinKeyColumns) {
result = result.removeColumns(resultIgnoreColIndexes);
} else {
renameJoinColumns(result, table1, resultIgnoreColIndexes);
}
return result;
}
/**
* Renames the column indexes for the second table from Placeholder_X to their original names
*
* @param resultIgnoreColIndexes The positions of the secondary join columns
*/
private void renameJoinColumns(Table result, Table left, int[] resultIgnoreColIndexes) {
String table2Alias = TABLE_ALIAS + joinTableId.get();
for (int position : resultIgnoreColIndexes) {
String realName = result.column(position).name().replace(PLACEHOLDER_COL_PREFIX, "");
if (position >= left.columnCount()) {
if (result.containsColumn(realName.toLowerCase())) {
result.column(position).setName(newName(table2Alias, realName));
} else {
result.column(position).setName(realName);
}
} else {
result.column(position).setName(realName);
}
}
}
private String newName(String table2Alias, String columnName) {
return table2Alias + "." + columnName;
}
/**
* Adds empty columns to the destination table with the same type as columns in table1 and table2.
*
* <p>For inner, left and full outer join types the join columns in table2 are not needed and will
* be marked as placeholders. The indexes of those columns will be returned. The downstream logic
* is easier if we wait to remove the redundant columns until the last step.
*
* @param table1 the table on left side of the join.
* @param allowDuplicates whether to allow duplicates. If yes rename columns in table2 that have
* the same name as columns in table1, with the exception of join columns in table2 when
* performing a right join.
* @param cols An array of columns from both join tables
* @return the table to use for the join results
*/
Table emptyTableFromColumns(Table table1, boolean allowDuplicates, Column<?>[] cols) {
Table destination = Table.create(table1.name());
// Rename duplicate columns in second table
if (allowDuplicates) {
Set<String> table1ColNames =
Arrays.stream(cols)
.map(Column::name)
.map(String::toLowerCase)
.limit(table1.columnCount())
.collect(Collectors.toSet());
String table2Alias = TABLE_ALIAS + joinTableId.incrementAndGet();
for (int c = table1.columnCount(); c < cols.length; c++) {
String columnName = cols[c].name();
if (table1ColNames.contains(columnName.toLowerCase())) {
cols[c].setName(newName(table2Alias, columnName));
}
}
}
destination.addColumns(cols);
return destination;
}
/**
* Returns the positions of columns that can be ignored in the result table
*
* <p>For inner join, left join and full outer join mark the join columns in table2 as
* placeholders.
*
* <p>For right join, mark the join columns in table1 as placeholders. Keep track of which join
* columns are placeholders so they can be ignored.
*/
private int[] getIgnoredColumns(Table table1, JoinType joinType, Column<?>[] cols) {
int[] ignoreColumns = new int[leftJoinColumnPositions.length];
int ignoreIndex = 0;
for (int c = 0; c < cols.length; c++) {
if (joinType == JoinType.RIGHT_OUTER) {
if (c < table1.columnCount() && indexesContainsValue(leftJoinColumnPositions, c)) {
ignoreColumns[ignoreIndex] = c;
cols[c].setName(PLACEHOLDER_COL_PREFIX + cols[c].name());
ignoreIndex++;
}
} else { // JoinType is LEFT, INNER, or FULL
int table2Index = c + table1.columnCount();
if (indexesContainsValue(rightJoinColumnPositions, c)) {
ignoreColumns[ignoreIndex] = table2Index;
cols[table2Index].setName(PLACEHOLDER_COL_PREFIX + cols[table2Index].name());
ignoreIndex++;
}
}
}
return ignoreColumns;
}
private void joinInner(Table destination, Table left, Table right, int[] ignoreColumns) {
Comparator<Row> comparator = getRowComparator(left, rightJoinColumnPositions);
Row leftRow = left.row(0);
Row rightRow = right.row(0);
// Marks the position of the first record in right table that matches a specific join value
int mark = -1;
while (leftRow.hasNext() || rightRow.hasNext()) {
if (mark == -1) {
while (comparator.compare(leftRow, rightRow) < 0 && leftRow.hasNext()) leftRow.next();
while (comparator.compare(leftRow, rightRow) > 0 && rightRow.hasNext()) rightRow.next();
// set the position of the first matching record on the right side
mark = rightRow.getRowNumber();
}
if (comparator.compare(leftRow, rightRow) == 0 && (leftRow.hasNext() || rightRow.hasNext())) {
addValues(destination, leftRow, rightRow);
if (rightRow.hasNext()) {
rightRow.next();
} else {
rightRow.at(mark);
if (leftRow.hasNext()) {
leftRow.next();
}
mark = -1;
}
} else {
if (rightRow.hasNext() && leftRow.hasNext()) {
rightRow.at(mark);
leftRow.next();
mark = -1;
} else {
if (leftRow.hasNext()) leftRow.next();
if (!leftRow.hasNext()) {
break;
}
}
}
}
// add the last value if you end on a match
if (comparator.compare(leftRow, rightRow) == 0) {
addValues(destination, leftRow, rightRow);
}
}
private void joinLeft(Table destination, Table left, Table right, int[] ignoreColumns) {
joinInner(destination, left, right, ignoreColumns);
Selection unmatched =
left.intColumn(LEFT_RECORD_ID_NAME)
.isNotIn(destination.intColumn(LEFT_RECORD_ID_NAME).unique());
addLeftOnlyValues(destination, left, unmatched);
}
private void joinRight(Table destination, Table left, Table right, int[] ignoreColumns) {
joinInner(destination, left, right, ignoreColumns);
Selection unmatched =
right
.intColumn(RIGHT_RECORD_ID_NAME)
.isNotIn(destination.intColumn(RIGHT_RECORD_ID_NAME).unique());
addRightOnlyValues(destination, left, right, unmatched);
}
private void joinFull(Table destination, Table left, Table right, int[] ignoreColumns) {
Table tempDestination = destination.emptyCopy();
joinInner(destination, left, right, ignoreColumns);
Selection unmatchedLeft =
left.intColumn(LEFT_RECORD_ID_NAME)
.isNotIn(destination.intColumn(LEFT_RECORD_ID_NAME).unique());
addLeftOnlyValues(destination, left, unmatchedLeft);
Selection unmatchedRight =
right
.intColumn(RIGHT_RECORD_ID_NAME)
.isNotIn(destination.intColumn(RIGHT_RECORD_ID_NAME).unique());
addRightOnlyValues(tempDestination, left, right, unmatchedRight);
for (int i = 0; i < ignoreColumns.length; i++) {
String name = tempDestination.columnNames().get(leftJoinColumnPositions[i]);
tempDestination.replaceColumn(
leftJoinColumnPositions[i],
tempDestination.column(ignoreColumns[i]).copy().setName(name));
}
destination.append(tempDestination);
}
private void addLeftOnlyValues(Table destination, Table left, Selection unmatched) {
for (Row leftRow : left.where(unmatched)) {
Row destRow = destination.appendRow();
for (int c = 0; c < leftRow.columnCount() - 1; c++) {
updateDestinationRow(destRow, leftRow, c, c);
}
// update the index column putting it at the end of the destination table
updateDestinationRow(destRow, leftRow, destRow.columnCount() - 2, leftRow.columnCount() - 1);
}
}
private void addRightOnlyValues(Table destination, Table left, Table right, Selection unmatched) {
int leftColumnCount = left.columnCount();
for (Row rightRow : right.where(unmatched)) {
Row destRow = destination.appendRow();
for (int c = 0; c < rightRow.columnCount() - 1; c++) {
updateDestinationRow(destRow, rightRow, c + leftColumnCount - 1, c);
}
// update the index column putting it at the end of the destination table
updateDestinationRow(
destRow, rightRow, destRow.columnCount() - 1, rightRow.columnCount() - 1);
}
}
private Comparator<Row> getRowComparator(Table left, int[] rightJoinColumnIndexes) {
List<ColumnIndexPair> pairs = createJoinColumnPairs(left, rightJoinColumnIndexes);
return SortKey.getChain(SortKey.create(pairs));
}
private List<ColumnIndexPair> createJoinColumnPairs(Table left, int[] rightJoinColumnIndexes) {
List<ColumnIndexPair> pairs = new ArrayList<>();
for (int i = 0; i < leftJoinColumnPositions.length; i++) {
ColumnIndexPair columnIndexPair =
new ColumnIndexPair(
left.column(leftJoinColumnPositions[i]).type(),
leftJoinColumnPositions[i],
rightJoinColumnIndexes[i]);
pairs.add(columnIndexPair);
}
return pairs;
}
private void updateDestinationRow(
Row destRow, Row sourceRow, int destColumnPosition, int sourceColumnPosition) {
ColumnType type = destRow.getColumnType(destColumnPosition);
if (type.equals(ColumnType.INTEGER)) {
destRow.setInt(destColumnPosition, sourceRow.getInt(sourceColumnPosition));
} else if (type.equals(ColumnType.LONG)) {
destRow.setLong(destColumnPosition, sourceRow.getLong(sourceColumnPosition));
} else if (type.equals(ColumnType.SHORT)) {
destRow.setShort(destColumnPosition, sourceRow.getShort(sourceColumnPosition));
} else if (type.equals(ColumnType.STRING)) {
destRow.setString(destColumnPosition, sourceRow.getString(sourceColumnPosition));
} else if (type.equals(ColumnType.LOCAL_DATE)) {
destRow.setPackedDate(destColumnPosition, sourceRow.getPackedDate(sourceColumnPosition));
} else if (type.equals(ColumnType.LOCAL_TIME)) {
destRow.setPackedTime(destColumnPosition, sourceRow.getPackedTime(sourceColumnPosition));
} else if (type.equals(ColumnType.LOCAL_DATE_TIME)) {
destRow.setPackedDateTime(
destColumnPosition, sourceRow.getPackedDateTime(sourceColumnPosition));
} else if (type.equals(ColumnType.INSTANT)) {
destRow.setPackedInstant(
destColumnPosition, sourceRow.getPackedInstant(sourceColumnPosition));
} else if (type.equals(ColumnType.DOUBLE)) {
destRow.setDouble(destColumnPosition, sourceRow.getDouble(sourceColumnPosition));
} else if (type.equals(ColumnType.FLOAT)) {
destRow.setFloat(destColumnPosition, sourceRow.getFloat(sourceColumnPosition));
} else if (type.equals(ColumnType.BOOLEAN)) {
destRow.setBooleanAsByte(
destColumnPosition, sourceRow.getBooleanAsByte(sourceColumnPosition));
}
}
private void addValues(Table destination, Row leftRow, Row rightRow) {
Row destRow = destination.appendRow();
// update positionally, but take into account the RECORD_ID COLUMNS at the end of the dest table
int leftColumnCount = leftRow.columnCount();
int rightColumnCount = rightRow.columnCount();
// update from the left table first (everythint but the RECORD_ID column)
for (int destIdx1 = 0; destIdx1 < leftColumnCount - 1; destIdx1++) {
updateDestinationRow(destRow, leftRow, destIdx1, destIdx1);
}
// update from the right table (everythint but the RECORD_ID column)
for (int destIdx2 = (leftColumnCount - 1);
destIdx2 < (leftColumnCount + rightColumnCount) - 2;
destIdx2++) {
int rightIndex = destIdx2 - (leftColumnCount - 1);
updateDestinationRow(destRow, rightRow, destIdx2, rightIndex);
}
// update the RECORD_ID columns
updateDestinationRow(destRow, leftRow, destRow.columnCount() - 2, leftColumnCount - 1);
updateDestinationRow(destRow, rightRow, destRow.columnCount() - 1, rightColumnCount - 1);
}
private boolean indexesContainsValue(int[] joinColumnIndexes, int columnIndex) {
for (int i : joinColumnIndexes) {
if (columnIndex == i) {
return true;
}
}
return false;
}
private void validateJoinColumns(Table table1, Table table2) {
if (leftJoinColumnPositions.length != rightJoinColumnPositions.length) {
throw new IllegalArgumentException(
"Cannot join using a different number of indices on each table: "
+ Arrays.toString(leftJoinColumnPositions)
+ " and "
+ Arrays.toString(rightJoinColumnPositions));
}
for (int i = 0; i < leftJoinColumnPositions.length; i++) {
if (!table1
.column(leftJoinColumnPositions[i])
.getClass()
.equals(table2.column(rightJoinColumnPositions[i]).getClass())) {
throw new IllegalArgumentException(
"Cannot join using different index types: "
+ Arrays.toString(leftJoinColumnPositions)
+ " and "
+ Arrays.toString(rightJoinColumnPositions));
}
}
}
@Override
public String toString() {
return "SortMergeJoin";
}
}