Skip to content

Commit 5ce75c3

Browse files
authored
Merge pull request #78 from JuliaQuantumControl/abstract-optimization-result
Add `AbstractOptimizationResult`
2 parents 71b6197 + b18e3e2 commit 5ce75c3

File tree

7 files changed

+219
-11
lines changed

7 files changed

+219
-11
lines changed

docs/generate_api.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ open(outfile, "w") do out
279279
280280
```@docs
281281
QuantumControl.set_default_ad_framework
282+
QuantumControl.AbstractOptimizationResult
282283
```
283284
""")
284285
write(out, raw"""

src/QuantumControl.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ include("functionals.jl") # submodule Functionals
6565

6666
include("print_versions.jl")
6767
include("set_default_ad_framework.jl")
68+
include("result.jl")
6869

6970
include("deprecate.jl")
7071

src/optimize.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ where `:Krotov` is the name of the module implementing the method. The above is
6666
also the method signature that a `Module` wishing to implement a control method
6767
must define.
6868
69-
The returned `result` object is specific to the optimization method.
69+
The returned `result` object is specific to the optimization method, but should
70+
be a subtype of [`QuantumControl.AbstractOptimizationResult`](@ref).
7071
"""
7172
function optimize(
7273
problem::ControlProblem;

src/propagate.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ function init_prop_trajectory(
9898
_prefixes=["prop_"],
9999
_msg="Initializing propagator for trajectory",
100100
_filter_kwargs=false,
101+
_kwargs_dict::Dict{Symbol,Any}=Dict{Symbol,Any}(),
101102
initial_state=traj.initial_state,
102103
verbose=false,
103104
kwargs...
104105
)
105106
#
106-
# The private keyword arguments, `_prefixes`, `_msg`, `_filter_kwargs` are
107-
# for internal use when setting up optimal control workspace objects (see,
108-
# e.g., Krotov.jl and GRAPE.jl)
107+
# The private keyword arguments, `_prefixes`, `_msg`, `_filter_kwargs`,
108+
# `_kwargs_dict` are for internal use when setting up optimal control
109+
# workspace objects (see, e.g., Krotov.jl and GRAPE.jl)
109110
#
110111
# * `_prefixes`: which prefixes to translate into `init_prop` kwargs. For
111112
# example, in Krotov/GRAPE, we have propagators both for the forward and
@@ -117,12 +118,16 @@ function init_prop_trajectory(
117118
# allows to pass the keyword arguments from `optimize` directly to
118119
# `init_prop_trajectory`. By convention, these use the same
119120
# `prop`/`fw_prop`/`bw_prop` prefixes as the properties of `traj`.
121+
# * `_kwargs_dict`: A dictionary Symbol => Any that collects the arguments
122+
# for `init_prop`. This allows to keep a copy of those arguments,
123+
# especially for arguments that cannot be obtained from the resulting
124+
# propagator, like the propagation callback.
120125
#
121-
kwargs_dict = Dict{Symbol,Any}()
126+
empty!(_kwargs_dict)
122127
for prefix in _prefixes
123128
for key in propertynames(traj)
124129
if startswith(string(key), prefix)
125-
kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] =
130+
_kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] =
126131
getproperty(traj, key)
127132
end
128133
end
@@ -131,20 +136,20 @@ function init_prop_trajectory(
131136
for prefix in _prefixes
132137
for (key, val) in kwargs
133138
if startswith(string(key), prefix)
134-
kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] = val
139+
_kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] = val
135140
end
136141
end
137142
end
138143
else
139-
merge!(kwargs_dict, kwargs)
144+
merge!(_kwargs_dict, kwargs)
140145
end
141146
level = verbose ? Logging.Info : Logging.Debug
142-
@logmsg level _msg kwargs = kwargs_dict
147+
@logmsg level _msg kwargs = _kwargs_dict
143148
try
144-
return init_prop(initial_state, traj.generator, tlist; verbose, kwargs_dict...)
149+
return init_prop(initial_state, traj.generator, tlist; verbose, _kwargs_dict...)
145150
catch exception
146151
msg = "Cannot initialize propagation for trajectory"
147-
@error msg exception kwargs = kwargs_dict
152+
@error msg exception kwargs = _kwargs_dict
148153
rethrow()
149154
end
150155
end

src/result.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Abstract type for the result object returned by [`optimize`](@ref). Any
3+
optimization method implemented on top of `QuantumControl` should subtype
4+
from `AbstractOptimizationResult`. This enables conversion between the results
5+
of different methods, allowing one method to continue an optimization from
6+
another method.
7+
8+
In order for this to work seamlessly, result objects should use a common set of
9+
field names as much as a possible. When a result object requires fields that
10+
cannot be provided by all other result objects, it should have default values
11+
for these field, which can be defined in a custom `Base.convert` method, as,
12+
e.g.,
13+
14+
```julia
15+
function Base.convert(::Type{MyResult}, result::AbstractOptimizationResult)
16+
defaults = Dict{Symbol,Any}(
17+
:f_calls => 0,
18+
:fg_calls => 0,
19+
)
20+
return convert(MyResult, result, defaults)
21+
end
22+
```
23+
24+
Where `f_calls` and `fg_calls` are fields of `MyResult` that are not present in
25+
a given `result` of a different type. The three-argument `convert` is defined
26+
internally for any `AbstractOptimizationResult`.
27+
"""
28+
abstract type AbstractOptimizationResult end
29+
30+
function Base.convert(
31+
::Type{Dict{Symbol,Any}},
32+
result::R
33+
) where {R<:AbstractOptimizationResult}
34+
return Dict{Symbol,Any}(field => getfield(result, field) for field in fieldnames(R))
35+
end
36+
37+
38+
struct MissingResultDataException{R} <: Exception
39+
missing_fields::Vector{Symbol}
40+
end
41+
42+
43+
function Base.showerror(io::IO, err::MissingResultDataException{R}) where {R}
44+
msg = "Missing data for fields $(err.missing_fields) to instantiate $R."
45+
print(io, msg)
46+
end
47+
48+
49+
struct IncompatibleResultsException{R1,R2} <: Exception
50+
missing_fields::Vector{Symbol}
51+
end
52+
53+
54+
function Base.showerror(io::IO, err::IncompatibleResultsException{R1,R2}) where {R1,R2}
55+
msg = "$R2 cannot be converted to $R1: $R2 does not provide required fields $(err.missing_fields). $R1 may need a custom implementation of `Base.convert` that sets values for any field names not provided by all results."
56+
print(io, msg)
57+
end
58+
59+
60+
function Base.convert(
61+
::Type{R},
62+
data::Dict{Symbol,<:Any},
63+
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(),
64+
) where {R<:AbstractOptimizationResult}
65+
66+
function _get(data, field, defaults)
67+
# Can't use `get`, because that would try to evaluate the non-existing
68+
# `defaults[field]` for `fields` that actually exist in `data`.
69+
if haskey(data, field)
70+
return data[field]
71+
else
72+
return defaults[field]
73+
end
74+
end
75+
76+
args = try
77+
[_get(data, field, defaults) for field in fieldnames(R)]
78+
catch exc
79+
if exc isa KeyError
80+
missing_fields = [
81+
field for field in fieldnames(R) if
82+
!(haskey(data, field) || haskey(defaults, field))
83+
]
84+
throw(MissingResultDataException{R}(missing_fields))
85+
else
86+
rethrow()
87+
end
88+
end
89+
return R(args...)
90+
end
91+
92+
93+
function Base.convert(
94+
::Type{R1},
95+
result::R2,
96+
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(),
97+
) where {R1<:AbstractOptimizationResult,R2<:AbstractOptimizationResult}
98+
data = convert(Dict{Symbol,Any}, result)
99+
try
100+
return convert(R1, data, defaults)
101+
catch exc
102+
if exc isa MissingResultDataException{R1}
103+
throw(IncompatibleResultsException{R1,R2}(exc.missing_fields))
104+
else
105+
rethrow()
106+
end
107+
end
108+
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ end
8686
include("test_pulse_parameterizations.jl")
8787
end
8888

89+
println("\n* Result Conversion (test_result_conversion.jl):")
90+
@time @safetestset "Result Conversion" begin
91+
include("test_result_conversion.jl")
92+
end
93+
8994
println("* Invalid interfaces (test_invalid_interfaces.jl):")
9095
@time @safetestset "Invalid interfaces" begin
9196
include("test_invalid_interfaces.jl")

test/test_result_conversion.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using Test
2+
using IOCapture
3+
4+
using QuantumControl:
5+
AbstractOptimizationResult, MissingResultDataException, IncompatibleResultsException
6+
7+
struct _TestOptimizationResult1 <: AbstractOptimizationResult
8+
iter_start::Int64
9+
iter_stop::Int64
10+
end
11+
12+
struct _TestOptimizationResult2 <: AbstractOptimizationResult
13+
iter_start::Int64
14+
J_T::Float64
15+
J_T_prev::Float64
16+
end
17+
18+
struct _TestOptimizationResult3 <: AbstractOptimizationResult
19+
iter_start::Int64
20+
iter_stop::Int64
21+
end
22+
23+
@testset "Dict conversion" begin
24+
25+
R = _TestOptimizationResult1(0, 100)
26+
27+
data = convert(Dict{Symbol,Any}, R)
28+
@test data isa Dict{Symbol,Any}
29+
@test Set(keys(data)) == Set((:iter_stop, :iter_start))
30+
@test data[:iter_start] == 0
31+
@test data[:iter_stop] == 100
32+
33+
@test _TestOptimizationResult1(0, 100) _TestOptimizationResult1(0, 50)
34+
35+
_R = convert(_TestOptimizationResult1, data)
36+
@test _R == R
37+
38+
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
39+
convert(_TestOptimizationResult2, data)
40+
end
41+
@test captured.value isa MissingResultDataException
42+
msg = begin
43+
io = IOBuffer()
44+
showerror(io, captured.value)
45+
String(take!(io))
46+
end
47+
@test startswith(msg, "Missing data for fields [:J_T, :J_T_prev]")
48+
@test contains(msg, "_TestOptimizationResult2")
49+
50+
end
51+
52+
53+
@testset "Result conversion" begin
54+
55+
R = _TestOptimizationResult1(0, 100)
56+
57+
_R = convert(_TestOptimizationResult1, R)
58+
@test _R == R
59+
60+
_R = convert(_TestOptimizationResult3, R)
61+
@test _R isa _TestOptimizationResult3
62+
@test convert(Dict{Symbol,Any}, _R) == convert(Dict{Symbol,Any}, R)
63+
64+
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
65+
convert(_TestOptimizationResult2, R)
66+
end
67+
@test captured.value isa IncompatibleResultsException
68+
msg = begin
69+
io = IOBuffer()
70+
showerror(io, captured.value)
71+
String(take!(io))
72+
end
73+
@test contains(msg, "does not provide required fields [:J_T, :J_T_prev]")
74+
75+
R2 = _TestOptimizationResult2(0, 0.1, 0.4)
76+
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
77+
convert(_TestOptimizationResult1, R2)
78+
end
79+
@test captured.value isa IncompatibleResultsException
80+
msg = begin
81+
io = IOBuffer()
82+
showerror(io, captured.value)
83+
String(take!(io))
84+
end
85+
@test contains(msg, "does not provide required fields [:iter_stop]")
86+
87+
end

0 commit comments

Comments
 (0)