@@ -779,6 +779,88 @@ function SpecialFunctions.gamma_inc(a::Real, d::Dual{T,<:Real}, ind::Integer) wh
779
779
return (Dual {T} (p, ∂p), Dual {T} (q, - ∂p))
780
780
end
781
781
782
+ # Efficient left multiplication/division of #
783
+ # Dual array by a constant matrix #
784
+ # -------------------------------------------#
785
+ # creates the copy of x and applies fvalue!(values(y), values(x)) to its values,
786
+ # and fpartial!(partial(y, i), partial(y, i), i) to its partials
787
+ function _map_dual_components! (fvalue!, fpartial!, y:: AbstractArray{DT} , x:: AbstractArray{DT} ) where DT <: Dual{<:Any, T} where T
788
+ N = npartials (DT)
789
+ tx = similar (x, T)
790
+ ty = similar (y, T) # temporary Array{T} for fvalue!/fpartial! application
791
+ # y allows res to be accessed as Array{T}
792
+ yarr = reinterpret (reshape, T, y)
793
+ @assert size (yarr) == (N + 1 , size (y)... )
794
+ ystride = size (yarr, 1 )
795
+
796
+ # calculate res values
797
+ @inbounds for (j, v) in enumerate (x)
798
+ tx[j] = value (v)
799
+ end
800
+ fvalue! (ty, tx)
801
+ k = 1
802
+ @inbounds for tt in ty
803
+ yarr[k] = tt
804
+ k += ystride
805
+ end
806
+
807
+ # calculate each res partial
808
+ for i in 1 : N
809
+ @inbounds for (j, v) in enumerate (x)
810
+ tx[j] = partials (v, i)
811
+ end
812
+ fpartial! (ty, tx, i)
813
+ k = i + 1
814
+ @inbounds for tt in ty
815
+ yarr[k] = tt
816
+ k += ystride
817
+ end
818
+ end
819
+
820
+ return y
821
+ end
822
+
823
+ # use ldiv!() for matrices of normal numbers to
824
+ # implement ldiv!() of dual vector by a matrix
825
+ LinearAlgebra. ldiv! (y:: StridedVector{T} ,
826
+ m:: Union {LowerTriangular{<: LinearAlgebra.BlasFloat },
827
+ UpperTriangular{<: LinearAlgebra.BlasFloat }},
828
+ x:: StridedVector{T} ) where T <: Dual =
829
+ (ldiv! (reinterpret (reshape, valtype (T), y)' , m, reinterpret (reshape, valtype (T), x)' ); y)
830
+
831
+ Base.:\ (m:: Union {LowerTriangular{<: LinearAlgebra.BlasFloat },
832
+ UpperTriangular{<: LinearAlgebra.BlasFloat }},
833
+ x:: StridedVector{<:Dual} ) = ldiv! (similar (x), m, x)
834
+
835
+ for MT in (StridedMatrix{<: LinearAlgebra.BlasFloat },
836
+ LowerTriangular{<: LinearAlgebra.BlasFloat },
837
+ UpperTriangular{<: LinearAlgebra.BlasFloat },
838
+ )
839
+ @eval begin
840
+
841
+ LinearAlgebra. ldiv! (y:: StridedMatrix{T} , m:: $MT , x:: StridedMatrix{T} ) where T <: Dual =
842
+ _map_dual_components! ((y, x) -> ldiv! (y, m, x), (y, x, _) -> ldiv! (y, m, x), y, x)
843
+
844
+ Base.:\ (m:: $MT , x:: StridedMatrix{<:Dual} ) = ldiv! (similar (x), m, x)
845
+
846
+ LinearAlgebra. mul! (y:: StridedVector{T} , m:: $MT , x:: StridedVector{T} ) where T <: Dual =
847
+ (mul! (reinterpret (reshape, valtype (T), y), reinterpret (reshape, valtype (T), x), m' ); y)
848
+
849
+ LinearAlgebra. mul! (y:: StridedVector{T} , m:: $MT , x:: StridedVector{T} ,
850
+ α:: Union{LinearAlgebra.BlasFloat, Integer} ,
851
+ β:: Union{LinearAlgebra.BlasFloat, Integer} ) where T <: Dual =
852
+ (mul! (reinterpret (reshape, valtype (T), y), reinterpret (reshape, valtype (T), x), m' , α, β); y)
853
+
854
+ Base.:* (m:: $MT , x:: StridedVector{<:Dual} ) = mul! (similar (x, (size (m, 1 ),)), m, x)
855
+
856
+ LinearAlgebra. mul! (y:: StridedMatrix{T} , m:: $MT , x:: StridedMatrix{T} ) where T <: Dual =
857
+ _map_dual_components! ((y, x) -> mul! (y, m, x), (y, x, _) -> mul! (y, m, x), y, x)
858
+
859
+ Base.:* (m:: $MT , x:: StridedMatrix{<:Dual} ) = mul! (similar (x, (size (m, 1 ), size (x, 2 ))), m, x)
860
+
861
+ end
862
+ end
863
+
782
864
# ##################
783
865
# Pretty Printing #
784
866
# ##################
0 commit comments