Skip to content

Commit f079b3e

Browse files
committed
fix A/x and A*x differentiation
1 parent 03f3565 commit f079b3e

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

src/dual.jl

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -785,7 +785,7 @@ function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::Abst
785785
# y allows res to be accessed as Array{T}
786786
yarr = reinterpret(reshape, T, y)
787787
@assert size(yarr) == (N + 1, size(y)...)
788-
ystride = size(y, 1)
788+
ystride = size(yarr, 1)
789789

790790
# calculate res values
791791
@inbounds for (j, v) in enumerate(x)
@@ -814,18 +814,21 @@ function _map_dual_components!(fvalue!, fpartial!, y::AbstractArray{DT}, x::Abst
814814
return y
815815
end
816816

817+
function Base.:\(m::Union{LowerTriangular{<:LinearAlgebra.BlasFloat},
818+
UpperTriangular{<:LinearAlgebra.BlasFloat}},
819+
x::StridedVector{<:Dual})
820+
T = valtype(eltype(x))
821+
res = copy(x)
822+
ldiv!(m, reinterpret(reshape, T, res)')
823+
return res
824+
end
825+
817826
for MT in (StridedMatrix{<:LinearAlgebra.BlasFloat},
818827
LowerTriangular{<:LinearAlgebra.BlasFloat},
819828
UpperTriangular{<:LinearAlgebra.BlasFloat})
820829

821-
@eval function Base.:\(m::$MT, x::StridedVector{<:Dual})
822-
T = valtype(eltype(x))
823-
ldiv!(m', reinterpret(reshape, T, res))
824-
return res
825-
end
826-
827830
@eval Base.:\(m::$MT, x::StridedMatrix{<:Dual}) =
828-
_map_dual_components!((x, _) -> ldiv!(m, x), (x, _, _) -> ldiv!(m, x), similar(x), x)
831+
_map_dual_components!((y, x) -> ldiv!(y, m, x), (y, x, _) -> ldiv!(y, m, x), similar(x), x)
829832

830833
@eval function Base.:*(m::$MT, x::StridedVector{<:Dual})
831834
T = valtype(eltype(x))

0 commit comments

Comments
 (0)