TestVariableWidthSorting.java
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.arrow.algorithm.sort;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Stream;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BaseVariableWidthVector;
import org.apache.arrow.vector.ValueVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.util.Text;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
/** Test sorting variable width vectors with random data. */
public class TestVariableWidthSorting<V extends BaseVariableWidthVector, U extends Comparable<U>> {
static final int[] VECTOR_LENGTHS = new int[] {2, 5, 10, 50, 100, 1000, 3000};
static final double[] NULL_FRACTIONS = {0, 0.1, 0.3, 0.5, 0.7, 0.9, 1};
private BufferAllocator allocator;
@BeforeEach
public void prepare() {
allocator = new RootAllocator(Integer.MAX_VALUE);
}
@AfterEach
public void shutdown() {
allocator.close();
}
@ParameterizedTest
@MethodSource("getParameters")
public void testSort(
int length,
double nullFraction,
Function<BufferAllocator, V> vectorGenerator,
TestSortingUtil.DataGenerator<V, U> dataGenerator) {
sortOutOfPlace(length, nullFraction, vectorGenerator, dataGenerator);
}
void sortOutOfPlace(
int length,
double nullFraction,
Function<BufferAllocator, V> vectorGenerator,
TestSortingUtil.DataGenerator<V, U> dataGenerator) {
try (V vector = vectorGenerator.apply(allocator)) {
U[] array = dataGenerator.populate(vector, length, nullFraction);
Arrays.sort(array, (Comparator<? super U>) new StringComparator());
// sort the vector
VariableWidthOutOfPlaceVectorSorter sorter = new VariableWidthOutOfPlaceVectorSorter();
VectorValueComparator<V> comparator =
DefaultVectorComparators.createDefaultComparator(vector);
try (V sortedVec =
(V) vector.getField().getFieldType().createNewSingleVector("", allocator, null)) {
int dataSize = vector.getOffsetBuffer().getInt(vector.getValueCount() * 4L);
sortedVec.allocateNew(dataSize, vector.getValueCount());
sortedVec.setValueCount(vector.getValueCount());
sorter.sortOutOfPlace(vector, sortedVec, comparator);
// verify results
verifyResults(sortedVec, (String[]) array);
}
}
}
public static Stream<Arguments> getParameters() {
List<Arguments> params = new ArrayList<>();
for (int length : VECTOR_LENGTHS) {
for (double nullFrac : NULL_FRACTIONS) {
params.add(
Arguments.of(
length,
nullFrac,
(Function<BufferAllocator, VarCharVector>)
allocator -> new VarCharVector("vector", allocator),
TestSortingUtil.STRING_GENERATOR));
}
}
return params.stream();
}
/** Verify results as byte arrays. */
public static <V extends ValueVector> void verifyResults(V vector, String[] expected) {
assertEquals(vector.getValueCount(), expected.length);
for (int i = 0; i < expected.length; i++) {
if (expected[i] == null) {
assertTrue(vector.isNull(i));
} else {
assertArrayEquals(
((Text) vector.getObject(i)).getBytes(), expected[i].getBytes(StandardCharsets.UTF_8));
}
}
}
/**
* String comparator with the same behavior as that of {@link
* DefaultVectorComparators.VariableWidthComparator}.
*/
static class StringComparator implements Comparator<String> {
@Override
public int compare(String str1, String str2) {
if (str1 == null || str2 == null) {
if (str1 == null && str2 == null) {
return 0;
}
return str1 == null ? -1 : 1;
}
byte[] bytes1 = str1.getBytes(StandardCharsets.UTF_8);
byte[] bytes2 = str2.getBytes(StandardCharsets.UTF_8);
for (int i = 0; i < bytes1.length && i < bytes2.length; i++) {
if (bytes1[i] != bytes2[i]) {
return (bytes1[i] & 0xff) < (bytes2[i] & 0xff) ? -1 : 1;
}
}
return bytes1.length - bytes2.length;
}
}
}