TestIndexSorter.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.algorithm.sort;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
/** Test cases for {@link IndexSorter}. */
public class TestIndexSorter {
private BufferAllocator allocator;
@BeforeEach
public void prepare() {
allocator = new RootAllocator(1024 * 1024);
}
@AfterEach
public void shutdown() {
allocator.close();
}
@Test
public void testIndexSort() {
try (IntVector vec = new IntVector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
// fill data to sort
ValueVectorDataPopulator.setVector(vec, 11, 8, 33, 10, 12, 17, null, 23, 35, 2);
// sort the index
IndexSorter<IntVector> indexSorter = new IndexSorter<>();
DefaultVectorComparators.IntComparator intComparator =
new DefaultVectorComparators.IntComparator();
intComparator.attachVector(vec);
IntVector indices = new IntVector("", allocator);
indices.setValueCount(10);
indexSorter.sort(vec, indices, intComparator);
int[] expected = new int[] {6, 9, 1, 3, 0, 4, 5, 7, 2, 8};
for (int i = 0; i < expected.length; i++) {
assertTrue(!indices.isNull(i));
assertEquals(expected[i], indices.get(i));
}
indices.close();
}
}
/**
* Tests the worst case for quick sort. It may cause stack overflow if the algorithm is
* implemented as a recursive algorithm.
*/
@Test
public void testSortLargeIncreasingInt() {
final int vectorLength = 20000;
try (IntVector vec = new IntVector("", allocator)) {
vec.allocateNew(vectorLength);
vec.setValueCount(vectorLength);
// fill data to sort
for (int i = 0; i < vectorLength; i++) {
vec.set(i, i);
}
// sort the vector
IndexSorter<IntVector> indexSorter = new IndexSorter<>();
DefaultVectorComparators.IntComparator intComparator =
new DefaultVectorComparators.IntComparator();
intComparator.attachVector(vec);
try (IntVector indices = new IntVector("", allocator)) {
indices.setValueCount(vectorLength);
indexSorter.sort(vec, indices, intComparator);
for (int i = 0; i < vectorLength; i++) {
assertTrue(!indices.isNull(i));
assertEquals(i, indices.get(i));
}
}
}
}
@Test
public void testChoosePivot() {
final int vectorLength = 100;
try (IntVector vec = new IntVector("vector", allocator);
IntVector indices = new IntVector("indices", allocator)) {
vec.allocateNew(vectorLength);
indices.allocateNew(vectorLength);
// the vector is sorted, so the pivot should be in the middle
for (int i = 0; i < vectorLength; i++) {
vec.set(i, i * 100);
indices.set(i, i);
}
vec.setValueCount(vectorLength);
indices.setValueCount(vectorLength);
VectorValueComparator<IntVector> comparator =
DefaultVectorComparators.createDefaultComparator(vec);
// setup internal data structures
comparator.attachVector(vec);
int low = 5;
int high = 6;
assertTrue(high - low + 1 < FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD);
// the range is small enough, so the pivot is simply selected as the low value
int pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(pivotIndex, low);
assertEquals(pivotIndex, indices.get(low));
low = 30;
high = 80;
assertTrue(high - low + 1 >= FixedWidthInPlaceVectorSorter.STOP_CHOOSING_PIVOT_THRESHOLD);
// the range is large enough, so the median is selected as the pivot
pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(pivotIndex, (low + high) / 2);
assertEquals(pivotIndex, indices.get(low));
}
}
/** Evaluates choosing pivot for all possible permutations of 3 numbers. */
@Test
public void testChoosePivotAllPermutes() {
try (IntVector vec = new IntVector("vector", allocator);
IntVector indices = new IntVector("indices", allocator)) {
vec.allocateNew();
indices.allocateNew();
VectorValueComparator<IntVector> comparator =
DefaultVectorComparators.createDefaultComparator(vec);
// setup internal data structures
comparator.attachVector(vec);
int low = 0;
int high = 2;
// test all the 6 permutations of 3 numbers
ValueVectorDataPopulator.setVector(indices, 0, 1, 2);
ValueVectorDataPopulator.setVector(vec, 11, 22, 33);
int pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(1, pivotIndex);
assertEquals(1, indices.get(low));
ValueVectorDataPopulator.setVector(indices, 0, 1, 2);
ValueVectorDataPopulator.setVector(vec, 11, 33, 22);
pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(2, pivotIndex);
assertEquals(2, indices.get(low));
ValueVectorDataPopulator.setVector(indices, 0, 1, 2);
ValueVectorDataPopulator.setVector(vec, 22, 11, 33);
pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(0, pivotIndex);
assertEquals(0, indices.get(low));
ValueVectorDataPopulator.setVector(indices, 0, 1, 2);
ValueVectorDataPopulator.setVector(vec, 22, 33, 11);
pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(0, pivotIndex);
assertEquals(0, indices.get(low));
ValueVectorDataPopulator.setVector(indices, 0, 1, 2);
ValueVectorDataPopulator.setVector(vec, 33, 11, 22);
pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(2, pivotIndex);
assertEquals(2, indices.get(low));
ValueVectorDataPopulator.setVector(indices, 0, 1, 2);
ValueVectorDataPopulator.setVector(vec, 33, 22, 11);
pivotIndex = IndexSorter.choosePivot(low, high, indices, comparator);
assertEquals(1, pivotIndex);
assertEquals(1, indices.get(low));
}
}
}