Skip to content

Commit 4eb557c

Browse files
authored
Propagate inbounds in dnPl (#6)
* remove OffsetArray wrapper in collectPl! * propagate inbounds if possible
1 parent ccef9c0 commit 4eb557c

File tree

2 files changed

+42
-41
lines changed

2 files changed

+42
-41
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LegendrePolynomials"
22
uuid = "3db4a2ba-fc88-11e8-3e01-49c72059a882"
3-
version = "0.3.2"
3+
version = "0.3.3"
44

55
[deps]
66
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"

src/LegendrePolynomials.jl

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ end
1616

1717
assertnonnegative(l) = (l >= 0 || throw(ArgumentError("l must be >= 0, received " * string(l))))
1818

19-
function checksize(arr, lmax)
20-
maximum(axes(arr,1)) >= lmax || throw(ArgumentError("array is not large enough to store all values"))
21-
end
2219
function checklength(arr, minlength)
2320
length(arr) >= minlength || throw(ArgumentError(
2421
"array is not large enough to store all values, require a minimum length of " * string(minlength)))
@@ -168,6 +165,39 @@ function _checkvalues(x, l, n)
168165
n >= 0 || throw(ArgumentError("order of derivative n must be >= 0"))
169166
end
170167

168+
Base.@propagate_inbounds function _unsafednPl!(cache, x, l, n)
169+
# unsafe, assumes 1-based indexing
170+
checklength(cache, l - n + 1)
171+
if n == l # may short-circuit this
172+
cache[1] = doublefactorial(eltype(cache), 2l-1)
173+
else
174+
collectPl!(cache, x, lmax = l - n)
175+
176+
for ni in 1:n
177+
# We denote the terms as P_ni_li
178+
179+
# li == ni
180+
P_nim1_nim1 = cache[1]
181+
P_ni_ni = dPl_recursion(eltype(cache), ni, ni, nothing, P_nim1_nim1, nothing, x)
182+
cache[1] = P_ni_ni
183+
184+
# li == ni + 1
185+
P_nim1_ni = cache[2]
186+
P_ni_nip1 = dPl_recursion(eltype(cache), ni + 1, ni, P_ni_ni, P_nim1_ni, nothing, x)
187+
cache[2] = P_ni_nip1
188+
189+
for li in ni+2:min(l, l - n + ni)
190+
P_ni_lim2 = cache[li - ni - 1]
191+
P_ni_lim1 = cache[li - ni]
192+
P_nim1_lim1 = cache[li - ni + 1]
193+
P_ni_li = dPl_recursion(eltype(cache), li, ni, P_ni_lim1, P_nim1_lim1, P_ni_lim2, x)
194+
cache[li - ni + 1] = P_ni_li
195+
end
196+
end
197+
end
198+
nothing
199+
end
200+
171201
"""
172202
dnPl(x, l::Integer, n::Integer, [cache::AbstractVector])
173203
@@ -188,7 +218,7 @@ julia> dnPl(0.5, 4, 0) == Pl(0.5, 4) # zero-th order derivative == Pl(x)
188218
true
189219
```
190220
"""
191-
function dnPl(x, l::Integer, n::Integer,
221+
Base.@propagate_inbounds function dnPl(x, l::Integer, n::Integer,
192222
A = begin
193223
_checkvalues(x, l, n)
194224
# do not allocate A if the value is trivially zero
@@ -200,42 +230,14 @@ function dnPl(x, l::Integer, n::Integer,
200230
)
201231

202232
_checkvalues(x, l, n)
203-
checklength(A, l - n + 1)
204-
205-
cache = OffsetArrays.no_offset_view(A)
206-
207233
# check if the value is trivially zero in case A is provided in the function call
208234
if l < n
209-
return zero(eltype(cache))
235+
return zero(eltype(A))
210236
end
211-
212-
if n == l # may short-circuit this
213-
cache[1] = doublefactorial(eltype(cache), 2l-1)
214-
else
215-
collectPl!(cache, x, lmax = l - n)
216-
217-
for ni in 1:n
218-
# We denote the terms as P_ni_li
219-
220-
# li == ni
221-
P_nim1_nim1 = cache[1]
222-
P_ni_ni = dPl_recursion(eltype(cache), ni, ni, nothing, P_nim1_nim1, nothing, x)
223-
cache[1] = P_ni_ni
224237

225-
# li == ni + 1
226-
P_nim1_ni = cache[2]
227-
P_ni_nip1 = dPl_recursion(eltype(cache), ni + 1, ni, P_ni_ni, P_nim1_ni, nothing, x)
228-
cache[2] = P_ni_nip1
229-
230-
for li in ni+2:min(l, l - n + ni)
231-
P_ni_lim2 = cache[li - ni - 1]
232-
P_ni_lim1 = cache[li - ni]
233-
P_nim1_lim1 = cache[li - ni + 1]
234-
P_ni_li = dPl_recursion(eltype(cache), li, ni, P_ni_lim1, P_nim1_lim1, P_ni_lim2, x)
235-
cache[li - ni + 1] = P_ni_li
236-
end
237-
end
238-
end
238+
cache = OffsetArrays.no_offset_view(A)
239+
# function barrier, as no_offset_view may be type-unstable
240+
_unsafednPl!(cache, x, l, n)
239241

240242
return cache[l - n + 1]
241243
end
@@ -271,13 +273,12 @@ julia> collectPl!(v, 0.5, lmax = 3) # only l from 0 to 3 are populated
271273
```
272274
"""
273275
function collectPl!(v::AbstractVector, x; lmax::Integer = length(v) - 1)
274-
v_0based = OffsetArray(v, OffsetArrays.Origin(0))
275-
checksize(v_0based, lmax)
276+
checklength(v, lmax + 1)
276277

277278
iter = LegendrePolynomialIterator(x, lmax)
278279

279280
for (l, Pl) in pairs(iter)
280-
v_0based[l] = Pl
281+
v[l + firstindex(v)] = Pl
281282
end
282283

283284
v
@@ -363,7 +364,7 @@ function collectdnPl!(v, x; lmax::Integer, n::Integer)
363364
# trivially zero for l < n
364365
fill!((@view v[(0:n-1) .+ firstindex(v)]), zero(eltype(v)))
365366
# populate the other elements
366-
dnPl(x, lmax, n, @view v[(n:lmax) .+ firstindex(v)])
367+
@inbounds dnPl(x, lmax, n, @view v[(n:lmax) .+ firstindex(v)])
367368

368369
v
369370
end

0 commit comments

Comments
 (0)