Skip to content

Commit 7984564

Browse files
committed
Add constant_function kwarg to AutoEnzyme
1 parent 97d5146 commit 7984564

File tree

4 files changed

+26
-14
lines changed

4 files changed

+26
-14
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.5.4"
6+
version = "1.6.0"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,18 @@ struct AutoDiffractor <: AbstractADType end
3939
mode(::AutoDiffractor) = ForwardOrReverseMode()
4040

4141
"""
42-
AutoEnzyme{M}
42+
AutoEnzyme{M,constant_function}
4343
4444
Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.
4545
4646
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4747
4848
# Constructors
4949
50-
AutoEnzyme(; mode=nothing)
50+
AutoEnzyme(; mode=nothing, constant_function::Bool=false)
51+
52+
The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl.
53+
For simple functions, this should usually be set to `false`, but in the case of closures or callable structs which contain differentiated data, it should be set to `true`.
5154
5255
# Fields
5356
@@ -56,8 +59,16 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
5659
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
5760
+ `nothing` to choose the best mode automatically
5861
"""
59-
Base.@kwdef struct AutoEnzyme{M} <: AbstractADType
60-
mode::M = nothing
62+
struct AutoEnzyme{M, constant_function} <: AbstractADType
63+
mode::M
64+
end
65+
66+
function AutoEnzyme(mode::M; constant_function::Bool = false) where {M}
67+
return AutoEnzyme{M, constant_function}(mode)
68+
end
69+
70+
function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M}
71+
return AutoEnzyme{M, constant_function}(mode)
6172
end
6273

6374
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension

test/dense.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,25 @@ end
2828
@testset "AutoEnzyme" begin
2929
ad = AutoEnzyme()
3030
@test ad isa AbstractADType
31-
@test ad isa AutoEnzyme{Nothing}
31+
@test ad isa AutoEnzyme{Nothing, false}
3232
@test mode(ad) isa ForwardOrReverseMode
3333
@test ad.mode === nothing
3434

35+
ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true)
36+
@test ad isa AbstractADType
37+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true}
38+
@test mode(ad) isa ForwardMode
39+
@test ad.mode == EnzymeCore.Forward
40+
3541
ad = AutoEnzyme(; mode = EnzymeCore.Forward)
3642
@test ad isa AbstractADType
37-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
43+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false}
3844
@test mode(ad) isa ForwardMode
3945
@test ad.mode == EnzymeCore.Forward
4046

41-
ad = AutoEnzyme(; mode = EnzymeCore.Reverse)
47+
ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true)
4248
@test ad isa AbstractADType
43-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)}
49+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true}
4450
@test mode(ad) isa ReverseMode
4551
@test ad.mode == EnzymeCore.Reverse
4652
end

test/misc.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,6 @@ end
2121
@test length(string(sparse_backend1)) < length(string(sparse_backend2))
2222
end
2323

24-
import ADTypes
25-
26-
struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end
27-
struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end
28-
2924
for backend in [
3025
# dense
3126
ADTypes.AutoChainRules(; ruleconfig = :rc),

0 commit comments

Comments
 (0)