Art.java
package org.roaringbitmap.art;
import org.roaringbitmap.ArraysShim;
import org.roaringbitmap.longlong.LongUtils;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.Arrays;
/**
* See: https://db.in.tum.de/~leis/papers/ART.pdf a cpu cache friendly main memory data structure.
* At our case, the LeafNode's key is always 48 bit size. The high 48 bit keys here are compared
* using the byte dictionary comparison.
*/
public class Art {
private Node root;
private long keySize = 0;
final static byte[] EMPTY_BYTES = new byte[0];
public Art() {
root = null;
}
public boolean isEmpty() {
return root == null;
}
/**
* insert the 48 bit key and the corresponding containerIdx
*
* @param key the high 48 bit of the long data
* @param containerIdx the container index
*/
public void insert(byte[] key, long containerIdx) {
Node freshRoot = insert(root, key, 0, containerIdx);
if (freshRoot != root) {
this.root = freshRoot;
}
keySize++;
}
/**
* @param key the high 48 bit of the long data
* @return the key's corresponding containerIdx
*/
public long findByKey(byte[] key) {
Node node = findByKey(root, key, 0);
if (node != null) {
LeafNode leafNode = (LeafNode) node;
return leafNode.containerIdx;
}
return BranchNode.ILLEGAL_IDX;
}
/**
* @param key the high 48 bit of the long data
* @return the key's corresponding containerIdx
*/
public long findByKey(long key) {
LeafNode node = findByKey(root, key);
if (node != null) {
return node.containerIdx;
}
return BranchNode.ILLEGAL_IDX;
}
private Node findByKey(Node node, byte[] key, int depth) {
while (node != null) {
if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
byte[] leafNodeKeyBytes = leafNode.getKeyBytes();
if (depth == LeafNode.LEAF_NODE_KEY_LENGTH_IN_BYTES) {
return leafNode;
}
int mismatchIndex =
ArraysShim.mismatch(
leafNodeKeyBytes,
depth,
LeafNode.LEAF_NODE_KEY_LENGTH_IN_BYTES,
key,
depth,
LeafNode.LEAF_NODE_KEY_LENGTH_IN_BYTES);
if (mismatchIndex != -1) {
return null;
}
return leafNode;
}
BranchNode branchNode = (BranchNode) node;
byte branchNodePrefixLength = branchNode.prefixLength();
if (branchNodePrefixLength > 0) {
int commonLength =
commonPrefixLength(key, depth, key.length, branchNode.prefix, 0, branchNodePrefixLength);
if (commonLength != branchNodePrefixLength) {
return null;
}
// common prefix is the same ,then increase the depth
depth += branchNodePrefixLength;
}
int pos = branchNode.getChildPos(key[depth]);
if (pos == BranchNode.ILLEGAL_IDX) {
return null;
}
node = branchNode.getChild(pos);
depth++;
}
return null;
}
private LeafNode findByKey(Node node, long key) {
int depth = 0;
while (node != null) {
//compare branch node first, its most common case
if (node instanceof BranchNode) {
BranchNode branchNode = (BranchNode) node;
byte branchNodePrefixLength = branchNode.prefixLength();
if (branchNodePrefixLength > 0) {
//TODO - we should expose a prefix() that is a long. So much time spend looping here
// when this could be a O(1) long mask & compare
byte[] prefix = branchNode.prefix;
for (int i = 0; i < branchNodePrefixLength; i++) {
// compare the prefix byte with the key byte
if (prefix[i] != LongUtils.getByte(key, depth + i)) {
return null;
}
}
// common prefix is the same ,then increase the depth
depth += branchNodePrefixLength;
}
//TODO - expose an API that avoids this double dipping
int pos = branchNode.getChildPos(LongUtils.getByte(key, depth));
if (pos == BranchNode.ILLEGAL_IDX) {
return null;
}
node = branchNode.getChild(pos);
depth++;
} else {
LeafNode leafNode = (LeafNode) node;
long leafNodeKey = leafNode.getKey();
return leafNodeKey == LongUtils.rightShiftHighPart(key)? leafNode: null;
}
}
return null;
}
/**
* a convenient method to traverse the key space in ascending order.
* @param containers input containers
* @return the key iterator
*/
public KeyIterator iterator(Containers containers) {
return new KeyIterator(this, containers);
}
/**
* remove the key from the art if it's there.
* @param key the high 48 bit key
* @return the corresponding containerIdx or -1 indicating not exist
*/
public long remove(byte[] key) {
Toolkit toolkit = removeSpecifyKey(root, key, 0);
if (toolkit != null) {
return toolkit.matchedContainerId;
}
return BranchNode.ILLEGAL_IDX;
}
protected Toolkit removeSpecifyKey(Node node, byte[] key, int dep) {
if (node == null) {
return null;
}
if (node instanceof LeafNode) {
// root is null
LeafNode leafNode = (LeafNode) node;
if (leafMatch(leafNode, key, dep)) {
// remove this node
if (leafNode == this.root) {
this.root = null;
}
keySize--;
return new Toolkit(null, leafNode.getContainerIdx(), null);
} else {
return null;
}
}
BranchNode branchNode = (BranchNode) node;
byte branchNodePrefixLength = branchNode.prefixLength();
if (branchNodePrefixLength > 0) {
int commonLength =
commonPrefixLength(key, dep, key.length, branchNode.prefix, 0, branchNodePrefixLength);
if (commonLength != branchNodePrefixLength) {
return null;
}
dep += branchNodePrefixLength;
}
int pos = branchNode.getChildPos(key[dep]);
if (pos != BranchNode.ILLEGAL_IDX) {
Node child = branchNode.getChild(pos);
if (child instanceof LeafNode && leafMatch((LeafNode) child, key, dep)) {
// found matched leaf node from the current node.
Node freshNode = branchNode.remove(pos);
keySize--;
if (branchNode == this.root && freshNode != branchNode) {
this.root = freshNode;
}
long matchedContainerIdx = ((LeafNode) child).getContainerIdx();
Toolkit toolkit = new Toolkit(freshNode, matchedContainerIdx, branchNode);
toolkit.needToVerifyReplacing = true;
return toolkit;
} else {
Toolkit toolkit = removeSpecifyKey(child, key, dep + 1);
if (toolkit != null
&& toolkit.needToVerifyReplacing
&& toolkit.freshMatchedParentNode != null
&& toolkit.freshMatchedParentNode != toolkit.originalMatchedParentNode) {
// meaning find the matched key and the shrinking happened
branchNode.replaceNode(pos, toolkit.freshMatchedParentNode);
toolkit.needToVerifyReplacing = false;
return toolkit;
}
if (toolkit != null) {
return toolkit;
}
}
}
return null;
}
class Toolkit {
Node freshMatchedParentNode; // indicating a fresh parent node while the original
// parent node shrunk and changed
long matchedContainerId; // holding the matched key's corresponding container index id
Node
originalMatchedParentNode; // holding the matched key's leaf node's original old parent node
boolean needToVerifyReplacing = false; // indicate whether the shrinking node's parent
// node has replaced its corresponding child node
Toolkit(Node freshMatchedParentNode, long matchedContainerId, Node originalMatchedParentNode) {
this.freshMatchedParentNode = freshMatchedParentNode;
this.matchedContainerId = matchedContainerId;
this.originalMatchedParentNode = originalMatchedParentNode;
}
}
private boolean leafMatch(LeafNode leafNode, byte[] key, int dep) {
byte[] leafNodeKeyBytes = leafNode.getKeyBytes();
int mismatchIndex =
ArraysShim.mismatch(
leafNodeKeyBytes,
dep,
LeafNode.LEAF_NODE_KEY_LENGTH_IN_BYTES,
key,
dep,
LeafNode.LEAF_NODE_KEY_LENGTH_IN_BYTES);
if (mismatchIndex == -1) {
return true;
} else {
return false;
}
}
private Node insert(Node node, byte[] key, int depth, long containerIdx) {
if (node == null) {
LeafNode leafNode = new LeafNode(key, containerIdx);
return leafNode;
}
if (node instanceof LeafNode) {
LeafNode leafNode = (LeafNode) node;
byte[] prefix = leafNode.getKeyBytes();
int commonPrefix = commonPrefixLength(prefix, depth, prefix.length, key, depth, key.length);
Node4 node4 = new Node4(commonPrefix);
// copy common prefix
System.arraycopy(key, depth, node4.prefix, 0, commonPrefix);
// generate two leaf nodes as the children of the fresh node4
node4.insert(leafNode, prefix[depth + commonPrefix]);
LeafNode anotherLeaf = new LeafNode(key, containerIdx);
node4.insert(anotherLeaf, key[depth + commonPrefix]);
// replace the current node with this internal node4
return node4;
}
BranchNode branchNode = (BranchNode) node;
byte branchNodePrefixLength = branchNode.prefixLength();
// to a inner node case
if (branchNodePrefixLength > 0) {
// find the mismatch position
int mismatchPos =
ArraysShim.mismatch(branchNode.prefix, 0, branchNodePrefixLength, key, depth, key.length);
if (mismatchPos != branchNodePrefixLength) {
Node4 node4 = new Node4(mismatchPos);
// copy prefix
System.arraycopy(branchNode.prefix, 0, node4.prefix, 0, mismatchPos);
// split the current internal node, spawn a fresh node4 and let the
// current internal node as its children.
node4.insert(branchNode, branchNode.prefix[mismatchPos]);
int newPrefixLength = (int) branchNodePrefixLength - (mismatchPos + 1);
// move the remained common prefix of the initial internal node
// as the new prefix is always > 0, we just allocate and fill the new prefix
branchNode.prefix = Arrays.copyOfRange(branchNode.prefix,mismatchPos + 1, branchNodePrefixLength);
LeafNode leafNode = new LeafNode(key, containerIdx);
node4.insert(leafNode, key[mismatchPos + depth]);
return node4;
}
depth += branchNodePrefixLength;
}
int pos = branchNode.getChildPos(key[depth]);
if (pos != BranchNode.ILLEGAL_IDX) {
// insert the key as current internal node's children's child node.
Node child = branchNode.getChild(pos);
Node freshOne = insert(child, key, depth + 1, containerIdx);
if (freshOne != child) {
branchNode.replaceNode(pos, freshOne);
}
return branchNode;
}
// insert the key as a child leaf node of the current internal node
LeafNode leafNode = new LeafNode(key, containerIdx);
return branchNode.insert(leafNode, key[depth]);
}
// find common prefix length
static int commonPrefixLength(
byte[] key1, int aFromIndex, int aToIndex, byte[] key2, int bFromIndex, int bToIndex) {
int aLength = aToIndex - aFromIndex;
int bLength = bToIndex - bFromIndex;
int minLength = Math.min(aLength, bLength);
int mismatchIndex = ArraysShim.mismatch(key1, aFromIndex, aToIndex, key2, bFromIndex, bToIndex);
if (aLength != bLength && mismatchIndex >= minLength) {
return minLength;
}
return mismatchIndex;
}
public Node getRoot() {
return root;
}
private LeafNode getExtremeLeaf(boolean reverse) {
Node parent = getRoot();
for (int depth = 0; depth < AbstractShuttle.MAX_DEPTH; depth++) {
if (parent instanceof BranchNode) {
BranchNode branchNode = (BranchNode) parent;
int childIndex = reverse ? branchNode.getMaxPos() : branchNode.getMinPos();
parent = branchNode.getChild(childIndex);
}
}
return (LeafNode) parent;
}
public LeafNode first() {
return getExtremeLeaf(false);
}
public LeafNode last() {
return getExtremeLeaf(true);
}
public void serializeArt(DataOutput dataOutput) throws IOException {
dataOutput.writeLong(Long.reverseBytes(keySize));
serialize(root, dataOutput);
}
public void deserializeArt(DataInput dataInput) throws IOException {
keySize = Long.reverseBytes(dataInput.readLong());
root = deserialize(dataInput);
}
public void serializeArt(ByteBuffer byteBuffer) throws IOException {
byteBuffer.putLong(keySize);
serialize(root, byteBuffer);
}
public void deserializeArt(ByteBuffer byteBuffer) throws IOException {
keySize = byteBuffer.getLong();
root = deserialize(byteBuffer);
}
public LeafNodeIterator leafNodeIterator(boolean reverse, Containers containers) {
return new LeafNodeIterator(this, reverse, containers);
}
public LeafNodeIterator leafNodeIteratorFrom(long bound, boolean reverse, Containers containers) {
return new LeafNodeIterator(this, reverse, containers, bound);
}
private void serialize(Node node, DataOutput dataOutput) throws IOException {
if (node instanceof BranchNode) {
BranchNode branchNode = (BranchNode)node;
// serialize the internal node itself first
branchNode.serialize(dataOutput);
// then all the internal node's children
int nexPos = branchNode.getNextLargerPos(BranchNode.ILLEGAL_IDX);
while (nexPos != BranchNode.ILLEGAL_IDX) {
// serialize all the not null child node
Node child = branchNode.getChild(nexPos);
serialize(child, dataOutput);
nexPos = branchNode.getNextLargerPos(nexPos);
}
} else {
// serialize the leaf node
node.serialize(dataOutput);
}
}
private void serialize(Node node, ByteBuffer byteBuffer) throws IOException {
if (node instanceof BranchNode) {
BranchNode branchNode = (BranchNode)node;
// serialize the internal node itself first
branchNode.serialize(byteBuffer);
// then all the internal node's children
int nexPos = branchNode.getNextLargerPos(BranchNode.ILLEGAL_IDX);
while (nexPos != BranchNode.ILLEGAL_IDX) {
// serialize all the not null child node
Node child = branchNode.getChild(nexPos);
serialize(child, byteBuffer);
nexPos = branchNode.getNextLargerPos(nexPos);
}
} else {
// serialize the leaf node
node.serialize(byteBuffer);
}
}
private Node deserialize(DataInput dataInput) throws IOException {
Node oneNode = Node.deserialize(dataInput);
if (oneNode == null) {
return null;
}
if (oneNode instanceof LeafNode) {
return oneNode;
} else {
BranchNode branch = (BranchNode) oneNode;
// internal node
int count = branch.count;
// all the not null child nodes
Node[] children = new Node[count];
for (int i = 0; i < count; i++) {
Node child = deserialize(dataInput);
children[i] = child;
}
branch.replaceChildren(children);
return branch;
}
}
private Node deserialize(ByteBuffer byteBuffer) throws IOException {
Node oneNode = Node.deserialize(byteBuffer);
if (oneNode == null) {
return null;
}
if (oneNode instanceof LeafNode) {
return oneNode;
} else {
BranchNode branchNode = (BranchNode) oneNode;
// internal node
int count = branchNode.count;
// all the not null child nodes
Node[] children = new Node[count];
for (int i = 0; i < count; i++) {
Node child = deserialize(byteBuffer);
children[i] = child;
}
branchNode.replaceChildren(children);
return branchNode;
}
}
public long serializeSizeInBytes() {
return serializeSizeInBytes(this.root) + 8;
}
public long getKeySize() {
return keySize;
}
private long serializeSizeInBytes(Node node) {
if (node instanceof BranchNode) {
BranchNode branchNode = (BranchNode) node;
// serialize the internal node itself first
int currentNodeSize = branchNode.serializeSizeInBytes();
// then all the internal node's children
long childrenTotalSize = 0L;
int nexPos = branchNode.getNextLargerPos(BranchNode.ILLEGAL_IDX);
while (nexPos != BranchNode.ILLEGAL_IDX) {
// serialize all the not null child node
Node child = branchNode.getChild(nexPos);
long childSize = serializeSizeInBytes(child);
nexPos = branchNode.getNextLargerPos(nexPos);
childrenTotalSize += childSize;
}
return currentNodeSize + childrenTotalSize;
} else {
// serialize the leaf node
int nodeSize = node.serializeSizeInBytes();
return nodeSize;
}
}
}