@@ -12,17 +12,6 @@ struct BallTree{V <: AbstractVector,N,T,M <: Metric} <: NNTree{V,M}
12
12
reordered:: Bool # If the data has been reordered
13
13
end
14
14
15
- # When we create the bounding spheres we need some temporary arrays.
16
- # We create a type to hold them to not allocate these arrays at every
17
- # function call and to reduce the number of parameters in the tree builder.
18
- struct ArrayBuffers{N,T <: AbstractFloat }
19
- center:: MVector{N,T}
20
- end
21
-
22
- function ArrayBuffers (:: Type{Val{N}} , :: Type{T} ) where {N, T}
23
- ArrayBuffers (zeros (MVector{N,T}))
24
- end
25
-
26
15
"""
27
16
BallTree(data [, metric = Euclidean(), leafsize = 10]) -> balltree
28
17
@@ -33,14 +22,14 @@ function BallTree(data::AbstractVector{V},
33
22
leafsize:: Int = 10 ,
34
23
reorder:: Bool = true ,
35
24
storedata:: Bool = true ,
25
+ parallel:: Bool = true ,
36
26
reorderbuffer:: Vector{V} = Vector {V} ()) where {V <: AbstractArray , M <: Metric }
37
27
reorder = ! isempty (reorderbuffer) || (storedata ? reorder : false )
38
28
39
29
tree_data = TreeData (data, leafsize)
40
30
n_d = length (V)
41
31
n_p = length (data)
42
32
43
- array_buffs = ArrayBuffers (Val{length (V)}, get_T (eltype (V)))
44
33
indices = collect (1 : n_p)
45
34
46
35
# Bottom up creation of hyper spheres so need spheres even for leafs)
@@ -70,7 +59,8 @@ function BallTree(data::AbstractVector{V},
70
59
if n_p > 0
71
60
# Call the recursive BallTree builder
72
61
build_BallTree (1 , data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
73
- 1 , length (data), tree_data, array_buffs, reorder)
62
+ 1 , length (data), tree_data, reorder, Val (parallel))
63
+
74
64
end
75
65
76
66
if reorder
@@ -86,6 +76,7 @@ function BallTree(data::AbstractVecOrMat{T},
86
76
leafsize:: Int = 10 ,
87
77
storedata:: Bool = true ,
88
78
reorder:: Bool = true ,
79
+ parallel:: Bool = true ,
89
80
reorderbuffer:: Matrix{T} = Matrix {T} (undef, 0 , 0 )) where {T <: AbstractFloat , M <: Metric }
90
81
dim = size (data, 1 )
91
82
npoints = size (data, 2 )
@@ -96,7 +87,7 @@ function BallTree(data::AbstractVecOrMat{T},
96
87
reorderbuffer_points = copy_svec (T, reorderbuffer, Val (dim))
97
88
end
98
89
BallTree (points, metric, leafsize = leafsize, storedata = storedata, reorder = reorder,
99
- reorderbuffer = reorderbuffer_points)
90
+ parallel = parallel, reorderbuffer = reorderbuffer_points)
100
91
end
101
92
102
93
# Recursive function to build the tree.
@@ -110,16 +101,16 @@ function build_BallTree(index::Int,
110
101
low:: Int ,
111
102
high:: Int ,
112
103
tree_data:: TreeData ,
113
- array_buffs :: ArrayBuffers{N,T} ,
114
- reorder :: Bool ) where {V <: AbstractVector , N, T}
104
+ reorder :: Bool ,
105
+ parallel :: Val{false} ) where {V <: AbstractVector , N, T}
115
106
116
107
n_points = high - low + 1 # Points left
117
108
if n_points <= tree_data. leafsize
118
109
if reorder
119
110
reorder_data! (data_reordered, data, index, indices, indices_reordered, tree_data)
120
111
end
121
112
# Create bounding sphere of points in leaf node by brute force
122
- hyper_spheres[index] = create_bsphere (data, metric, indices, low, high, array_buffs )
113
+ hyper_spheres[index] = create_bsphere (data, metric, indices, low, high)
123
114
return
124
115
end
125
116
@@ -132,22 +123,74 @@ function build_BallTree(index::Int,
132
123
133
124
# Sort the data at the mid_idx boundary using the split_dim
134
125
# to compare
135
- select_spec! (indices, mid_idx, low, high, data, split_dim)
126
+ select_spec! (indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads
136
127
137
128
build_BallTree (getleft (index), data, data_reordered, hyper_spheres, metric,
138
- indices, indices_reordered, low, mid_idx - 1 ,
139
- tree_data, array_buffs, reorder )
129
+ indices, indices_reordered, low, mid_idx - 1 ,
130
+ tree_data, reorder, parallel )
140
131
141
132
build_BallTree (getright (index), data, data_reordered, hyper_spheres, metric,
142
- indices, indices_reordered, mid_idx, high,
143
- tree_data, array_buffs, reorder )
133
+ indices, indices_reordered, mid_idx, high,
134
+ tree_data, reorder, parallel )
144
135
145
136
# Finally create bounding hyper sphere from the two children's hyper spheres
146
137
hyper_spheres[index] = create_bsphere (metric, hyper_spheres[getleft (index)],
147
- hyper_spheres[getright (index)],
148
- array_buffs)
138
+ hyper_spheres[getright (index)])
139
+ return
149
140
end
150
141
142
+ # Parallelized recursive function to build the tree.
143
+ function build_BallTree (index:: Int ,
144
+ data:: Vector{V} ,
145
+ data_reordered:: Vector{V} ,
146
+ hyper_spheres:: Vector{HyperSphere{N,T}} ,
147
+ metric:: Metric ,
148
+ indices:: Vector{Int} ,
149
+ indices_reordered:: Vector{Int} ,
150
+ low:: Int ,
151
+ high:: Int ,
152
+ tree_data:: TreeData ,
153
+ reorder:: Bool ,
154
+ parallel:: Val{true} ) where {V <: AbstractVector , N, T}
155
+
156
+ n_points = high - low + 1 # Points left
157
+ if n_points <= tree_data. leafsize
158
+ if reorder
159
+ reorder_data! (data_reordered, data, index, indices, indices_reordered, tree_data)
160
+ end
161
+ # Create bounding sphere of points in leaf node by brute force
162
+ hyper_spheres[index] = create_bsphere (data, metric, indices, low, high)
163
+ return
164
+ end
165
+
166
+ # Find split such that one of the sub trees has 2^p points
167
+ # and the left sub tree has more points
168
+ mid_idx = find_split (low, tree_data. leafsize, n_points)
169
+
170
+ # Brute force to find the dimension with the largest spread
171
+ split_dim = find_largest_spread (data, indices, low, high)
172
+
173
+ # Sort the data at the mid_idx boundary using the split_dim
174
+ # to compare
175
+ select_spec! (indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads
176
+
177
+ @sync begin
178
+ @spawn build_BallTree (getleft (index), data, data_reordered, hyper_spheres, metric,
179
+ indices, indices_reordered, low, mid_idx - 1 ,
180
+ tree_data, reorder, parallel)
181
+
182
+ @spawn build_BallTree (getright (index), data, data_reordered, hyper_spheres, metric,
183
+ indices, indices_reordered, mid_idx, high,
184
+ tree_data, reorder, parallel)
185
+ end
186
+
187
+ # Finally create bounding hyper sphere from the two children's hyper spheres
188
+ hyper_spheres[index] = create_bsphere (metric, hyper_spheres[getleft (index)],
189
+ hyper_spheres[getright (index)])
190
+ return
191
+ end
192
+
193
+
151
194
function _knn (tree:: BallTree ,
152
195
point:: AbstractVector ,
153
196
best_idxs:: AbstractVector{Int} ,
0 commit comments