Skip to content

Commit b18e3e2

Browse files
committed
Add AbstractOptimizationResult
This allows for more seamless conversion between the result objects of different methods. To convert between two result types, the first result is converted to a Dict of field names to values, and then that dict is converted to the target result type. This assumes that all result types have a common set of field names, and for any field in a result that is not in that common set, a custom convert method must be defined that sets default values for those fields in the target result type.
1 parent 71b6197 commit b18e3e2

File tree

7 files changed

+219
-11
lines changed

7 files changed

+219
-11
lines changed

docs/generate_api.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ open(outfile, "w") do out
279279
280280
```@docs
281281
QuantumControl.set_default_ad_framework
282+
QuantumControl.AbstractOptimizationResult
282283
```
283284
""")
284285
write(out, raw"""

src/QuantumControl.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ include("functionals.jl") # submodule Functionals
6565

6666
include("print_versions.jl")
6767
include("set_default_ad_framework.jl")
68+
include("result.jl")
6869

6970
include("deprecate.jl")
7071

src/optimize.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,8 @@ where `:Krotov` is the name of the module implementing the method. The above is
6666
also the method signature that a `Module` wishing to implement a control method
6767
must define.
6868
69-
The returned `result` object is specific to the optimization method.
69+
The returned `result` object is specific to the optimization method, but should
70+
be a subtype of [`QuantumControl.AbstractOptimizationResult`](@ref).
7071
"""
7172
function optimize(
7273
problem::ControlProblem;

src/propagate.jl

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ function init_prop_trajectory(
9898
_prefixes=["prop_"],
9999
_msg="Initializing propagator for trajectory",
100100
_filter_kwargs=false,
101+
_kwargs_dict::Dict{Symbol,Any}=Dict{Symbol,Any}(),
101102
initial_state=traj.initial_state,
102103
verbose=false,
103104
kwargs...
104105
)
105106
#
106-
# The private keyword arguments, `_prefixes`, `_msg`, `_filter_kwargs` are
107-
# for internal use when setting up optimal control workspace objects (see,
108-
# e.g., Krotov.jl and GRAPE.jl)
107+
# The private keyword arguments, `_prefixes`, `_msg`, `_filter_kwargs`,
108+
# `_kwargs_dict` are for internal use when setting up optimal control
109+
# workspace objects (see, e.g., Krotov.jl and GRAPE.jl)
109110
#
110111
# * `_prefixes`: which prefixes to translate into `init_prop` kwargs. For
111112
# example, in Krotov/GRAPE, we have propagators both for the forward and
@@ -117,12 +118,16 @@ function init_prop_trajectory(
117118
# allows to pass the keyword arguments from `optimize` directly to
118119
# `init_prop_trajectory`. By convention, these use the same
119120
# `prop`/`fw_prop`/`bw_prop` prefixes as the properties of `traj`.
121+
# * `_kwargs_dict`: A dictionary Symbol => Any that collects the arguments
122+
# for `init_prop`. This allows to keep a copy of those arguments,
123+
# especially for arguments that cannot be obtained from the resulting
124+
# propagator, like the propagation callback.
120125
#
121-
kwargs_dict = Dict{Symbol,Any}()
126+
empty!(_kwargs_dict)
122127
for prefix in _prefixes
123128
for key in propertynames(traj)
124129
if startswith(string(key), prefix)
125-
kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] =
130+
_kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] =
126131
getproperty(traj, key)
127132
end
128133
end
@@ -131,20 +136,20 @@ function init_prop_trajectory(
131136
for prefix in _prefixes
132137
for (key, val) in kwargs
133138
if startswith(string(key), prefix)
134-
kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] = val
139+
_kwargs_dict[Symbol(string(key)[length(prefix)+1:end])] = val
135140
end
136141
end
137142
end
138143
else
139-
merge!(kwargs_dict, kwargs)
144+
merge!(_kwargs_dict, kwargs)
140145
end
141146
level = verbose ? Logging.Info : Logging.Debug
142-
@logmsg level _msg kwargs = kwargs_dict
147+
@logmsg level _msg kwargs = _kwargs_dict
143148
try
144-
return init_prop(initial_state, traj.generator, tlist; verbose, kwargs_dict...)
149+
return init_prop(initial_state, traj.generator, tlist; verbose, _kwargs_dict...)
145150
catch exception
146151
msg = "Cannot initialize propagation for trajectory"
147-
@error msg exception kwargs = kwargs_dict
152+
@error msg exception kwargs = _kwargs_dict
148153
rethrow()
149154
end
150155
end

src/result.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
Abstract type for the result object returned by [`optimize`](@ref). Any
3+
optimization method implemented on top of `QuantumControl` should subtype
4+
from `AbstractOptimizationResult`. This enables conversion between the results
5+
of different methods, allowing one method to continue an optimization from
6+
another method.
7+
8+
In order for this to work seamlessly, result objects should use a common set of
9+
field names as much as a possible. When a result object requires fields that
10+
cannot be provided by all other result objects, it should have default values
11+
for these field, which can be defined in a custom `Base.convert` method, as,
12+
e.g.,
13+
14+
```julia
15+
function Base.convert(::Type{MyResult}, result::AbstractOptimizationResult)
16+
defaults = Dict{Symbol,Any}(
17+
:f_calls => 0,
18+
:fg_calls => 0,
19+
)
20+
return convert(MyResult, result, defaults)
21+
end
22+
```
23+
24+
Where `f_calls` and `fg_calls` are fields of `MyResult` that are not present in
25+
a given `result` of a different type. The three-argument `convert` is defined
26+
internally for any `AbstractOptimizationResult`.
27+
"""
28+
abstract type AbstractOptimizationResult end
29+
30+
function Base.convert(
31+
::Type{Dict{Symbol,Any}},
32+
result::R
33+
) where {R<:AbstractOptimizationResult}
34+
return Dict{Symbol,Any}(field => getfield(result, field) for field in fieldnames(R))
35+
end
36+
37+
38+
struct MissingResultDataException{R} <: Exception
39+
missing_fields::Vector{Symbol}
40+
end
41+
42+
43+
function Base.showerror(io::IO, err::MissingResultDataException{R}) where {R}
44+
msg = "Missing data for fields $(err.missing_fields) to instantiate $R."
45+
print(io, msg)
46+
end
47+
48+
49+
struct IncompatibleResultsException{R1,R2} <: Exception
50+
missing_fields::Vector{Symbol}
51+
end
52+
53+
54+
function Base.showerror(io::IO, err::IncompatibleResultsException{R1,R2}) where {R1,R2}
55+
msg = "$R2 cannot be converted to $R1: $R2 does not provide required fields $(err.missing_fields). $R1 may need a custom implementation of `Base.convert` that sets values for any field names not provided by all results."
56+
print(io, msg)
57+
end
58+
59+
60+
function Base.convert(
61+
::Type{R},
62+
data::Dict{Symbol,<:Any},
63+
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(),
64+
) where {R<:AbstractOptimizationResult}
65+
66+
function _get(data, field, defaults)
67+
# Can't use `get`, because that would try to evaluate the non-existing
68+
# `defaults[field]` for `fields` that actually exist in `data`.
69+
if haskey(data, field)
70+
return data[field]
71+
else
72+
return defaults[field]
73+
end
74+
end
75+
76+
args = try
77+
[_get(data, field, defaults) for field in fieldnames(R)]
78+
catch exc
79+
if exc isa KeyError
80+
missing_fields = [
81+
field for field in fieldnames(R) if
82+
!(haskey(data, field) || haskey(defaults, field))
83+
]
84+
throw(MissingResultDataException{R}(missing_fields))
85+
else
86+
rethrow()
87+
end
88+
end
89+
return R(args...)
90+
end
91+
92+
93+
function Base.convert(
94+
::Type{R1},
95+
result::R2,
96+
defaults::Dict{Symbol,<:Any}=Dict{Symbol,Any}(),
97+
) where {R1<:AbstractOptimizationResult,R2<:AbstractOptimizationResult}
98+
data = convert(Dict{Symbol,Any}, result)
99+
try
100+
return convert(R1, data, defaults)
101+
catch exc
102+
if exc isa MissingResultDataException{R1}
103+
throw(IncompatibleResultsException{R1,R2}(exc.missing_fields))
104+
else
105+
rethrow()
106+
end
107+
end
108+
end

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,11 @@ end
8686
include("test_pulse_parameterizations.jl")
8787
end
8888

89+
println("\n* Result Conversion (test_result_conversion.jl):")
90+
@time @safetestset "Result Conversion" begin
91+
include("test_result_conversion.jl")
92+
end
93+
8994
println("* Invalid interfaces (test_invalid_interfaces.jl):")
9095
@time @safetestset "Invalid interfaces" begin
9196
include("test_invalid_interfaces.jl")

test/test_result_conversion.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
using Test
2+
using IOCapture
3+
4+
using QuantumControl:
5+
AbstractOptimizationResult, MissingResultDataException, IncompatibleResultsException
6+
7+
struct _TestOptimizationResult1 <: AbstractOptimizationResult
8+
iter_start::Int64
9+
iter_stop::Int64
10+
end
11+
12+
struct _TestOptimizationResult2 <: AbstractOptimizationResult
13+
iter_start::Int64
14+
J_T::Float64
15+
J_T_prev::Float64
16+
end
17+
18+
struct _TestOptimizationResult3 <: AbstractOptimizationResult
19+
iter_start::Int64
20+
iter_stop::Int64
21+
end
22+
23+
@testset "Dict conversion" begin
24+
25+
R = _TestOptimizationResult1(0, 100)
26+
27+
data = convert(Dict{Symbol,Any}, R)
28+
@test data isa Dict{Symbol,Any}
29+
@test Set(keys(data)) == Set((:iter_stop, :iter_start))
30+
@test data[:iter_start] == 0
31+
@test data[:iter_stop] == 100
32+
33+
@test _TestOptimizationResult1(0, 100) _TestOptimizationResult1(0, 50)
34+
35+
_R = convert(_TestOptimizationResult1, data)
36+
@test _R == R
37+
38+
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
39+
convert(_TestOptimizationResult2, data)
40+
end
41+
@test captured.value isa MissingResultDataException
42+
msg = begin
43+
io = IOBuffer()
44+
showerror(io, captured.value)
45+
String(take!(io))
46+
end
47+
@test startswith(msg, "Missing data for fields [:J_T, :J_T_prev]")
48+
@test contains(msg, "_TestOptimizationResult2")
49+
50+
end
51+
52+
53+
@testset "Result conversion" begin
54+
55+
R = _TestOptimizationResult1(0, 100)
56+
57+
_R = convert(_TestOptimizationResult1, R)
58+
@test _R == R
59+
60+
_R = convert(_TestOptimizationResult3, R)
61+
@test _R isa _TestOptimizationResult3
62+
@test convert(Dict{Symbol,Any}, _R) == convert(Dict{Symbol,Any}, R)
63+
64+
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
65+
convert(_TestOptimizationResult2, R)
66+
end
67+
@test captured.value isa IncompatibleResultsException
68+
msg = begin
69+
io = IOBuffer()
70+
showerror(io, captured.value)
71+
String(take!(io))
72+
end
73+
@test contains(msg, "does not provide required fields [:J_T, :J_T_prev]")
74+
75+
R2 = _TestOptimizationResult2(0, 0.1, 0.4)
76+
captured = IOCapture.capture(; passthrough=false, rethrow=Union{}) do
77+
convert(_TestOptimizationResult1, R2)
78+
end
79+
@test captured.value isa IncompatibleResultsException
80+
msg = begin
81+
io = IOBuffer()
82+
showerror(io, captured.value)
83+
String(take!(io))
84+
end
85+
@test contains(msg, "does not provide required fields [:iter_stop]")
86+
87+
end

0 commit comments

Comments
 (0)