15
15
* @version 1.0
16
16
*/
17
17
public class KDTree implements Serializable {
18
-
18
+
19
19
/**
20
20
* Random number generator
21
21
*/
22
22
private static final Random random = new Random ();
23
23
24
- /**
24
+ /**
25
25
* The head node of the kd tree
26
26
*/
27
27
private KDTreeNode head ;
@@ -30,7 +30,7 @@ public class KDTree implements Serializable {
30
30
* The dimensionality of the tree (k)
31
31
*/
32
32
private int dimensions ;
33
-
33
+
34
34
/**
35
35
* The distance measure to use
36
36
*/
@@ -51,7 +51,7 @@ public KDTree(DataSet keys, DistanceMeasure distance) {
51
51
}
52
52
head = buildTree (nodes , 0 , nodes .length );
53
53
}
54
-
54
+
55
55
/**
56
56
* Build a kd tree from the given parallel arrays
57
57
* of keys and data
@@ -63,12 +63,23 @@ public KDTree(DataSet keys) {
63
63
this (keys , new EuclideanDistance ());
64
64
}
65
65
66
+ /**
67
+ * Builds a tree from a list of nodes
68
+ * @param nodes the list of nodes
69
+ * @param start the starting index
70
+ * @param end the ending index
71
+ * @return the head node of the build tree
72
+ */
73
+ private KDTreeNode buildTree (KDTreeNode [] nodes , int start , int end ) {
74
+ return buildTree (nodes , start , end , 0 );
75
+ }
76
+
66
77
/**
67
78
* Build a tree from a list of nodes
68
79
* @param nodes the list of nodes
69
80
* @return the head node of the built tree
70
81
*/
71
- private KDTreeNode buildTree (KDTreeNode [] nodes , int start , int end ) {
82
+ private KDTreeNode buildTree (KDTreeNode [] nodes , int start , int end , int depth ) {
72
83
if (start >= end ) {
73
84
// if we're done return null
74
85
return null ;
@@ -77,16 +88,16 @@ private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) {
77
88
return nodes [start ];
78
89
}
79
90
// choose splitter
80
- int splitterIndex = chooseSplitterRandom (nodes , start , end );
91
+ int splitterIndex = chooseApproxBestSplitter (nodes , start , end , depth );
81
92
KDTreeNode splitter = nodes [splitterIndex ];
82
- // patition based on splitter
93
+ // partition based on splitter
83
94
splitterIndex = partition (nodes , start , end , splitterIndex );
84
95
// recursively build tree
85
- splitter .setLeft (buildTree (nodes , start , splitterIndex ));
86
- splitter .setRight (buildTree (nodes , splitterIndex + 1 , end ));
96
+ splitter .setLeft (buildTree (nodes , start , splitterIndex , depth + 1 ));
97
+ splitter .setRight (buildTree (nodes , splitterIndex + 1 , end , depth + 1 ));
87
98
return splitter ;
88
99
}
89
-
100
+
90
101
/**
91
102
* Partition an array based on a splitter
92
103
* @param comparables the array
@@ -95,7 +106,7 @@ private KDTreeNode buildTree(KDTreeNode[] nodes, int start, int end) {
95
106
* @param splitterIndex the splitter's index
96
107
* @return the new splitter index
97
108
*/
98
- private int partition (Comparable [] comparables , int start , int end ,
109
+ private int partition (Comparable [] comparables , int start , int end ,
99
110
int splitterIndex ) {
100
111
swap (comparables , splitterIndex , end - 1 );
101
112
splitterIndex = end - 1 ;
@@ -110,7 +121,7 @@ private int partition(Comparable[] comparables, int start, int end,
110
121
swap (comparables , splitterIndex , i + 1 );
111
122
return i + 1 ;
112
123
}
113
-
124
+
114
125
/**
115
126
* Swap two elements in an array
116
127
* @param objects the array
@@ -122,7 +133,7 @@ private void swap(Object[] objects, int i, int j) {
122
133
objects [i ] = objects [j ];
123
134
objects [j ] = temp ;
124
135
}
125
-
136
+
126
137
/**
127
138
* Choose a random splitter
128
139
* @param nodes the nodes to choose from
@@ -136,7 +147,81 @@ private int chooseSplitterRandom(KDTreeNode[] nodes, int start, int end) {
136
147
nodes [splitter ].setDimension (dimension );
137
148
return splitter ;
138
149
}
139
-
150
+
151
+
152
+ /**
153
+ * Use quickSelect and medianOfMedians to select a item
154
+ * guarenteed to be in the middle 50% of the data. Along the chosen dimension
155
+ * This function perpares the data to be processed by the quickSelect function
156
+ * It selects dimensions in order, so that each level of the final KDTree
157
+ * splits on a different dimension
158
+ * @param nodes the nodes to choose from
159
+ * @param start the starting index
160
+ * @param end the ending index
161
+ * @param depth the tree depth that the splitter will be placed at
162
+ * @return the index of the splitter
163
+ */
164
+ private int chooseApproxBestSplitter (KDTreeNode [] nodes , int start , int end , int depth ){
165
+ int dimension = depth % dimensions ;
166
+ for (int k = start ; k < end ; k ++) nodes [k ].setDimension (dimension );
167
+ return medianOfMedians (nodes ,start ,end );
168
+ }
169
+ /**
170
+ * This function implements quickSelect on the passed in KDTreeNode Array
171
+ * quickSelect returns the n's element of a list, defined by a predefined
172
+ * ordering of elements (implemented in KDTreeNode)
173
+ *
174
+ * @param n the element, by number, we want to return
175
+ * @param nodes the list of KDTreeNodes to search through
176
+ * @param start the lower bound of the search, inclusive
177
+ * @param end the upper bound of the search, exclusive
178
+ * @return the index of element n
179
+ */
180
+ private int quickSelect (int n , KDTreeNode [] nodes , int start , int end ){
181
+ while (start != end ){
182
+ // if (n < start || n >=end){
183
+ // System.out.println("ERROR");
184
+ // return start;
185
+ // }
186
+ int pivot = medianOfMedians (nodes ,start ,end );
187
+ pivot = partition (nodes ,start ,end ,pivot );
188
+ if (n == pivot ) return n ;
189
+ else if (n > pivot ) start = pivot +1 ;
190
+ else end = pivot ;
191
+ }
192
+ return start ;
193
+ }
194
+
195
+ /**
196
+ * This function implements medianOfMedians on the passed in KDTreeNode Array
197
+ * this fn returns an acceptable splitter, guarenteed to be greater than 25%
198
+ * of the data and less than 25% of the data.
199
+ *
200
+ * @param nodes the list of KDTreeNodes to search through
201
+ * @param start the lower bound of the search, inclusive
202
+ * @param end the upper bound of the search, exclusive
203
+ * @return the index of the splitter
204
+ */
205
+ private int medianOfMedians (KDTreeNode [] nodes , int start , int end ){
206
+ int MEDIAN_SIZE = 5 ;
207
+ int partitions ;
208
+ int length = end -start ;
209
+ if (length < 10 ) {
210
+ Arrays .sort (nodes ,start ,end );
211
+ return start + length /2 ;
212
+ }
213
+ partitions = length /MEDIAN_SIZE ;
214
+ if (length % MEDIAN_SIZE != 0 ) partitions ++;
215
+ for (int i = 0 ; i < partitions ; i ++){
216
+ int pstart = start + (i *MEDIAN_SIZE );
217
+ int pend = Math .min (pstart +MEDIAN_SIZE , end );
218
+ int pmiddle = pstart + (pend -pstart )/2 ;
219
+ Arrays .sort (nodes ,pstart ,pend );
220
+ swap (nodes ,start +i ,pmiddle );
221
+ }
222
+ return quickSelect (start +(partitions /2 ),nodes ,start ,start +partitions );
223
+
224
+ }
140
225
141
226
/**
142
227
* Choose a splitter from a list of nodes
@@ -173,7 +258,7 @@ private int chooseSplitterSmart(KDTreeNode[] nodes, int start, int end) {
173
258
double median = (max [widestDimension ] - min [widestDimension ]) / 2 ;
174
259
// find the best splitter
175
260
double bestDifference = Double .POSITIVE_INFINITY ;
176
- int splitterIndex = -1 ;
261
+ int splitterIndex = -1 ;
177
262
for (int i = start ; i < end ; i ++) {
178
263
KDTreeNode node = nodes [i ];
179
264
if (Math .abs (node .getInstance ().getContinuous (widestDimension ) - median )
@@ -198,7 +283,7 @@ public Instance[] knn(Instance target, int k) {
198
283
knn (head , target , new HyperRectangle (dimensions ), results );
199
284
return results .getNearest ();
200
285
}
201
-
286
+
202
287
/**
203
288
* Perform a nearest neighbor search
204
289
* @param target the target
@@ -207,7 +292,7 @@ public Instance[] knn(Instance target, int k) {
207
292
public Instance [] nn (Instance target ) {
208
293
NearestNeighborQueue results = new NearestNeighborQueue ();
209
294
knn (head , target , new HyperRectangle (dimensions ), results );
210
- return results .getNearest ();
295
+ return results .getNearest ();
211
296
}
212
297
213
298
/**
@@ -219,9 +304,9 @@ public Instance[] nn(Instance target) {
219
304
public Instance [] range (Instance target , double range ) {
220
305
NearestNeighborQueue results = new NearestNeighborQueue (range );
221
306
knn (head , target , new HyperRectangle (dimensions ), results );
222
- return results .getNearest ();
307
+ return results .getNearest ();
223
308
}
224
-
309
+
225
310
/**
226
311
* Perform a k nearest neighbor range search
227
312
* @param target the target
@@ -232,17 +317,17 @@ public Instance[] range(Instance target, double range) {
232
317
public Instance [] knnrange (Instance target , int k , double range ) {
233
318
NearestNeighborQueue results = new NearestNeighborQueue (k , range );
234
319
knn (head , target , new HyperRectangle (dimensions ), results );
235
- return results .getNearest ();
320
+ return results .getNearest ();
236
321
}
237
-
322
+
238
323
/**
239
324
* Perform a nearest neighbor search
240
325
* @param node the node to search on
241
326
* @param target the target
242
327
* @param hr the hyper rectangle
243
328
* @param results the current results
244
329
*/
245
- private void knn (KDTreeNode node , Instance target , HyperRectangle hr ,
330
+ private void knn (KDTreeNode node , Instance target , HyperRectangle hr ,
246
331
NearestNeighborQueue results ) {
247
332
if (node == null ) {
248
333
return ;
@@ -262,11 +347,12 @@ private void knn(KDTreeNode node, Instance target, HyperRectangle hr,
262
347
}
263
348
knn (nearNode , target , nearHR , results );
264
349
if (distanceMeasure .value (
265
- farHR .pointNearestTo (target ), target )
350
+ farHR .pointNearestTo (target ), target )
266
351
<= results .getMaxDistance ()) {
267
- results .add (node .getInstance (),
352
+ results .add (node .getInstance (),
268
353
distanceMeasure .value (node .getInstance (), target ));
269
354
knn (farNode , target , farHR , results );
270
355
}
271
356
}
357
+
272
358
}
0 commit comments