-
Notifications
You must be signed in to change notification settings - Fork 12
New libtask interface #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
50d493c
af4cf44
bf11785
8518a58
3eb7276
89ac7d3
a54d271
dc5e594
41be1fe
da248b6
5f2765b
2d29914
d7dff1c
41125c4
6dff5f8
4bf2ac5
f863a7d
97e69ed
423d731
9f3f47c
0f6435f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,17 @@ else | |
using ..Libtask: Libtask | ||
end | ||
|
||
# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG, | ||
# and sometimes both an RNG and other information. In Turing.jl the other information | ||
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of | ||
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway. | ||
struct TapedGlobals{RngType} | ||
rng::RngType | ||
other::Any | ||
end | ||
|
||
TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing) | ||
|
||
""" | ||
LibtaskModel{F} | ||
|
||
|
@@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`. | |
function AdvancedPS.LibtaskModel( | ||
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... | ||
) # Changed the API, need to take care of the RNG properly | ||
return AdvancedPS.LibtaskModel( | ||
f, | ||
Libtask.TapedTask( | ||
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)} | ||
), | ||
) | ||
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...)) | ||
end | ||
|
||
""" | ||
|
@@ -43,6 +49,29 @@ end | |
|
||
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} | ||
|
||
"""Get the RNG from a `LibtaskTrace`.""" | ||
function get_rng(trace::LibtaskTrace) | ||
return trace.model.ctask.taped_globals.rng | ||
end | ||
|
||
"""Set the RNG for a `LibtaskTrace`.""" | ||
function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG) | ||
taped_globals = trace.model.ctask.taped_globals | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other)) | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
trace.rng = rng | ||
return trace | ||
end | ||
|
||
"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG.""" | ||
function set_other_global!(trace::LibtaskTrace, other) | ||
rng = get_rng(trace) | ||
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other)) | ||
return trace | ||
end | ||
|
||
"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG.""" | ||
get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other | ||
|
||
function AdvancedPS.Trace( | ||
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... | ||
) | ||
|
@@ -51,27 +80,26 @@ end | |
|
||
# step to the next observe statement and | ||
# return the log probability of the transition (or nothing if done) | ||
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false) | ||
isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng) | ||
AdvancedPS.inc_counter!(t.rng) | ||
|
||
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false) | ||
rng = get_rng(trace) | ||
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng) | ||
AdvancedPS.inc_counter!(rng) | ||
set_rng!(trace, rng) | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Move to next step | ||
return Libtask.consume(t.model.ctask) | ||
return Libtask.consume(trace.model.ctask) | ||
end | ||
|
||
# create a backward reference in task_local_storage | ||
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace) | ||
if task.storage === nothing | ||
task.storage = IdDict() | ||
end | ||
task.storage[:__trace] = trace | ||
|
||
return task | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The last bug that I had here was in fact a missing call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I suggest removing it from this PR since More context: I think we are safe if we always store and retrieve |
||
Set a backreference so that the TapedTask in `trace` stores the `trace` itself in the | ||
taped globals. | ||
""" | ||
function AdvancedPS.addreference!(trace::LibtaskTrace) | ||
set_other_global!(trace, trace) | ||
return trace | ||
end | ||
|
||
function AdvancedPS.update_rng!(trace::LibtaskTrace) | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rng, = trace.model.ctask.args | ||
trace.rng = rng | ||
set_rng!(trace, deepcopy(get_rng(trace))) | ||
return trace | ||
end | ||
|
||
|
@@ -81,26 +109,23 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false) | |
AdvancedPS.update_rng!(newtrace) | ||
isref && AdvancedPS.delete_retained!(newtrace.model.f) | ||
isref && delete_seeds!(newtrace) | ||
|
||
# add backward reference | ||
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) | ||
AdvancedPS.addreference!(newtrace) | ||
return newtrace | ||
end | ||
|
||
# PG requires keeping all randomness for the reference particle | ||
# Create new task and copy randomness | ||
function AdvancedPS.forkr(trace::LibtaskTrace) | ||
rng = get_rng(trace) | ||
newf = AdvancedPS.reset_model(trace.model.f) | ||
Random123.set_counter!(trace.rng, 1) | ||
Random123.set_counter!(rng, 1) | ||
trace.rng = rng | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
ctask = Libtask.TapedTask( | ||
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)} | ||
) | ||
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) | ||
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) | ||
|
||
# add backward reference | ||
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng) | ||
AdvancedPS.addreference!(ctask.task, newtrace) | ||
newtrace = AdvancedPS.Trace(new_tapedmodel, rng) | ||
AdvancedPS.gen_refseed!(newtrace) | ||
return newtrace | ||
end | ||
|
@@ -113,11 +138,12 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing | |
Observe sample `x` from distribution `dist` and yield its log-likelihood value. | ||
""" | ||
function AdvancedPS.observe(dist::Distributions.Distribution, x) | ||
return Libtask.produce(Distributions.loglikelihood(dist, x)) | ||
Libtask.produce(Distributions.loglikelihood(dist, x)) | ||
return nothing | ||
end | ||
|
||
""" | ||
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModel | ||
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModelext | ||
mhauru marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
function AbstractMCMC.step( | ||
|
@@ -138,7 +164,6 @@ function AbstractMCMC.step( | |
else | ||
trng = AdvancedPS.TracedRNG() | ||
trace = AdvancedPS.Trace(deepcopy(model), trng) | ||
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ? | ||
trace | ||
end | ||
end | ||
|
@@ -153,8 +178,7 @@ function AbstractMCMC.step( | |
newtrajectory = rand(rng, particles) | ||
|
||
replayed = AdvancedPS.replay(newtrajectory) | ||
return AdvancedPS.PGSample(replayed.model.f, logevidence), | ||
AdvancedPS.PGState(newtrajectory) | ||
return AdvancedPS.PGSample(replayed.model.f, logevidence), AdvancedPS.PGState(replayed) | ||
end | ||
|
||
function AbstractMCMC.sample( | ||
|
@@ -176,7 +200,6 @@ function AbstractMCMC.sample( | |
traces = map(1:(sampler.nparticles)) do i | ||
trng = AdvancedPS.TracedRNG() | ||
trace = AdvancedPS.Trace(deepcopy(model), trng) | ||
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ? | ||
trace | ||
end | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove
addreference!
, the fieldother
will only storevarinfo
instances.