Skip to content

Commit c4b9a4c

Browse files
Add AutoGTPSA backend (#78)
* Add AutoGTPSA taylor mode * Fix printing --------- Co-authored-by: Guillaume Dalle <[email protected]>
1 parent a44fd2d commit c4b9a4c

File tree

6 files changed

+58
-0
lines changed

6 files changed

+58
-0
lines changed

docs/src/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ AutoFiniteDiff
3030
AutoFiniteDifferences
3131
```
3232

33+
Taylor mode:
34+
35+
```@docs
36+
AutoGTPSA
37+
```
38+
3339
### Reverse mode
3440

3541
```@docs

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ export AutoChainRules,
4040
AutoFiniteDiff,
4141
AutoFiniteDifferences,
4242
AutoForwardDiff,
43+
AutoGTPSA,
4344
AutoModelingToolkit,
4445
AutoPolyesterForwardDiff,
4546
AutoReverseDiff,

src/dense.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,38 @@ function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize
195195
print(io, ")")
196196
end
197197

198+
"""
199+
AutoGTPSA{D}
200+
201+
Struct used to select the [GTPSA.jl](https://github.com/bmad-sim/GTPSA.jl) backend for automatic differentiation.
202+
203+
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
204+
205+
# Constructors
206+
207+
AutoGTPSA(; descriptor=nothing)
208+
209+
# Fields
210+
211+
- `descriptor::D`: can be either
212+
213+
+ a GTPSA `Descriptor` specifying the number of variables/parameters, parameter
214+
order, individual variable/parameter truncation orders, and maximum order. See
215+
the [GTPSA.jl documentation](https://bmad-sim.github.io/GTPSA.jl/stable/man/c_descriptor/) for more details.
216+
+ `nothing` to automatically use a `Descriptor` given the context.
217+
"""
218+
Base.@kwdef struct AutoGTPSA{D} <: AbstractADType
219+
descriptor::D = nothing
220+
end
221+
222+
mode(::AutoGTPSA) = ForwardMode()
223+
224+
function Base.show(io::IO, backend::AutoGTPSA{D}) where {D}
225+
print(io, AutoGTPSA, "(")
226+
D != Nothing && print(io, "descriptor=", repr(backend.descriptor; context = io))
227+
print(io, ")")
228+
end
229+
198230
"""
199231
AutoPolyesterForwardDiff{chunksize,T}
200232

test/dense.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ end
105105
@test ad.tag == CustomTag()
106106
end
107107

108+
@testset "AutoGTPSA" begin
109+
ad = AutoGTPSA(; descriptor = nothing)
110+
@test ad isa AbstractADType
111+
@test ad isa AutoGTPSA{Nothing}
112+
@test mode(ad) isa ForwardMode
113+
@test ad.descriptor === nothing
114+
115+
ad = AutoGTPSA(; descriptor = Val(:descriptor))
116+
@test ad isa AbstractADType
117+
@test ad isa AutoGTPSA{Val{:descriptor}}
118+
@test mode(ad) isa ForwardMode
119+
@test ad.descriptor == Val(:descriptor)
120+
end
121+
108122
@testset "AutoPolyesterForwardDiff" begin
109123
ad = AutoPolyesterForwardDiff()
110124
@test ad isa AbstractADType

test/misc.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ for backend in [
4040
ADTypes.AutoFiniteDifferences(; fdm = :fdm),
4141
ADTypes.AutoForwardDiff(),
4242
ADTypes.AutoForwardDiff(chunksize = 3, tag = :tag),
43+
ADTypes.AutoGTPSA(),
44+
ADTypes.AutoGTPSA(; descriptor = Val(:descriptor)),
4345
ADTypes.AutoPolyesterForwardDiff(),
4446
ADTypes.AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
4547
ADTypes.AutoReverseDiff(),

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ function every_ad()
4545
AutoFiniteDiff(),
4646
AutoFiniteDifferences(; fdm = :fdm),
4747
AutoForwardDiff(),
48+
AutoGTPSA(),
4849
AutoPolyesterForwardDiff(),
4950
AutoReverseDiff(),
5051
AutoSymbolics(),
@@ -66,6 +67,8 @@ function every_ad_with_options()
6667
AutoFiniteDifferences(; fdm = :fdm),
6768
AutoForwardDiff(),
6869
AutoForwardDiff(chunksize = 3, tag = :tag),
70+
AutoGTPSA(),
71+
AutoGTPSA(descriptor = Val(:descriptor)),
6972
AutoPolyesterForwardDiff(),
7073
AutoPolyesterForwardDiff(chunksize = 3, tag = :tag),
7174
AutoReverseDiff(),

0 commit comments

Comments
 (0)