Skip to content

Commit d551bbe

Browse files
authored
define sin/cos on Dual to use sincos (#561)
1 parent ae3cb15 commit d551bbe

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/dual.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ Base.float(d::Dual) = convert(float(typeof(d)), d)
393393
###################################
394394

395395
for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
396-
if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-))
396+
if (M, f) in ((:Base, :^), (:NaNMath, :pow), (:Base, :/), (:Base, :+), (:Base, :-), (:Base, :sin), (:Base, :cos))
397397
continue # Skip methods which we define elsewhere.
398398
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
399399
continue # Skip rules for methods not defined in the current scope
@@ -622,12 +622,19 @@ end
622622
Dual{Tz}(muladd(x, y, value(z)), partials(z)) # z_body
623623
)
624624

625-
# sincos #
625+
# sin/cos #
626626
#--------#
627+
function Base.sin(d::Dual{T}) where T
628+
s, c = sincos(value(d))
629+
return Dual{T}(s, c * partials(d))
630+
end
627631

628-
@inline sincos(x) = (sin(x), cos(x))
632+
function Base.cos(d::Dual{T}) where T
633+
s, c = sincos(value(d))
634+
return Dual{T}(c, -s * partials(d))
635+
end
629636

630-
@inline function sincos(d::Dual{T}) where T
637+
@inline function Base.sincos(d::Dual{T}) where T
631638
sd, cd = sincos(value(d))
632639
return (Dual{T}(sd, cd * partials(d)), Dual{T}(cd, -sd * partials(d)))
633640
end

0 commit comments

Comments
 (0)