Skip to content

Commit 72f806d

Browse files
authored
Add Symbol -> AbstractADType mapping (#62)
1 parent 602ce68 commit 72f806d

File tree

6 files changed

+60
-1
lines changed

6 files changed

+60
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.3.0"
6+
version = "1.4.0"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

docs/src/index.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,9 @@ ADTypes.ForwardOrReverseMode
9292
ADTypes.ReverseMode
9393
ADTypes.SymbolicMode
9494
```
95+
96+
## Miscellaneous
97+
98+
```@docs
99+
ADTypes.Auto
100+
```

src/ADTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ include("mode.jl")
2020
include("dense.jl")
2121
include("sparse.jl")
2222
include("legacy.jl")
23+
include("symbols.jl")
2324

2425
if !isdefined(Base, :get_extension)
2526
include("../ext/ADTypesChainRulesCoreExt.jl")

src/symbols.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
ADTypes.Auto(package::Symbol)
3+
4+
A shortcut that converts an AD package name into an instance of [`AbstractADType`](@ref), with all parameters set to their default values.
5+
6+
!!! warning
7+
8+
This function is type-unstable by design and might lead to suboptimal performance.
9+
In most cases, you should never need it: use the individual backend types directly.
10+
11+
# Example
12+
13+
```jldoctest
14+
import ADTypes
15+
backend = ADTypes.Auto(:Zygote)
16+
17+
# output
18+
19+
ADTypes.AutoZygote()
20+
```
21+
"""
22+
Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...)
23+
24+
for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation,
25+
:FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff,
26+
:ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote)
27+
@eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...)
28+
end
29+

test/runtests.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ end
6666
@testset "Sparse" begin
6767
include("sparse.jl")
6868
end
69+
@testset "Symbols" begin
70+
include("symbols.jl")
71+
end
6972
@testset "Legacy" begin
7073
include("legacy.jl")
7174
end

test/symbols.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using ADTypes
2+
using Test
3+
4+
@test ADTypes.Auto(:ChainRules, 1) isa AutoChainRules{Int64}
5+
@test ADTypes.Auto(:Diffractor) isa AutoDiffractor
6+
@test ADTypes.Auto(:Enzyme) isa AutoEnzyme
7+
@test ADTypes.Auto(:FastDifferentiation) isa AutoFastDifferentiation
8+
@test ADTypes.Auto(:FiniteDiff) isa AutoFiniteDiff
9+
@test ADTypes.Auto(:FiniteDifferences, 1.0) isa AutoFiniteDifferences{Float64}
10+
@test ADTypes.Auto(:ForwardDiff) isa AutoForwardDiff
11+
@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff
12+
@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff
13+
@test ADTypes.Auto(:Symbolics) isa AutoSymbolics
14+
@test ADTypes.Auto(:Tapir) isa AutoTapir
15+
@test ADTypes.Auto(:Tracker) isa AutoTracker
16+
@test ADTypes.Auto(:Zygote) isa AutoZygote
17+
18+
@test_throws MethodError ADTypes.Auto(:ThisPackageDoesNotExist)
19+
@test_throws UndefKeywordError ADTypes.Auto(:ChainRules)
20+
@test_throws UndefKeywordError ADTypes.Auto(:FiniteDifferences)

0 commit comments

Comments
 (0)