Skip to content

Commit ab0e239

Browse files
authored
explicitly SIMD muladd with duals (#562)
1 parent 5bb4546 commit ab0e239

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

src/ForwardDiff.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ import SpecialFunctions
1616
import LogExpFunctions
1717
import CommonSubexpressions
1818

19+
const SIMDFloat = Union{Float64, Float32}
20+
const SIMDInt = Union{
21+
Int128, Int64, Int32, Int16, Int8,
22+
UInt128, UInt64, UInt32, UInt16, UInt8,
23+
}
24+
const SIMDType = Union{SIMDFloat, SIMDInt}
25+
1926
include("prelude.jl")
2027
include("partials.jl")
2128
include("dual.jl")

src/dual.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,16 @@ end
541541
# fma #
542542
#-----#
543543

544+
@inline function calc_fma_xyz(x::Dual{T,V,N},
545+
y::Dual{T,V,N},
546+
z::Dual{T,V,N}) where {T, V<:SIMDFloat,N}
547+
xv, yv, zv = value(x), value(y), value(z)
548+
rv = fma(xv, yv, zv)
549+
N == 0 && return Dual{T}(rv)
550+
xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values)
551+
parts = Tuple(fma(xv, yp, fma(yv, xp, zp)))
552+
Dual{T}(rv, parts)
553+
end
544554
@generated function calc_fma_xyz(x::Dual{T,<:Any,N},
545555
y::Dual{T,<:Any,N},
546556
z::Dual{T,<:Any,N}) where {T,N}
@@ -583,6 +593,16 @@ end
583593
# muladd #
584594
#--------#
585595

596+
@inline function calc_muladd_xyz(x::Dual{T,V,N},
597+
y::Dual{T,V,N},
598+
z::Dual{T,V,N}) where {T, V<:SIMDType,N}
599+
xv, yv, zv = value(x), value(y), value(z)
600+
rv = muladd(xv, yv, zv)
601+
N == 0 && return Dual{T}(rv)
602+
xp, yp, zp = Vec(partials(x).values), Vec(partials(y).values), Vec(partials(z).values)
603+
parts = Tuple(muladd(xv, yp, muladd(yv, xp, zp)))
604+
Dual{T}(rv, parts)
605+
end
586606
@generated function calc_muladd_xyz(x::Dual{T,<:Any,N},
587607
y::Dual{T,<:Any,N},
588608
z::Dual{T,<:Any,N}) where {T,N}

src/partials.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,13 +205,6 @@ end
205205
return tupexpr(i -> :(rand(V)), N)
206206
end
207207

208-
209-
const SIMDFloat = Union{Float64, Float32}
210-
const SIMDInt = Union{
211-
Int128, Int64, Int32, Int16, Int8,
212-
UInt128, UInt64, UInt32, UInt16, UInt8,
213-
}
214-
const SIMDType = Union{SIMDFloat, SIMDInt}
215208
const NT{N,T} = NTuple{N,T}
216209

217210
# SIMD implementation

0 commit comments

Comments
 (0)