StrideIndex.java
/*
* Copyright (C) 2020 Matteo Di Giovinazzo, Samuel Audet
*
* Licensed either under the Apache License, Version 2.0, or (at your option)
* under the terms of the GNU General Public License as published by
* the Free Software Foundation (subject to the "Classpath" exception),
* either version 2, or any later version (collectively, the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.gnu.org/licenses/
* http://www.gnu.org/software/classpath/license.html
*
* or as provided in the LICENSE.txt file that accompanied this code.
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.bytedeco.javacpp.indexer;
/**
* An Index that computes a linear index from given array sizes and strides.
*
* @author Matteo Di Giovinazzo
*/
public class StrideIndex extends Index {
/**
* Returns default (row-major contiguous) strides for the given sizes.
*/
public static long[] defaultStrides(long... sizes) {
long[] strides = new long[sizes.length];
strides[sizes.length - 1] = 1;
for (int i = sizes.length - 2; i >= 0; i--) {
strides[i] = strides[i + 1] * sizes[i + 1];
}
return strides;
}
/**
* The number of elements to skip to reach the next element in a given dimension.
* {@code strides[i] > strides[i + 1] && strides[strides.length - 1] == 1} preferred.
*/
protected final long[] strides;
/** Calls {@code StrideIndex(sizes, defaultStrides(sizes))}. */
public StrideIndex(long... sizes) {
this(sizes, defaultStrides(sizes));
}
/** Constructor to set the {@link #sizes} and {@link #strides}. */
public StrideIndex(long[] sizes, long[] strides) {
super(sizes);
this.strides = strides;
}
/** Returns {@link #strides}. */
public long[] strides() {
return strides;
}
/** Returns {@code i * strides[0]}. */
@Override public long index(long i) {
return i * strides[0];
}
/** Returns {@code i * strides[0] + j * strides[1]}. */
@Override public long index(long i, long j) {
return i * strides[0] + j * strides[1];
}
/** Returns {@code i * strides[0] + j * strides[1] + k * strides[2]}. */
@Override public long index(long i, long j, long k) {
return i * strides[0] + j * strides[1] + k * strides[2];
}
/**
* Computes the linear index as the dot product of indices and strides.
*
* @param indices of each dimension
* @return index to access array or buffer
*/
@Override public long index(long... indices) {
long index = 0;
for (int i = 0; i < indices.length && i < strides.length; i++) {
index += indices[i] * strides[i];
}
return index;
}
}