Skip to content

Commit ebcdf17

Browse files
committed
faster * and / of dual arr by constant mtx
1 parent 9a91982 commit ebcdf17

File tree

1 file changed

+82
-0
lines changed

1 file changed

+82
-0
lines changed

src/dual.jl

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,88 @@ function SpecialFunctions.gamma_inc(a::Real, d::Dual{T,<:Real}, ind::Integer) wh
779779
return (Dual{T}(p, ∂p), Dual{T}(q, -∂p))
780780
end
781781

782+
# Efficient left multiplication/division of #
783+
# Dual array by a constant matrix #
784+
#-------------------------------------------#
785+
# creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
786+
# and fpartial!(partial(y, i), partial(y, i), i) to its partials
787+
function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::AbstractArray{DT}) where DT <: Dual{<:Any, T} where T
788+
N = npartials(DT)
789+
tx = similar(x, T)
790+
ty = similar(y, T) # temporary Array{T} for fvalue!/fpartial! application
791+
# y allows res to be accessed as Array{T}
792+
yarr = reinterpret(reshape, T, y)
793+
@assert size(yarr) == (N + 1, size(y)...)
794+
ystride = size(yarr, 1)
795+
796+
# calculate res values
797+
@inbounds for (j, v) in enumerate(x)
798+
tx[j] = value(v)
799+
end
800+
fvalue!(ty, tx)
801+
k = 1
802+
@inbounds for tt in ty
803+
yarr[k] = tt
804+
k += ystride
805+
end
806+
807+
# calculate each res partial
808+
for i in 1:N
809+
@inbounds for (j, v) in enumerate(x)
810+
tx[j] = partials(v, i)
811+
end
812+
fpartial!(ty, tx, i)
813+
k = i + 1
814+
@inbounds for tt in ty
815+
yarr[k] = tt
816+
k += ystride
817+
end
818+
end
819+
820+
return y
821+
end
822+
823+
# use ldiv!() for matrices of normal numbers to
824+
# implement ldiv!() of dual vector by a matrix
825+
LinearAlgebra.ldiv!(y::StridedVector{T},
826+
m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
827+
UpperTriangular{<:LinearAlgebra.BlasFloat}},
828+
x::StridedVector{T}) where T <: Dual =
829+
(ldiv!(reinterpret(reshape, valtype(T), y)', m, reinterpret(reshape, valtype(T), x)'); y)
830+
831+
Base.:\(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
832+
UpperTriangular{<:LinearAlgebra.BlasFloat}},
833+
x::StridedVector{<:Dual}) = ldiv!(similar(x), m, x)
834+
835+
for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
836+
LowerTriangular{<:LinearAlgebra.BlasFloat},
837+
UpperTriangular{<:LinearAlgebra.BlasFloat},
838+
)
839+
@eval begin
840+
841+
LinearAlgebra.ldiv!(y::StridedMatrix{T}, m::$MT, x::StridedMatrix{T}) where T <: Dual =
842+
_map_dual_components!((y, x) -> ldiv!(y, m, x), (y, x, _) -> ldiv!(y, m, x), y, x)
843+
844+
Base.:\(m::$MT, x::StridedMatrix{<:Dual}) = ldiv!(similar(x), m, x)
845+
846+
LinearAlgebra.mul!(y::StridedVector{T}, m::$MT, x::StridedVector{T}) where T <: Dual =
847+
(mul!(reinterpret(reshape, valtype(T), y), reinterpret(reshape, valtype(T), x), m'); y)
848+
849+
LinearAlgebra.mul!(y::StridedVector{T}, m::$MT, x::StridedVector{T},
850+
α::Union{LinearAlgebra.BlasFloat, Integer},
851+
β::Union{LinearAlgebra.BlasFloat, Integer}) where T <: Dual =
852+
(mul!(reinterpret(reshape, valtype(T), y), reinterpret(reshape, valtype(T), x), m', α, β); y)
853+
854+
Base.:*(m::$MT, x::StridedVector{<:Dual}) = mul!(similar(x, (size(m, 1),)), m, x)
855+
856+
LinearAlgebra.mul!(y::StridedMatrix{T}, m::$MT, x::StridedMatrix{T}) where T <: Dual =
857+
_map_dual_components!((y, x) -> mul!(y, m, x), (y, x, _) -> mul!(y, m, x), y, x)
858+
859+
Base.:*(m::$MT, x::StridedMatrix{<:Dual}) = mul!(similar(x, (size(m, 1), size(x, 2))), m, x)
860+
861+
end
862+
end
863+
782864
###################
783865
# Pretty Printing #
784866
###################

0 commit comments

Comments
 (0)