Skip to content

Commit 102ee4d

Browse files
authored
fix float(::Dual) and add float(::Type{<:Dual}) (#535)
1 parent 54fe5d0 commit 102ee4d

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

src/dual.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ Base.convert(::Type{Dual{T,V,N}}, x) where {T,V,N} = Dual{T}(convert(V, x), zero
380380
Base.convert(::Type{Dual{T,V,N}}, x::Number) where {T,V,N} = Dual{T}(convert(V, x), zero(Partials{N,V}))
381381
Base.convert(::Type{D}, d::D) where {D<:Dual} = d
382382

383-
Base.float(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d)
384-
Base.AbstractFloat(d::Dual{T,V,N}) where {T,V,N} = convert(Dual{T,promote_type(V, Float16),N}, d)
383+
Base.float(::Type{Dual{T,V,N}}) where {T,V,N} = Dual{T,float(V),N}
384+
Base.float(d::Dual) = convert(float(typeof(d)), d)
385385

386386
###################################
387387
# General Mathematical Operations #

test/DualTest.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ end
517517
@test length(UnitRange(Dual(1.5), Dual(3.5))) == 3
518518
@test length(UnitRange(Dual(1.5,1), Dual(3.5,3))) == 3
519519
end
520-
520+
521521
if VERSION >= v"1.6.0-rc1"
522522
@testset "@printf" begin
523523
for T in (Float16, Float32, Float64, BigFloat)
@@ -528,4 +528,11 @@ if VERSION >= v"1.6.0-rc1"
528528
end
529529
end
530530

531+
@testset "float" begin # issue #492
532+
@test float(Dual{Nothing, Int, 2}) === Dual{Nothing, Float64, 2}
533+
@test float(Dual(1)) isa Dual{Nothing, Float64, 0}
534+
@test value.(float.(Dual.(1:4, 2:5, 3:6))) isa Vector{Float64}
535+
@test ForwardDiff.derivative(float, 1)::Float64 === 1.0
536+
end
537+
531538
end # module

0 commit comments

Comments
 (0)