Skip to content

Commit 1a5162d

Browse files
authored
Merge pull request #49 from rob2001/KDTree
K-D Tree Build Improvements
2 parents 4a99340 + b81722e commit 1a5162d

File tree

2 files changed

+195
-24
lines changed

2 files changed

+195
-24
lines changed

src/func/inst/KDTree.java

Lines changed: 110 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@
1515
* @version 1.0
1616
*/
1717
public class KDTree implements Serializable {
18-
18+
1919
/**
2020
* Random number generator
2121
*/
2222
private static final Random random = new Random();
2323

24-
/**
24+
/**
2525
* The head node of the kd tree
2626
*/
2727
private KDTreeNode head;
@@ -30,7 +30,7 @@ public class KDTree implements Serializable {
3030
* The dimensionality of the tree (k)
3131
*/
3232
private int dimensions;
33-
33+
3434
/**
3535
* The distance measure to use
3636
*/
@@ -51,7 +51,7 @@ public KDTree(DataSet keys, DistanceMeasure distance) {
5151
}
5252
head = buildTree(nodes, 0, nodes.length);
5353
}
54-
54+
5555
/**
5656
* Build a kd tree from the given parallel arrays
5757
* of keys and data
@@ -63,12 +63,23 @@ public KDTree(DataSet keys) {
6363
this(keys, new EuclideanDistance());
6464
}
6565

66+
/**
67+
* Builds a tree from a list of nodes
68+
* @param nodes the list of nodes
69+
* @param start the starting index
70+
* @param end the ending index
71+
* @return the head node of the build tree
72+
*/
73+
private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) {
74+
return buildTree(nodes, start, end, 0);
75+
}
76+
6677
/**
6778
* Build a tree from a list of nodes
6879
* @param nodes the list of nodes
6980
* @return the head node of the built tree
7081
*/
71-
private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) {
82+
private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end, int depth) {
7283
if (start >= end) {
7384
// if we're done return null
7485
return null;
@@ -77,16 +88,16 @@ private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) {
7788
return nodes[start];
7889
}
7990
// choose splitter
80-
int splitterIndex = chooseSplitterRandom(nodes, start, end);
91+
int splitterIndex = chooseApproxBestSplitter(nodes, start, end,depth);
8192
KDTreeNode splitter = nodes[splitterIndex];
82-
// patition based on splitter
93+
// partition based on splitter
8394
splitterIndex = partition(nodes, start, end, splitterIndex);
8495
// recursively build tree
85-
splitter.setLeft(buildTree(nodes, start, splitterIndex));
86-
splitter.setRight(buildTree(nodes, splitterIndex + 1, end));
96+
splitter.setLeft(buildTree(nodes, start, splitterIndex, depth+1));
97+
splitter.setRight(buildTree(nodes, splitterIndex + 1, end, depth+1));
8798
return splitter;
8899
}
89-
100+
90101
/**
91102
* Partition an array based on a splitter
92103
* @param comparables the array
@@ -95,7 +106,7 @@ private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) {
95106
* @param splitterIndex the splitter's index
96107
* @return the new splitter index
97108
*/
98-
private int partition(Comparable[] comparables, int start, int end,
109+
private int partition(Comparable[] comparables, int start, int end,
99110
int splitterIndex) {
100111
swap(comparables, splitterIndex, end - 1);
101112
splitterIndex = end - 1;
@@ -110,7 +121,7 @@ private int partition(Comparable[] comparables, int start, int end,
110121
swap(comparables, splitterIndex, i + 1);
111122
return i + 1;
112123
}
113-
124+
114125
/**
115126
* Swap two elements in an array
116127
* @param objects the array
@@ -122,7 +133,7 @@ private void swap(Object[] objects, int i, int j) {
122133
objects[i] = objects[j];
123134
objects[j] = temp;
124135
}
125-
136+
126137
/**
127138
* Choose a random splitter
128139
* @param nodes the nodes to choose from
@@ -136,7 +147,81 @@ private int chooseSplitterRandom(KDTreeNode[] nodes, int start, int end) {
136147
nodes[splitter].setDimension(dimension);
137148
return splitter;
138149
}
139-
150+
151+
152+
/**
153+
* Use quickSelect and medianOfMedians to select a item
154+
* guarenteed to be in the middle 50% of the data. Along the chosen dimension
155+
* This function perpares the data to be processed by the quickSelect function
156+
* It selects dimensions in order, so that each level of the final KDTree
157+
* splits on a different dimension
158+
* @param nodes the nodes to choose from
159+
* @param start the starting index
160+
* @param end the ending index
161+
* @param depth the tree depth that the splitter will be placed at
162+
* @return the index of the splitter
163+
*/
164+
private int chooseApproxBestSplitter(KDTreeNode[] nodes, int start, int end, int depth){
165+
int dimension = depth % dimensions;
166+
for (int k = start; k < end; k++) nodes[k].setDimension(dimension);
167+
return medianOfMedians(nodes,start,end);
168+
}
169+
/**
170+
* This function implements quickSelect on the passed in KDTreeNode Array
171+
* quickSelect returns the n's element of a list, defined by a predefined
172+
* ordering of elements (implemented in KDTreeNode)
173+
*
174+
* @param n the element, by number, we want to return
175+
* @param nodes the list of KDTreeNodes to search through
176+
* @param start the lower bound of the search, inclusive
177+
* @param end the upper bound of the search, exclusive
178+
* @return the index of element n
179+
*/
180+
private int quickSelect(int n, KDTreeNode[] nodes, int start, int end){
181+
while (start != end){
182+
// if (n < start || n >=end){
183+
// System.out.println("ERROR");
184+
// return start;
185+
// }
186+
int pivot = medianOfMedians(nodes,start,end);
187+
pivot = partition(nodes,start,end,pivot);
188+
if (n == pivot) return n;
189+
else if (n > pivot) start = pivot+1;
190+
else end = pivot;
191+
}
192+
return start;
193+
}
194+
195+
/**
196+
* This function implements medianOfMedians on the passed in KDTreeNode Array
197+
* this fn returns an acceptable splitter, guarenteed to be greater than 25%
198+
* of the data and less than 25% of the data.
199+
*
200+
* @param nodes the list of KDTreeNodes to search through
201+
* @param start the lower bound of the search, inclusive
202+
* @param end the upper bound of the search, exclusive
203+
* @return the index of the splitter
204+
*/
205+
private int medianOfMedians(KDTreeNode[] nodes, int start, int end){
206+
int MEDIAN_SIZE = 5;
207+
int partitions;
208+
int length = end-start;
209+
if (length < 10) {
210+
Arrays.sort(nodes,start,end);
211+
return start + length/2;
212+
}
213+
partitions = length/MEDIAN_SIZE;
214+
if (length % MEDIAN_SIZE != 0) partitions++;
215+
for (int i = 0; i < partitions; i++){
216+
int pstart = start + (i*MEDIAN_SIZE);
217+
int pend = Math.min(pstart+MEDIAN_SIZE, end);
218+
int pmiddle = pstart + (pend-pstart)/2;
219+
Arrays.sort(nodes,pstart,pend);
220+
swap(nodes,start+i,pmiddle);
221+
}
222+
return quickSelect(start+(partitions/2),nodes,start,start+partitions);
223+
224+
}
140225

141226
/**
142227
* Choose a splitter from a list of nodes
@@ -173,7 +258,7 @@ private int chooseSplitterSmart(KDTreeNode[] nodes, int start, int end) {
173258
double median = (max[widestDimension] - min[widestDimension]) / 2;
174259
// find the best splitter
175260
double bestDifference = Double.POSITIVE_INFINITY;
176-
int splitterIndex = -1;
261+
int splitterIndex = -1;
177262
for (int i = start; i < end; i++) {
178263
KDTreeNode node = nodes[i];
179264
if (Math.abs(node.getInstance().getContinuous(widestDimension) - median)
@@ -198,7 +283,7 @@ public Instance[] knn(Instance target, int k) {
198283
knn(head, target, new HyperRectangle(dimensions), results);
199284
return results.getNearest();
200285
}
201-
286+
202287
/**
203288
* Perform a nearest neighbor search
204289
* @param target the target
@@ -207,7 +292,7 @@ public Instance[] knn(Instance target, int k) {
207292
public Instance[] nn(Instance target) {
208293
NearestNeighborQueue results = new NearestNeighborQueue();
209294
knn(head, target, new HyperRectangle(dimensions), results);
210-
return results.getNearest();
295+
return results.getNearest();
211296
}
212297

213298
/**
@@ -219,9 +304,9 @@ public Instance[] nn(Instance target) {
219304
public Instance[] range(Instance target, double range) {
220305
NearestNeighborQueue results = new NearestNeighborQueue(range);
221306
knn(head, target, new HyperRectangle(dimensions), results);
222-
return results.getNearest();
307+
return results.getNearest();
223308
}
224-
309+
225310
/**
226311
* Perform a k nearest neighbor range search
227312
* @param target the target
@@ -232,17 +317,17 @@ public Instance[] range(Instance target, double range) {
232317
public Instance[] knnrange(Instance target, int k, double range) {
233318
NearestNeighborQueue results = new NearestNeighborQueue(k, range);
234319
knn(head, target, new HyperRectangle(dimensions), results);
235-
return results.getNearest();
320+
return results.getNearest();
236321
}
237-
322+
238323
/**
239324
* Perform a nearest neighbor search
240325
* @param node the node to search on
241326
* @param target the target
242327
* @param hr the hyper rectangle
243328
* @param results the current results
244329
*/
245-
private void knn(KDTreeNode node, Instance target, HyperRectangle hr,
330+
private void knn(KDTreeNode node, Instance target, HyperRectangle hr,
246331
NearestNeighborQueue results) {
247332
if (node == null) {
248333
return;
@@ -262,11 +347,12 @@ private void knn(KDTreeNode node, Instance target, HyperRectangle hr,
262347
}
263348
knn(nearNode, target, nearHR, results);
264349
if (distanceMeasure.value(
265-
farHR.pointNearestTo(target), target)
350+
farHR.pointNearestTo(target), target)
266351
<= results.getMaxDistance()) {
267-
results.add(node.getInstance(),
352+
results.add(node.getInstance(),
268353
distanceMeasure.value(node.getInstance(), target));
269354
knn(farNode, target, farHR, results);
270355
}
271356
}
357+
272358
}
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package func.test;
2+
3+
import java.io.BufferedReader;
4+
import java.io.File;
5+
import java.io.FileReader;
6+
import java.util.Scanner;
7+
8+
import func.inst.KDTree;
9+
import opt.test.AbaloneTest;
10+
import shared.DataSet;
11+
import shared.Instance;
12+
13+
/**
14+
* KNN test built to test the speed of splitter selectors in func.inst.KDTree
15+
* @author Robert Smith [email protected]
16+
* @version 1.0
17+
*/
18+
public class KNNClassifierAbaloneTest {
19+
20+
public static void main(String[] args){
21+
Instance[] instances = initializeAbaloneInstances();
22+
Instance[] training = new Instance[3000];
23+
Instance[] testing = new Instance[1177];
24+
for (int i = 0; i < training.length; i++){
25+
training[i] = instances[i];
26+
}
27+
for (int i = 0; i < testing.length; i++){
28+
testing[i] = instances[training.length+i];
29+
}
30+
long buildTime = System.nanoTime();
31+
KDTree tree = new KDTree(new DataSet(training));
32+
buildTime = System.nanoTime() - buildTime;
33+
System.out.println("BuildTime = " + buildTime);
34+
35+
for (int k = 1; k < 10; k++){
36+
long testTime = 0;
37+
for (int i = 0; i < testing.length; i++){
38+
long searchTime = System.nanoTime();
39+
tree.knn(testing[i], k);
40+
searchTime = System.nanoTime() - searchTime;
41+
testTime += searchTime;
42+
}
43+
testTime /= testing.length;
44+
System.out.println("K = " + k + ", average search time = " + testTime);
45+
}
46+
47+
48+
}
49+
50+
private static Instance[] initializeAbaloneInstances() {
51+
52+
double[][][] attributes = new double[4177][][];
53+
54+
try {
55+
BufferedReader br = new BufferedReader(new FileReader(new File("src/opt/test/abalone.txt")));
56+
57+
for(int i = 0; i < attributes.length; i++) {
58+
Scanner scan = new Scanner(br.readLine());
59+
scan.useDelimiter(",");
60+
61+
attributes[i] = new double[2][];
62+
attributes[i][0] = new double[7]; // 7 attributes
63+
attributes[i][1] = new double[1];
64+
65+
for(int j = 0; j < 7; j++)
66+
attributes[i][0][j] = Double.parseDouble(scan.next());
67+
68+
attributes[i][1][0] = Double.parseDouble(scan.next());
69+
}
70+
}
71+
catch(Exception e) {
72+
e.printStackTrace();
73+
}
74+
75+
Instance[] instances = new Instance[attributes.length];
76+
77+
for(int i = 0; i < instances.length; i++) {
78+
instances[i] = new Instance(attributes[i][0]);
79+
// classifications range from 0 to 30; split into 0 - 14 and 15 - 30
80+
instances[i].setLabel(new Instance(attributes[i][1][0] < 15 ? 0 : 1));
81+
}
82+
83+
return instances;
84+
}
85+
}

0 commit comments

Comments
 (0)