Skip to content

Commit 1b5cad0

Browse files
Merge pull request #74 from SciML/gd/revert_constant_function
Remove `constant_function` for `AutoEnzyme`
2 parents 024ac94 + 298605d commit 1b5cad0

File tree

3 files changed

+12
-56
lines changed

3 files changed

+12
-56
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
authors = [
44
"Vaibhav Dixit <[email protected]>, Guillaume Dalle and contributors",
55
]
6-
version = "1.6.1"
6+
version = "1.6.2"
77

88
[deps]
99
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/dense.jl

Lines changed: 5 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -39,73 +39,29 @@ struct AutoDiffractor <: AbstractADType end
3939
mode(::AutoDiffractor) = ForwardOrReverseMode()
4040

4141
"""
42-
AutoEnzyme{M,constant_function}
42+
AutoEnzyme{M}
4343
4444
Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation.
4545
4646
Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
4747
4848
# Constructors
4949
50-
AutoEnzyme(; mode=nothing, constant_function::Bool=false)
51-
52-
The `constant_function` keyword argument (and type parameter) determines whether the function object itself should be considered constant or not during differentiation with Enzyme.jl.
53-
For simple functions, `constant_function` should usually be set to `true`, which leads to increased performance.
54-
However, in the case of closures or callable structs which contain differentiated data, `constant_function` should be set to `false` to ensure correctness (more details below).
50+
AutoEnzyme(; mode=nothing)
5551
5652
# Fields
5753
5854
- `mode::M`: can be either
5955
6056
+ an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required
6157
+ `nothing` to choose the best mode automatically
62-
63-
# Notes
64-
65-
We now give several examples of functions.
66-
For each one, we explain how `constant_function` should be set in order to compute the correct derivative with respect to the input `x`.
67-
68-
```julia
69-
function f1(x)
70-
return x[1]
71-
end
72-
```
73-
74-
The function `f1` is not a closure, it does not contain any data.
75-
Thus `f1` can be differentiated with `AutoEnzyme(constant_function=true)` (although here setting `constant_function=false` would change neither correctness nor performance).
76-
77-
```julia
78-
parameter = [0.0]
79-
function f2(x)
80-
return parameter[1] + x[1]
81-
end
82-
```
83-
84-
The function `f2` is a closure over `parameter`, but `parameter` is never modified based on the input `x`.
85-
Thus, `f2` can be differentiated with `AutoEnzyme(constant_function=true)` (setting `constant_function=false` would not change correctness but would hinder performance).
86-
87-
```julia
88-
cache = [0.0]
89-
function f3(x)
90-
cache[1] = x[1]
91-
return cache[1] + x[1]
92-
end
93-
```
94-
95-
The function `f3` is a closure over `cache`, and `cache` is modified based on the input `x`.
96-
That means `cache` cannot be treated as constant, since derivative values must be propagated through it.
97-
Thus `f3` must be differentiated with `AutoEnzyme(constant_function=false)` (setting `constant_function=true` would make the result incorrect).
9858
"""
99-
struct AutoEnzyme{M, constant_function} <: AbstractADType
59+
struct AutoEnzyme{M} <: AbstractADType
10060
mode::M
10161
end
10262

103-
function AutoEnzyme(mode::M; constant_function::Bool = false) where {M}
104-
return AutoEnzyme{M, constant_function}(mode)
105-
end
106-
107-
function AutoEnzyme(; mode::M = nothing, constant_function::Bool = false) where {M}
108-
return AutoEnzyme{M, constant_function}(mode)
63+
function AutoEnzyme(; mode::M = nothing) where {M}
64+
return AutoEnzyme{M}(mode)
10965
end
11066

11167
mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension

test/dense.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,25 @@ end
2828
@testset "AutoEnzyme" begin
2929
ad = AutoEnzyme()
3030
@test ad isa AbstractADType
31-
@test ad isa AutoEnzyme{Nothing, false}
31+
@test ad isa AutoEnzyme{Nothing}
3232
@test mode(ad) isa ForwardOrReverseMode
3333
@test ad.mode === nothing
3434

35-
ad = AutoEnzyme(EnzymeCore.Forward; constant_function = true)
35+
ad = AutoEnzyme(EnzymeCore.Forward)
3636
@test ad isa AbstractADType
37-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), true}
37+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
3838
@test mode(ad) isa ForwardMode
3939
@test ad.mode == EnzymeCore.Forward
4040

4141
ad = AutoEnzyme(; mode = EnzymeCore.Forward)
4242
@test ad isa AbstractADType
43-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward), false}
43+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Forward)}
4444
@test mode(ad) isa ForwardMode
4545
@test ad.mode == EnzymeCore.Forward
4646

47-
ad = AutoEnzyme(; mode = EnzymeCore.Reverse, constant_function = true)
47+
ad = AutoEnzyme(; mode = EnzymeCore.Reverse)
4848
@test ad isa AbstractADType
49-
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), true}
49+
@test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse)}
5050
@test mode(ad) isa ReverseMode
5151
@test ad.mode == EnzymeCore.Reverse
5252
end

0 commit comments

Comments
 (0)