Skip to content

Commit 49bdd7a

Browse files
Merge pull request #549 from JuliaDiff/kc/arit
define ArithmeticStyle for Dual
2 parents bb85ea2 + ff251a0 commit 49bdd7a

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

src/dual.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ struct Dual{T,V,N} <: Real
2020
end
2121
end
2222

23+
##########
24+
# Traits #
25+
##########
26+
Base.ArithmeticStyle(::Type{<:Dual{T,V}}) where {T,V} = Base.ArithmeticStyle(V)
27+
2328
##############
2429
# Exceptions #
2530
##############

test/GradientTest.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,4 +162,11 @@ end
162162
@test_throws DimensionMismatch ForwardDiff.gradient(identity, fill(2pi, 10^6)) # chunk_mode_gradient
163163
end
164164

165+
@testset "ArithmeticStyle" begin
166+
function f(p)
167+
sum(collect(0.0:p[1]:p[2]))
168+
end
169+
@test ForwardDiff.gradient(f, [0.2,25.0]) == [7875.0, 0.0]
170+
end
171+
165172
end # module

0 commit comments

Comments
 (0)