-
Notifications
You must be signed in to change notification settings - Fork 227
AdvancedPS v0.7 (and thus Libtask v0.9) support #2585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f687db0
2366bfa
b4823d9
d34dd3d
276e56e
7cf8ee0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,9 +25,8 @@ | |
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.", | ||
) | ||
end | ||
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}( | ||
model, sampler, varinfo, (model.f, args...) | ||
) | ||
evaluator = (model.f, args...) | ||
return TracedModel(model, sampler, varinfo, evaluator) | ||
end | ||
|
||
function AdvancedPS.advance!( | ||
|
@@ -59,20 +58,10 @@ | |
return trace | ||
end | ||
|
||
function AdvancedPS.update_rng!( | ||
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}} | ||
) | ||
# Extract the `args`. | ||
args = trace.model.ctask.args | ||
# From `args`, extract the `SamplingContext`, which contains the RNG. | ||
sampling_context = args[3] | ||
rng = sampling_context.rng | ||
trace.rng = rng | ||
return trace | ||
end | ||
|
||
function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ? | ||
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...) | ||
function Libtask.TapedTask(taped_globals::Any, model::TracedModel, args...; kwargs...) # RNG ? | ||
return Libtask.TapedTask( | ||
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs... | ||
) | ||
end | ||
|
||
abstract type ParticleInference <: InferenceAlgorithm end | ||
|
@@ -402,11 +391,11 @@ | |
|
||
function trace_local_varinfo_maybe(varinfo) | ||
try | ||
trace = AdvancedPS.current_trace() | ||
return trace.model.f.varinfo | ||
trace = Libtask.get_taped_globals(Any).other | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we change |
||
return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo | ||
catch e | ||
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. | ||
if e == KeyError(:__trace) || current_task().storage isa Nothing | ||
if e == KeyError(:task_variable) | ||
return varinfo | ||
else | ||
rethrow(e) | ||
|
@@ -416,11 +405,10 @@ | |
|
||
function trace_local_rng_maybe(rng::Random.AbstractRNG) | ||
try | ||
trace = AdvancedPS.current_trace() | ||
return trace.rng | ||
return Libtask.get_taped_globals(Any).rng | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with above. |
||
catch e | ||
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`. | ||
if e == KeyError(:__trace) || current_task().storage isa Nothing | ||
if e == KeyError(:task_variable) | ||
return rng | ||
else | ||
rethrow(e) | ||
|
@@ -481,6 +469,25 @@ | |
|
||
tmodel = TracedModel(model, sampler, newvarinfo, rng) | ||
newtrace = AdvancedPS.Trace(tmodel, rng) | ||
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) | ||
return newtrace | ||
end | ||
|
||
# We need to tell Libtask which calls may have `produce` calls within them. In practice most | ||
# of these won't be needed, because of inlining and the fact that `might_produce` is only | ||
# called on `:invoke` expressions rather than `:call`s, but since those are implementation | ||
# details of the compiler, we set a bunch of methods as might_produce = true. We start with | ||
# `acclogp_observe!!` which is what calls `produce` and go up the call stack. | ||
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true | ||
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true | ||
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true | ||
function Libtask.might_produce( | ||
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}} | ||
) | ||
return true | ||
end | ||
function Libtask.might_produce( | ||
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}} | ||
) | ||
return true | ||
end | ||
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Libtask requires 1.10.8 at a minimum.