Skip to content

Commit 807b0ec

Browse files
authored
allow printing using at-printf (#511)
* allow printing using at-printf I implemented specializations of `Printf.tofloat` to allow printing (the `value` of) dual numbers. This is useful when differentiating functions that print some information during the computation, cf. https://discourse.julialang.org/t/forwarddiff-jl-and-printf/57694.
1 parent 6caf1cd commit 807b0ec

File tree

4 files changed

+17
-0
lines changed

4 files changed

+17
-0
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
88
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
11+
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1112
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1213
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1314
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

src/ForwardDiff.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ using StaticArrays
66
using Random
77
using LinearAlgebra
88

9+
import Printf
910
import NaNMath
1011
import SpecialFunctions
1112
import CommonSubexpressions

src/dual.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,3 +690,7 @@ end
690690
function Base.typemax(::Type{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
691691
ForwardDiff.Dual{T,V,N}(typemax(V))
692692
end
693+
694+
if VERSION >= v"1.6.0-rc1"
695+
Printf.tofloat(d::Dual) = Printf.tofloat(value(d))
696+
end

test/DualTest.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module DualTest
22

33
using Test
4+
using Printf
45
using Random
56
using ForwardDiff
67
using ForwardDiff: Partials, Dual, value, partials
@@ -504,4 +505,14 @@ end
504505
@test !isfinite(dinf)
505506
end
506507

508+
if VERSION >= v"1.6.0-rc1"
509+
@testset "@printf" begin
510+
for T in (Float16, Float32, Float64, BigFloat)
511+
d1 = Dual(one(T))
512+
@test_nowarn @printf("Testing @printf: %.2e\n", d1)
513+
@test @sprintf("Testing @sprintf: %.2e\n", d1) == "Testing @sprintf: 1.00e+00\n"
514+
end
515+
end
516+
end
517+
507518
end # module

0 commit comments

Comments
 (0)