Created
November 3, 2015 22:49
-
-
Save hoffrocket/d4ee2b805ae55634222c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package j.nettytest; | |
/* | |
** JkKdTree.java by Julian Kent | |
** | |
** Licenced under the Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License | |
** | |
** Licence summary: | |
** Under this licence you are free to: | |
** Share — copy and redistribute the material in any medium or format | |
** Adapt — remix, transform, and build upon the material | |
** The licensor cannot revoke these freedoms as long as you follow the license terms. | |
** | |
** Under the following terms: | |
** Attribution — You must give appropriate credit, provide a link to the license, and indicate | |
** if changes were made. You may do so in any reasonable manner, but not in any | |
** way that suggests the licensor endorses you or your use. | |
** NonCommercial — You may not use the material for commercial purposes. | |
** ShareAlike — If you remix, transform, or build upon the material, you must distribute your | |
** contributions under the same license as the original. | |
** No additional restrictions | |
** — You may not apply legal terms or technological measures that legally restrict | |
** others from doing anything the license permits. | |
** | |
** See full licencing details here: http://creativecommons.org/licenses/by-nc-sa/3.0/ | |
** | |
** For additional licencing rights please contact [email protected] | |
** | |
*/ | |
import java.util.ArrayList; | |
import java.util.Arrays; | |
public abstract class JkKdTree { | |
//use a big bucketSize so that we have less node bounds (for more cache hits) and better splits | |
private static final int _bucketSize = 50; | |
private final int _dimensions; | |
private int _nodes; | |
private final Node root; | |
private final ArrayList<Node> nodeList = new ArrayList<Node>(); | |
//prevent GC from having to collect _bucketSize*dimensions*8 bytes each time a leaf splits | |
private float[] mem_recycle; | |
//the starting values for bounding boxes, for easy access | |
private final float[] bounds_template; | |
//one big self-expanding array to keep all the node bounding boxes so that they stay in cache | |
// node bounds available at: | |
//low: 2 * _dimensions * node.index + 2 * dim | |
//high: 2 * _dimensions * node.index + 2 * dim + 1 | |
private final ContiguousFloatArrayList nodeMinMaxBounds; | |
private JkKdTree(int dimensions) { | |
_dimensions = dimensions; | |
//initialise this big so that it ends up in 'old' memory | |
nodeMinMaxBounds = new ContiguousFloatArrayList(512 * 1024 / 8 + 2 * _dimensions); | |
mem_recycle = new float[_bucketSize * dimensions]; | |
bounds_template = new float[2 * _dimensions]; | |
Arrays.fill(bounds_template, Float.NEGATIVE_INFINITY); | |
for (int i = 0, max = 2 * _dimensions; i < max; i += 2) | |
bounds_template[i] = Float.POSITIVE_INFINITY; | |
//and.... start! | |
root = new Node(); | |
} | |
public int nodes() { | |
return _nodes; | |
} | |
public int size() { | |
return root.entries; | |
} | |
public int addPoint(float[] location, long payload) { | |
Node addNode = root; | |
//Do a Depth First Search to find the Node where 'location' should be stored | |
while (addNode.pointLocations == null) { | |
addNode.expandBounds(location); | |
if (location[addNode.splitDim] < addNode.splitVal) | |
addNode = nodeList.get(addNode.lessIndex); | |
else | |
addNode = nodeList.get(addNode.moreIndex); | |
} | |
addNode.expandBounds(location); | |
int nodeSize = addNode.add(location, payload); | |
if (nodeSize % _bucketSize == 0) | |
//try splitting again once every time the node passes a _bucketSize multiple | |
//in case it is full of points of the same location and won't split | |
addNode.split(); | |
return root.entries; | |
} | |
public ArrayList<SearchResult> nearestNeighbours(float[] searchLocation, int K) { | |
IntStack stack = new IntStack(); | |
PrioQueue results = new PrioQueue(K, true); | |
stack.push(root.index); | |
int added = 0; | |
while (stack.size() > 0) { | |
int nodeIndex = stack.pop(); | |
if (added < K || results.peekPrio() > pointRectDist(nodeIndex, searchLocation)) { | |
Node node = nodeList.get(nodeIndex); | |
if (node.pointLocations == null) | |
node.search(searchLocation, stack); | |
else | |
added += node.search(searchLocation, results); | |
} | |
} | |
ArrayList<SearchResult> returnResults = new ArrayList<SearchResult>(K); | |
float[] priorities = results.priorities; | |
long[] elements = results.elements; | |
for (int i = 0; i < K; i++) {//forward (closest first) | |
SearchResult s = new SearchResult(priorities[i], elements[i]); | |
returnResults.add(s); | |
} | |
return returnResults; | |
} | |
public ArrayList<Long> ballSearch(float[] searchLocation, double radius) { | |
IntStack stack = new IntStack(); | |
ArrayList<Long> results = new ArrayList<Long>(); | |
stack.push(root.index); | |
while (stack.size() > 0) { | |
int nodeIndex = stack.pop(); | |
if (radius > pointRectDist(nodeIndex, searchLocation)) { | |
Node node = nodeList.get(nodeIndex); | |
if (node.pointLocations == null) | |
stack.push(node.moreIndex).push(node.lessIndex); | |
else | |
node.searchBall(searchLocation, radius, results); | |
} | |
} | |
return results; | |
} | |
public ArrayList<Long> rectSearch(float[] mins, float[] maxs) { | |
IntStack stack = new IntStack(); | |
ArrayList<Long> results = new ArrayList<Long>(); | |
stack.push(root.index); | |
while (stack.size() > 0) { | |
int nodeIndex = stack.pop(); | |
if (overlaps(mins, maxs, nodeIndex)) { | |
Node node = nodeList.get(nodeIndex); | |
if (node.pointLocations == null) | |
stack.push(node.moreIndex).push(node.lessIndex); | |
else | |
node.searchRect(mins, maxs, results); | |
} | |
} | |
return results; | |
} | |
abstract float pointRectDist(int offset, final float[] location); | |
abstract float pointDist(float[] arr, float[] location, int index); | |
boolean contains(float[] arr, float[] mins, float[] maxs, int index) { | |
int offset = (index + 1) * mins.length; | |
for (int i = mins.length; i-- > 0; ) { | |
float d = arr[--offset]; | |
if (mins[i] > d | d > maxs[i]) | |
return false; | |
} | |
return true; | |
} | |
boolean overlaps(float[] mins, float[] maxs, int offset) { | |
offset *= (2 * maxs.length); | |
final float[] array = nodeMinMaxBounds.array; | |
for (int i = 0; i < maxs.length; i++, offset += 2) { | |
double bmin = array[offset], bmax = array[offset + 1]; | |
if (mins[i] > bmax | maxs[i] < bmin) | |
return false; | |
} | |
return true; | |
} | |
public static class Euclidean extends JkKdTree { | |
public Euclidean(int dims) { | |
super(dims); | |
} | |
float pointRectDist(int offset, final float[] location) { | |
offset *= (2 * super._dimensions); | |
float distance = 0; | |
final float[] array = super.nodeMinMaxBounds.array; | |
for (int i = 0; i < location.length; i++, offset += 2) { | |
float diff = 0; | |
float bv = array[offset]; | |
float lv = location[i]; | |
if (bv > lv) | |
diff = bv - lv; | |
else { | |
bv = array[offset + 1]; | |
if (lv > bv) | |
diff = lv - bv; | |
} | |
distance += sqr(diff); | |
} | |
return distance; | |
} | |
float pointDist(float[] arr, float[] location, int index) { | |
float distance = 0; | |
int offset = (index + 1) * super._dimensions; | |
for (int i = super._dimensions; i-- > 0; ) { | |
distance += sqr(arr[--offset] - location[i]); | |
} | |
return distance; | |
} | |
} | |
public static class Manhattan extends JkKdTree { | |
public Manhattan(int dims) { | |
super(dims); | |
} | |
float pointRectDist(int offset, final float[] location) { | |
offset *= (2 * super._dimensions); | |
float distance = 0; | |
final float[] array = super.nodeMinMaxBounds.array; | |
for (int i = 0; i < location.length; i++, offset += 2) { | |
float diff = 0; | |
float bv = array[offset]; | |
float lv = location[i]; | |
if (bv > lv) | |
diff = bv - lv; | |
else { | |
bv = array[offset + 1]; | |
if (lv > bv) | |
diff = lv - bv; | |
} | |
distance += (diff); | |
} | |
return distance; | |
} | |
float pointDist(float[] arr, float[] location, int index) { | |
float distance = 0; | |
int offset = (index + 1) * super._dimensions; | |
for (int i = super._dimensions; i-- > 0; ) { | |
distance += Math.abs(arr[--offset] - location[i]); | |
} | |
return distance; | |
} | |
} | |
public static class WeightedManhattan extends JkKdTree { | |
float[] weights; | |
public WeightedManhattan(int dims) { | |
super(dims); | |
} | |
public void setWeights(float[] newWeights) { | |
weights = newWeights; | |
} | |
float pointRectDist(int offset, final float[] location) { | |
offset *= (2 * super._dimensions); | |
float distance = 0; | |
final float[] array = super.nodeMinMaxBounds.array; | |
for (int i = 0; i < location.length; i++, offset += 2) { | |
double diff = 0; | |
double bv = array[offset]; | |
double lv = location[i]; | |
if (bv > lv) | |
diff = bv - lv; | |
else { | |
bv = array[offset + 1]; | |
if (lv > bv) | |
diff = lv - bv; | |
} | |
distance += (diff) * weights[i]; | |
} | |
return distance; | |
} | |
float pointDist(float[] arr, float[] location, int index) { | |
float distance = 0; | |
int offset = (index + 1) * super._dimensions; | |
for (int i = super._dimensions; i-- > 0; ) { | |
distance += Math.abs(arr[--offset] - location[i]) * weights[i]; | |
} | |
return distance; | |
} | |
} | |
//NB! This Priority Queue keeps things with the LOWEST priority. | |
//If you want highest priority items kept, negate your values | |
private static class PrioQueue { | |
long[] elements; | |
float[] priorities; | |
private double minPrio; | |
private int size; | |
PrioQueue(int size, boolean prefill) { | |
elements = new long[size]; | |
priorities = new float[size]; | |
Arrays.fill(priorities, Float.POSITIVE_INFINITY); | |
if (prefill) { | |
minPrio = Float.POSITIVE_INFINITY; | |
this.size = size; | |
} | |
} | |
//uses O(log(n)) comparisons and one big shift of size O(N) | |
//and is MUCH simpler than a heap --> faster on small sets, faster JIT | |
void addNoGrow(long value, float priority) { | |
int index = searchFor(priority); | |
int nextIndex = index + 1; | |
int length = size - index - 1; | |
System.arraycopy(elements, index, elements, nextIndex, length); | |
System.arraycopy(priorities, index, priorities, nextIndex, length); | |
elements[index] = value; | |
priorities[index] = priority; | |
minPrio = priorities[size - 1]; | |
} | |
int searchFor(float priority) { | |
int i = size - 1; | |
int j = 0; | |
while (i >= j) { | |
int index = (i + j) >>> 1; | |
if (priorities[index] < priority) | |
j = index + 1; | |
else | |
i = index - 1; | |
} | |
return j; | |
} | |
double peekPrio() { | |
return minPrio; | |
} | |
} | |
public static class SearchResult { | |
public float distance; | |
public long payload; | |
SearchResult(float dist, long load) { | |
distance = dist; | |
payload = load; | |
} | |
} | |
private class Node { | |
//for accessing bounding box data | |
// - if trees weren't so unbalanced might be better to use an implicit heap? | |
int index; | |
//keep track of size of subtree | |
int entries; | |
//leaf | |
ContiguousFloatArrayList pointLocations; | |
LongList pointPayloads = new LongList(); | |
//stem | |
//Node less, more; | |
int lessIndex, moreIndex; | |
int splitDim; | |
double splitVal; | |
Node() { | |
this(new float[_bucketSize * _dimensions]); | |
} | |
Node(float[] pointMemory) { | |
pointLocations = new ContiguousFloatArrayList(pointMemory); | |
index = _nodes++; | |
nodeList.add(this); | |
nodeMinMaxBounds.add(bounds_template); | |
} | |
void search(float[] searchLocation, IntStack stack) { | |
if (searchLocation[splitDim] < splitVal) | |
stack.push(moreIndex).push(lessIndex);//less will be popped first | |
else | |
stack.push(lessIndex).push(moreIndex);//more will be popped first | |
} | |
//returns number of points added to results | |
int search(float[] searchLocation, PrioQueue results) { | |
int updated = 0; | |
for (int j = entries; j-- > 0; ) { | |
float distance = pointDist(pointLocations.array, searchLocation, j); | |
if (results.peekPrio() > distance) { | |
updated++; | |
results.addNoGrow(pointPayloads.get(j), distance); | |
} | |
} | |
return updated; | |
} | |
void searchBall(float[] searchLocation, double radius, ArrayList<Long> results) { | |
for (int j = entries; j-- > 0; ) { | |
double distance = pointDist(pointLocations.array, searchLocation, j); | |
if (radius >= distance) { | |
results.add(pointPayloads.get(j)); | |
} | |
} | |
} | |
void searchRect(float[] mins, float[] maxs, ArrayList<Long> results) { | |
for (int j = entries; j-- > 0; ) | |
if (contains(pointLocations.array, mins, maxs, j)) | |
results.add(pointPayloads.get(j)); | |
} | |
void expandBounds(float[] location) { | |
entries++; | |
int mio = index * 2 * _dimensions; | |
for (int i = 0; i < _dimensions; i++) { | |
nodeMinMaxBounds.array[mio] = Math.min(nodeMinMaxBounds.array[mio++], location[i]); | |
nodeMinMaxBounds.array[mio] = Math.max(nodeMinMaxBounds.array[mio++], location[i]); | |
} | |
} | |
int add(float[] location, long load) { | |
pointLocations.add(location); | |
pointPayloads.add(load); | |
return entries; | |
} | |
void split() { | |
int offset = index * 2 * _dimensions; | |
double diff = 0; | |
for (int i = 0; i < _dimensions; i++) { | |
double min = nodeMinMaxBounds.array[offset]; | |
double max = nodeMinMaxBounds.array[offset + 1]; | |
if (max - min > diff) { | |
double mean = 0; | |
for (int j = 0; j < entries; j++) | |
mean += pointLocations.array[i + _dimensions * j]; | |
mean = mean / entries; | |
double varianceSum = 0; | |
for (int j = 0; j < entries; j++) | |
varianceSum += sqr(mean - pointLocations.array[i + _dimensions * j]); | |
if (varianceSum > diff * entries) { | |
diff = varianceSum / entries; | |
splitVal = mean; | |
splitDim = i; | |
} | |
} | |
offset += 2; | |
} | |
//kill all the nasties | |
if (splitVal == Double.POSITIVE_INFINITY) | |
splitVal = Double.MAX_VALUE; | |
else if (splitVal == Double.NEGATIVE_INFINITY) | |
splitVal = Double.MIN_VALUE; | |
else if (splitVal == nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim + 1]) | |
splitVal = nodeMinMaxBounds.array[index * 2 * _dimensions + 2 * splitDim]; | |
Node less = new Node(mem_recycle);//recycle that memory! | |
Node more = new Node(); | |
lessIndex = less.index; | |
moreIndex = more.index; | |
//reduce garbage by factor of _bucketSize by recycling this array | |
float[] pointLocation = new float[_dimensions]; | |
for (int i = 0; i < entries; i++) { | |
System.arraycopy(pointLocations.array, i * _dimensions, pointLocation, 0, _dimensions); | |
long load = pointPayloads.get(i); | |
if (pointLocation[splitDim] < splitVal) { | |
less.expandBounds(pointLocation); | |
less.add(pointLocation, load); | |
} else { | |
more.expandBounds(pointLocation); | |
more.add(pointLocation, load); | |
} | |
} | |
if (less.entries * more.entries == 0) { | |
//one of them was 0, so the split was worthless. throw it away. | |
_nodes -= 2;//recall that bounds memory | |
nodeList.remove(moreIndex); | |
nodeList.remove(lessIndex); | |
} else { | |
//we won't be needing that now, so keep it for the next split to reduce garbage | |
mem_recycle = pointLocations.array; | |
pointLocations = null; | |
pointPayloads.clear(); | |
pointPayloads = null; | |
} | |
} | |
} | |
private static class ContiguousFloatArrayList { | |
float[] array; | |
int size; | |
ContiguousFloatArrayList() { | |
this(300); | |
} | |
ContiguousFloatArrayList(int size) { | |
this(new float[size]); | |
} | |
ContiguousFloatArrayList(float[] data) { | |
array = data; | |
} | |
ContiguousFloatArrayList add(float[] da) { | |
if (size + da.length > array.length) | |
array = Arrays.copyOf(array, (array.length + da.length) * 2); | |
System.arraycopy(da, 0, array, size, da.length); | |
size += da.length; | |
return this; | |
} | |
} | |
private static class LongList { | |
long[] array; | |
int size; | |
LongList() { | |
this(16); | |
} | |
LongList(int size) { | |
array = new long[size]; | |
} | |
void add(long l) { | |
if (size + 1 > array.length) | |
array = Arrays.copyOf(array, array.length + 1); | |
array[size] = l; | |
size ++; | |
} | |
long get(int index) { | |
return array[index]; | |
} | |
void clear() { | |
size = 0; | |
} | |
} | |
private static class IntStack { | |
int[] array; | |
int size; | |
IntStack() { | |
this(64); | |
} | |
IntStack(int size) { | |
this(new int[size]); | |
} | |
IntStack(int[] data) { | |
array = data; | |
} | |
IntStack push(int i) { | |
if (size >= array.length) | |
array = Arrays.copyOf(array, (array.length + 1) * 2); | |
array[size++] = i; | |
return this; | |
} | |
int pop() { | |
return array[--size]; | |
} | |
int size() { | |
return size; | |
} | |
} | |
static final double sqr(double d) { | |
return d * d; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment