Skip to content

Commit 59ceff3

Browse files
authored
Fix pretty printing and ReverseDiff constructor (#67)
1 parent 16f421d commit 59ceff3

File tree

6 files changed

+80
-62
lines changed

6 files changed

+80
-62
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.5.1"
6+
version = "1.5.2"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

Lines changed: 29 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ end
2020
mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension
2121

2222
function Base.show(io::IO, backend::AutoChainRules)
23-
print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))")
23+
print(io, AutoChainRules, "(ruleconfig=", repr(backend.ruleconfig; context = io), ")")
2424
end
2525

2626
"""
@@ -63,11 +63,9 @@ end
6363
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
6464

6565
function Base.show(io::IO, backend::AutoEnzyme)
66-
if isnothing(backend.mode)
67-
print(io, "AutoEnzyme()")
68-
else
69-
print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))")
70-
end
66+
print(io, AutoEnzyme, "(")
67+
!isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io))
68+
print(io, ")")
7169
end
7270

7371
"""
@@ -111,21 +109,14 @@ end
111109
mode(::AutoFiniteDiff) = ForwardMode()
112110

113111
function Base.show(io::IO, backend::AutoFiniteDiff)
114-
s = "AutoFiniteDiff("
115-
if backend.fdtype != Val(:forward)
116-
s *= "fdtype=$(repr(backend.fdtype, context=io)), "
117-
end
118-
if backend.fdjtype != backend.fdtype
119-
s *= "fdjtype=$(repr(backend.fdjtype, context=io)), "
120-
end
121-
if backend.fdhtype != Val(:hcentral)
122-
s *= "fdhtype=$(repr(backend.fdhtype, context=io)), "
123-
end
124-
if endswith(s, ", ")
125-
s = s[1:(end - 2)]
126-
end
127-
s *= ")"
128-
print(io, s)
112+
print(io, AutoFiniteDiff, "(")
113+
backend.fdtype != Val(:forward) &&
114+
print(io, "fdtype=", repr(backend.fdtype; context = io), ", ")
115+
backend.fdjtype != backend.fdtype &&
116+
print(io, "fdjtype=", repr(backend.fdjtype; context = io), ", ")
117+
backend.fdhtype != Val(:hcentral) &&
118+
print(io, "fdhtype=", repr(backend.fdhtype; context = io))
119+
print(io, ")")
129120
end
130121

131122
"""
@@ -150,7 +141,7 @@ end
150141
mode(::AutoFiniteDifferences) = ForwardMode()
151142

152143
function Base.show(io::IO, backend::AutoFiniteDifferences)
153-
print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))")
144+
print(io, AutoFiniteDifferences, "(fdm=", repr(backend.fdm; context = io), ")")
154145
end
155146

156147
"""
@@ -183,18 +174,11 @@ end
183174
mode(::AutoForwardDiff) = ForwardMode()
184175

185176
function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize}
186-
s = "AutoForwardDiff("
187-
if chunksize !== nothing
188-
s *= "chunksize=$chunksize, "
189-
end
190-
if backend.tag !== nothing
191-
s *= "tag=$(repr(backend.tag, context=io)), "
192-
end
193-
if endswith(s, ", ")
194-
s = s[1:(end - 2)]
195-
end
196-
s *= ")"
197-
print(io, s)
177+
print(io, AutoForwardDiff, "(")
178+
chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io),
179+
(backend.tag !== nothing ? ", " : ""))
180+
backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io))
181+
print(io, ")")
198182
end
199183

200184
"""
@@ -227,18 +211,11 @@ end
227211
mode(::AutoPolyesterForwardDiff) = ForwardMode()
228212

229213
function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize}
230-
s = "AutoPolyesterForwardDiff("
231-
if chunksize !== nothing
232-
s *= "chunksize=$chunksize, "
233-
end
234-
if backend.tag !== nothing
235-
s *= "tag=$(repr(backend.tag, context=io)), "
236-
end
237-
if endswith(s, ", ")
238-
s = s[1:(end - 2)]
239-
end
240-
s *= ")"
241-
print(io, s)
214+
print(io, AutoPolyesterForwardDiff, "(")
215+
chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io),
216+
(backend.tag !== nothing ? ", " : ""))
217+
backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io))
218+
print(io, ")")
242219
end
243220

244221
"""
@@ -277,11 +254,9 @@ end
277254
mode(::AutoReverseDiff) = ReverseMode()
278255

279256
function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile}
280-
if !compile
281-
print(io, "AutoReverseDiff()")
282-
else
283-
print(io, "AutoReverseDiff(compile=true)")
284-
end
257+
print(io, AutoReverseDiff, "(")
258+
compile && print(io, "compile=true")
259+
print(io, ")")
285260
end
286261

287262
"""
@@ -321,11 +296,9 @@ end
321296
mode(::AutoTapir) = ReverseMode()
322297

323298
function Base.show(io::IO, backend::AutoTapir)
324-
if backend.safe_mode
325-
print(io, "AutoTapir()")
326-
else
327-
print(io, "AutoTapir(safe_mode=false)")
328-
end
299+
print(io, AutoTapir, "(")
300+
!(backend.safe_mode) && print(io, "safe_mode=false")
301+
print(io, ")")
329302
end
330303

331304
"""

src/legacy.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
@deprecate AutoSparseZygote() AutoSparse(AutoZygote())
1313

14+
@deprecate AutoReverseDiff(compile) AutoReverseDiff(; compile)
15+
1416
function mtk_to_symbolics(obj_sparse::Bool, cons_sparse::Bool)
1517
if obj_sparse || cons_sparse
1618
return AutoSparse(AutoSymbolics())

src/sparse.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,15 +155,15 @@ function AutoSparse(
155155
end
156156

157157
function Base.show(io::IO, backend::AutoSparse)
158-
s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), "
158+
print(io, AutoSparse, "(dense_ad=", repr(backend.dense_ad, context = io))
159159
if backend.sparsity_detector != NoSparsityDetector()
160-
s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), "
160+
print(io, ", sparsity_detector=", repr(backend.sparsity_detector, context = io))
161161
end
162162
if backend.coloring_algorithm != NoColoringAlgorithm()
163-
s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), "
163+
print(
164+
io, ", coloring_algorithm=", repr(backend.coloring_algorithm, context = io))
164165
end
165-
s = s[1:(end - 2)] * ")"
166-
print(io, s)
166+
print(io, ")")
167167
end
168168

169169
"""

test/legacy.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,8 @@ end
5858
@test ad isa AbstractADType
5959
@test dense_ad(ad) isa AutoZygote
6060
end
61+
62+
@testset "AutoReverseDiff without kwarg" begin
63+
ad = @test_deprecated AutoReverseDiff(true)
64+
@test ad.compile
65+
end

test/misc.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ end
77
@testset "Printing" begin
88
for ad in every_ad_with_options()
99
@test startswith(string(ad), "Auto")
10+
@test contains(string(ad), "(")
1011
@test endswith(string(ad), ")")
1112
end
1213

@@ -19,3 +20,40 @@ end
1920
@test contains(string(sparse_backend1), string(AutoForwardDiff()))
2021
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
2122
end
23+
24+
import ADTypes
25+
26+
struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end
27+
struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end
28+
29+
for backend in [
30+
# dense
31+
ADTypes.AutoChainRules(; ruleconfig = :rc),
32+
ADTypes.AutoDiffractor(),
33+
ADTypes.AutoEnzyme(),
34+
ADTypes.AutoEnzyme(mode = :forward),
35+
ADTypes.AutoFastDifferentiation(),
36+
ADTypes.AutoFiniteDiff(),
37+
ADTypes.AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh),
38+
ADTypes.AutoFiniteDifferences(; fdm = :fdm),
39+
ADTypes.AutoForwardDiff(),
40+
ADTypes.AutoForwardDiff(chunksize = 3, tag = :tag),
41+
ADTypes.AutoPolyesterForwardDiff(),
42+
ADTypes.AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
43+
ADTypes.AutoReverseDiff(),
44+
ADTypes.AutoReverseDiff(compile = true),
45+
ADTypes.AutoSymbolics(),
46+
ADTypes.AutoTapir(),
47+
ADTypes.AutoTapir(safe_mode = false),
48+
ADTypes.AutoTracker(),
49+
ADTypes.AutoZygote(),
50+
# sparse
51+
ADTypes.AutoSparse(ADTypes.AutoForwardDiff()),
52+
ADTypes.AutoSparse(
53+
ADTypes.AutoForwardDiff();
54+
sparsity_detector = FakeSparsityDetector(),
55+
coloring_algorithm = FakeColoringAlgorithm()
56+
)
57+
]
58+
println(backend)
59+
end

0 commit comments

Comments
 (0)