Skip to content

Commit f6acba9

Browse files
parallelizing BallTree construction
1 parent 33ccb17 commit f6acba9

File tree

7 files changed

+116
-81
lines changed

7 files changed

+116
-81
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
version:
18-
- '1.0'
18+
- '1.3'
1919
- '1'
2020
- 'nightly'
2121
os:

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
name = "NearestNeighbors"
22
uuid = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
3-
version = "0.4.10"
3+
version = "0.5.0"
44

55
[deps]
66
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
77
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
8+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
89

910
[compat]
1011
Distances = "0.9, 0.10"
1112
StaticArrays = "0.9, 0.10, 0.11, 0.12, 1.0"
12-
julia = "1.0"
13+
julia = "1.3"
1314

1415
[extras]
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/NearestNeighbors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import Distances: Metric, result_type, eval_reduce, eval_end, eval_op, eval_star
55

66
using StaticArrays
77
import Base.show
8-
using Base.Threads: @threads
8+
using Base.Threads
99

1010
export NNTree, BruteTree, KDTree, BallTree, DataFreeTree
1111
export knn, nn, inrange # TODOs? , allpairs, distmat, npairs

src/ball_tree.jl

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,6 @@ struct BallTree{V <: AbstractVector,N,T,M <: Metric} <: NNTree{V,M}
1212
reordered::Bool # If the data has been reordered
1313
end
1414

15-
# When we create the bounding spheres we need some temporary arrays.
16-
# We create a type to hold them to not allocate these arrays at every
17-
# function call and to reduce the number of parameters in the tree builder.
18-
struct ArrayBuffers{N,T <: AbstractFloat}
19-
center::MVector{N,T}
20-
end
21-
22-
function ArrayBuffers(::Type{Val{N}}, ::Type{T}) where {N, T}
23-
ArrayBuffers(zeros(MVector{N,T}))
24-
end
25-
2615
"""
2716
BallTree(data [, metric = Euclidean(), leafsize = 10]) -> balltree
2817
@@ -33,14 +22,14 @@ function BallTree(data::AbstractVector{V},
3322
leafsize::Int = 10,
3423
reorder::Bool = true,
3524
storedata::Bool = true,
25+
parallel::Bool = true,
3626
reorderbuffer::Vector{V} = Vector{V}()) where {V <: AbstractArray, M <: Metric}
3727
reorder = !isempty(reorderbuffer) || (storedata ? reorder : false)
3828

3929
tree_data = TreeData(data, leafsize)
4030
n_d = length(V)
4131
n_p = length(data)
4232

43-
array_buffs = ArrayBuffers(Val{length(V)}, get_T(eltype(V)))
4433
indices = collect(1:n_p)
4534

4635
# Bottom up creation of hyper spheres so need spheres even for leafs)
@@ -70,7 +59,8 @@ function BallTree(data::AbstractVector{V},
7059
if n_p > 0
7160
# Call the recursive BallTree builder
7261
build_BallTree(1, data, data_reordered, hyper_spheres, metric, indices, indices_reordered,
73-
1, length(data), tree_data, array_buffs, reorder)
62+
1, length(data), tree_data, reorder, Val(parallel))
63+
7464
end
7565

7666
if reorder
@@ -86,6 +76,7 @@ function BallTree(data::AbstractVecOrMat{T},
8676
leafsize::Int = 10,
8777
storedata::Bool = true,
8878
reorder::Bool = true,
79+
parallel::Bool = true,
8980
reorderbuffer::Matrix{T} = Matrix{T}(undef, 0, 0)) where {T <: AbstractFloat, M <: Metric}
9081
dim = size(data, 1)
9182
npoints = size(data, 2)
@@ -96,7 +87,7 @@ function BallTree(data::AbstractVecOrMat{T},
9687
reorderbuffer_points = copy_svec(T, reorderbuffer, Val(dim))
9788
end
9889
BallTree(points, metric, leafsize = leafsize, storedata = storedata, reorder = reorder,
99-
reorderbuffer = reorderbuffer_points)
90+
parallel = parallel, reorderbuffer = reorderbuffer_points)
10091
end
10192

10293
# Recursive function to build the tree.
@@ -110,16 +101,16 @@ function build_BallTree(index::Int,
110101
low::Int,
111102
high::Int,
112103
tree_data::TreeData,
113-
array_buffs::ArrayBuffers{N,T},
114-
reorder::Bool) where {V <: AbstractVector, N, T}
104+
reorder::Bool,
105+
parallel::Val{false}) where {V <: AbstractVector, N, T}
115106

116107
n_points = high - low + 1 # Points left
117108
if n_points <= tree_data.leafsize
118109
if reorder
119110
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data)
120111
end
121112
# Create bounding sphere of points in leaf node by brute force
122-
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high, array_buffs)
113+
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high)
123114
return
124115
end
125116

@@ -132,22 +123,74 @@ function build_BallTree(index::Int,
132123

133124
# Sort the data at the mid_idx boundary using the split_dim
134125
# to compare
135-
select_spec!(indices, mid_idx, low, high, data, split_dim)
126+
select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads
136127

137128
build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
138-
indices, indices_reordered, low, mid_idx - 1,
139-
tree_data, array_buffs, reorder)
129+
indices, indices_reordered, low, mid_idx - 1,
130+
tree_data, reorder, parallel)
140131

141132
build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
142-
indices, indices_reordered, mid_idx, high,
143-
tree_data, array_buffs, reorder)
133+
indices, indices_reordered, mid_idx, high,
134+
tree_data, reorder, parallel)
144135

145136
# Finally create bounding hyper sphere from the two children's hyper spheres
146137
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
147-
hyper_spheres[getright(index)],
148-
array_buffs)
138+
hyper_spheres[getright(index)])
139+
return
149140
end
150141

142+
# Parallelized recursive function to build the tree.
143+
function build_BallTree(index::Int,
144+
data::Vector{V},
145+
data_reordered::Vector{V},
146+
hyper_spheres::Vector{HyperSphere{N,T}},
147+
metric::Metric,
148+
indices::Vector{Int},
149+
indices_reordered::Vector{Int},
150+
low::Int,
151+
high::Int,
152+
tree_data::TreeData,
153+
reorder::Bool,
154+
parallel::Val{true}) where {V <: AbstractVector, N, T}
155+
156+
n_points = high - low + 1 # Points left
157+
if n_points <= tree_data.leafsize
158+
if reorder
159+
reorder_data!(data_reordered, data, index, indices, indices_reordered, tree_data)
160+
end
161+
# Create bounding sphere of points in leaf node by brute force
162+
hyper_spheres[index] = create_bsphere(data, metric, indices, low, high)
163+
return
164+
end
165+
166+
# Find split such that one of the sub trees has 2^p points
167+
# and the left sub tree has more points
168+
mid_idx = find_split(low, tree_data.leafsize, n_points)
169+
170+
# Brute force to find the dimension with the largest spread
171+
split_dim = find_largest_spread(data, indices, low, high)
172+
173+
# Sort the data at the mid_idx boundary using the split_dim
174+
# to compare
175+
select_spec!(indices, mid_idx, low, high, data, split_dim) # culprit? technically, low and high should be disjoint for different threads
176+
177+
@sync begin
178+
@spawn build_BallTree(getleft(index), data, data_reordered, hyper_spheres, metric,
179+
indices, indices_reordered, low, mid_idx - 1,
180+
tree_data, reorder, parallel)
181+
182+
@spawn build_BallTree(getright(index), data, data_reordered, hyper_spheres, metric,
183+
indices, indices_reordered, mid_idx, high,
184+
tree_data, reorder, parallel)
185+
end
186+
187+
# Finally create bounding hyper sphere from the two children's hyper spheres
188+
hyper_spheres[index] = create_bsphere(metric, hyper_spheres[getleft(index)],
189+
hyper_spheres[getright(index)])
190+
return
191+
end
192+
193+
151194
function _knn(tree::BallTree,
152195
point::AbstractVector,
153196
best_idxs::AbstractVector{Int},

src/hyperspheres.jl

Lines changed: 22 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ end
77

88
HyperSphere(center::SVector{N,T1}, r::T2) where {N, T1, T2} = HyperSphere(center, convert(T1, r))
99

10+
Base.:(==)(A::HyperSphere, B::HyperSphere) = A.center == B.center && A.r == B.r
11+
1012
@inline function intersects(m::M,
1113
s1::HyperSphere{N,T},
1214
s2::HyperSphere{N,T}) where {T <: AbstractFloat, N, M <: Metric}
@@ -19,55 +21,22 @@ end
1921
evaluate(m, s1.center, s2.center) + s1.r <= s2.r
2022
end
2123

22-
@inline function interpolate(::M,
23-
c1::V,
24-
c2::V,
25-
x,
26-
d,
27-
ab) where {V <: AbstractVector, M <: NormMetric}
28-
alpha = x / d
29-
@assert length(c1) == length(c2)
30-
@inbounds for i in eachindex(ab.center)
31-
ab.center[i] = (1 - alpha) .* c1[i] + alpha .* c2[i]
32-
end
33-
return ab.center, true
34-
end
35-
36-
@inline function interpolate(::M,
37-
c1::V,
38-
::V,
39-
::Any,
40-
::Any,
41-
::Any) where {V <: AbstractVector, M <: Metric}
42-
return c1, false
43-
end
44-
45-
function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high, ab) where {V}
46-
n_dim = size(data, 1)
47-
n_points = high - low + 1
48-
# First find center of all points
49-
fill!(ab.center, 0.0)
50-
for i in low:high
51-
for j in 1:length(ab.center)
52-
ab.center[j] += data[indices[i]][j]
53-
end
54-
end
55-
ab.center .*= 1 / n_points
56-
24+
# versions with no array buffer - still not allocating in sequential BallTree construction
25+
using Statistics: mean
26+
function create_bsphere(data::AbstractVector{V}, metric::Metric, indices::Vector{Int}, low, high) where {V}
27+
# find center
28+
center = mean(@views(data[indices[low:high]]))
5729
# Then find r
5830
r = zero(get_T(eltype(V)))
5931
for i in low:high
60-
r = max(r, evaluate(metric, data[indices[i]], ab.center))
32+
r = max(r, evaluate(metric, data[indices[i]], center))
6133
end
6234
r += eps(get_T(eltype(V)))
63-
return HyperSphere(SVector{length(V),eltype(V)}(ab.center), r)
35+
return HyperSphere(SVector{length(V),eltype(V)}(center), r)
6436
end
6537

6638
# Creates a bounding sphere from two other spheres
67-
function create_bsphere(m::Metric,
68-
s1::HyperSphere{N,T},
69-
s2::HyperSphere{N,T},
70-
ab) where {N, T <: AbstractFloat}
39+
function create_bsphere(m::Metric, s1::HyperSphere{N,T}, s2::HyperSphere{N,T}) where {N, T <: AbstractFloat}
7140
if encloses(m, s1, s2)
7241
return HyperSphere(s2.center, s2.r)
7342
elseif encloses(m, s2, s1)
@@ -79,7 +48,7 @@ function create_bsphere(m::Metric,
7948
# neither s1 nor s2 contains the other)
8049
dist = evaluate(m, s1.center, s2.center)
8150
x = 0.5 * (s2.r - s1.r + dist)
82-
center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist, ab)
51+
center, is_exact_center = interpolate(m, s1.center, s2.center, x, dist)
8352
if is_exact_center
8453
rad = 0.5 * (s2.r + s1.r + dist)
8554
else
@@ -88,3 +57,14 @@ function create_bsphere(m::Metric,
8857

8958
return HyperSphere(SVector{N,T}(center), rad)
9059
end
60+
61+
@inline function interpolate(::M, c1::V, c2::V, x, d) where {V <: AbstractVector, M <: NormMetric}
62+
length(c1) == length(c2) || throw(DimensionMismatch("interpolate arguments have length $(length(c1)) and $(length(c2))"))
63+
alpha = x / d
64+
center = (1 - alpha) * c1 + alpha * c2
65+
return center, true
66+
end
67+
68+
@inline function interpolate(::M, c1::V, ::V, ::Any, ::Any) where {V <: AbstractVector, M <: Metric}
69+
return c1, false
70+
end

test/runtests.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@ using LinearAlgebra
66

77
using Distances: Distances, Metric, evaluate, PeriodicEuclidean
88
struct CustomMetric1 <: Metric end
9-
Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs.(a .- b))
9+
Distances.evaluate(::CustomMetric1, a::AbstractVector, b::AbstractVector) = maximum(abs, (a .- b))
1010
function NearestNeighbors.interpolate(::CustomMetric1,
1111
a::V,
1212
b::V,
1313
x,
14-
d,
15-
ab) where {V <: AbstractVector}
14+
d) where {V <: AbstractVector}
1615
idx = (abs.(b .- a) .>= d - x)
1716
c = copy(Array(a))
1817
c[idx] = (1 - x / d) * a[idx] + (x / d) * b[idx]

test/test_monkey.jl

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import NearestNeighbors.MinkowskiMetric
22
# This contains a bunch of random tests that should hopefully detect if
33
# some edge case has been missed in the real tests
4-
5-
64
@testset "metric $metric" for metric in fullmetrics
5+
nrep = 30
76
@testset "tree type $TreeType" for TreeType in trees_with_brute
87
@testset "element type $T" for T in (Float32, Float64)
98
@testset "knn monkey" begin
@@ -14,7 +13,7 @@ import NearestNeighbors.MinkowskiMetric
1413
elseif TreeType == BallTree && isa(metric, Hamming)
1514
continue
1615
end
17-
for i in 1:30
16+
for i in 1:nrep
1817
dim_data = rand(1:4)
1918
size_data = rand(1000:1300)
2019
data = rand(T, dim_data, size_data)
@@ -28,7 +27,7 @@ import NearestNeighbors.MinkowskiMetric
2827
end
2928

3029
# Compares vs Brute Force
31-
for i in 1:30
30+
for i in 1:nrep
3231
dim_data = rand(1:5)
3332
size_data = rand(100:151)
3433
data = rand(T, dim_data, size_data)
@@ -45,7 +44,7 @@ import NearestNeighbors.MinkowskiMetric
4544

4645
@testset "inrange monkey" begin
4746
# Test against brute force
48-
for i in 1:30
47+
for i in 1:nrep
4948
dim_data = rand(1:6)
5049
size_data = rand(20:250)
5150
data = rand(T, dim_data, size_data)
@@ -62,17 +61,30 @@ import NearestNeighbors.MinkowskiMetric
6261
end
6362

6463
@testset "coupled monkey" begin
65-
for i in 1:50
64+
for i in 1:nrep
6665
dim_data = rand(1:5)
6766
size_data = rand(100:1000)
6867
data = randn(T, dim_data, size_data)
69-
tree = TreeType(data, metric; leafsize = rand(1:8))
68+
69+
lf = rand(1:8)
70+
tree = TreeType(data, metric; leafsize = lf)
71+
72+
if TreeType == BallTree # this caught a race-condition in an early version of the parallel BallTree code
73+
tree2 = TreeType(data, metric; leafsize = lf, parallel = false)
74+
@test tree.data == tree2.data
75+
@test tree.hyper_spheres[1] == tree2.hyper_spheres[1]
76+
@test tree.indices == tree2.indices
77+
@test tree.metric == tree2.metric
78+
@test tree.tree_data == tree2.tree_data
79+
@test tree.reordered == tree2.reordered
80+
end
81+
7082
point = randn(dim_data)
7183
idxs_ball = Int[]
7284
r = 0.1
7385
while length(idxs_ball) < 10
7486
r *= 2.0
75-
idxs_ball = inrange(tree, point, r, true)
87+
idxs_ball = inrange(tree, point, r, true)
7688
end
7789
idxs_knn, dists = knn(tree, point, length(idxs_ball))
7890

0 commit comments

Comments
 (0)