Skip to content

New libtask interface #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.6.2"
version = "0.7"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -21,13 +21,13 @@ AdvancedPSLibtaskExt = "Libtask"
[compat]
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Libtask = "0.9.2"
Random = "<0.0.1, 1"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
SSMProblems = "0.5"
julia = "1.7"
julia = "1.10.8"

[extras]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Expand Down
95 changes: 59 additions & 36 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ else
using ..Libtask: Libtask
end

# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
# and sometimes both an RNG and other information. In Turing.jl the other information
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
struct TapedGlobals{RngType}
rng::RngType
other::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove addreference!, the field other will only store varinfo instances.

end

TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)

"""
LibtaskModel{F}

Expand All @@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
function AdvancedPS.LibtaskModel(
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
) # Changed the API, need to take care of the RNG properly
return AdvancedPS.LibtaskModel(
f,
Libtask.TapedTask(
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
),
)
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
end

"""
Expand All @@ -43,6 +49,29 @@ end

const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}

"""Get the RNG from a `LibtaskTrace`."""
function get_rng(trace::LibtaskTrace)
return trace.model.ctask.taped_globals.rng
end

"""Set the RNG for a `LibtaskTrace`."""
function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG)
taped_globals = trace.model.ctask.taped_globals
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
trace.rng = rng
return trace
end

"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
function set_other_global!(trace::LibtaskTrace, other)
rng = get_rng(trace)
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
return trace
end

"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
Expand All @@ -51,27 +80,26 @@ end

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng)
AdvancedPS.inc_counter!(t.rng)

function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
rng = get_rng(trace)
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
AdvancedPS.inc_counter!(rng)
set_rng!(trace, rng)
# Move to next step
return Libtask.consume(t.model.ctask)
return Libtask.consume(trace.model.ctask)
end

# create a backward reference in task_local_storage
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:__trace] = trace

return task
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that addreference! is no longer needed, given that it stores a self-reference? If so, can we remove it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last bug that I had here was in fact a missing call to addreference!, I think in fork. We also need to call it in Turing's particle_mcmc, to keep the reference in sync so that we can access the varinfo of the trace from within the TapedTask. There's probably a way to get rid of it, but that would require some refactoring, which would require me learning better what is going on. If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?

I suggest removing it from this PR since addreference! adds another layer of indirection. Given that this PR is already a breaking release and requires updates to Turing, I think it is worth it.

More context:

I think we are safe if we always store and retrieve (rng, varinfo) via set_taped_globals! / get_taped_globals. The entire motivation for addreference! is to be able to copy varinfo external to a live particle (e.g. during the resampling step, multiple child particles share the same varinfo). It should produce correct results if particles always retrieve varinfo via get_taped_global. Otherwise, something weird is happening.

Set a backreference so that the TapedTask in `trace` stores the `trace` itself in the
taped globals.
"""
function AdvancedPS.addreference!(trace::LibtaskTrace)
set_other_global!(trace, trace)
return trace
end

function AdvancedPS.update_rng!(trace::LibtaskTrace)
rng, = trace.model.ctask.args
trace.rng = rng
set_rng!(trace, deepcopy(get_rng(trace)))
return trace
end

Expand All @@ -81,26 +109,23 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
AdvancedPS.update_rng!(newtrace)
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)

# add backward reference
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
AdvancedPS.addreference!(newtrace)
return newtrace
end

# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function AdvancedPS.forkr(trace::LibtaskTrace)
rng = get_rng(trace)
newf = AdvancedPS.reset_model(trace.model.f)
Random123.set_counter!(trace.rng, 1)
Random123.set_counter!(rng, 1)
trace.rng = rng

ctask = Libtask.TapedTask(
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
)
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
AdvancedPS.addreference!(ctask.task, newtrace)
newtrace = AdvancedPS.Trace(new_tapedmodel, rng)
AdvancedPS.gen_refseed!(newtrace)
return newtrace
end
Expand All @@ -113,11 +138,12 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
"""
function AdvancedPS.observe(dist::Distributions.Distribution, x)
return Libtask.produce(Distributions.loglikelihood(dist, x))
Libtask.produce(Distributions.loglikelihood(dist, x))
return nothing
end

"""
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModel
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModelext
"""

function AbstractMCMC.step(
Expand All @@ -138,7 +164,6 @@ function AbstractMCMC.step(
else
trng = AdvancedPS.TracedRNG()
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
trace
end
end
Expand All @@ -153,8 +178,7 @@ function AbstractMCMC.step(
newtrajectory = rand(rng, particles)

replayed = AdvancedPS.replay(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence),
AdvancedPS.PGState(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence), AdvancedPS.PGState(replayed)
end

function AbstractMCMC.sample(
Expand All @@ -176,7 +200,6 @@ function AbstractMCMC.sample(
traces = map(1:(sampler.nparticles)) do i
trng = AdvancedPS.TracedRNG()
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
trace
end

Expand Down
2 changes: 0 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ function observe end
function replay end
function addreference! end

current_trace() = current_task().storage[:__trace]

# We need this one to be visible outside of the extension for dispatching (Turing.jl).
struct LibtaskModel{F,T}
f::F
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ GaussianDistributions = "0.5"
Kalman = "0.1"
HypothesisTests = "0.11"
DynamicIterators = "0.4"
Libtask = "0.8"
Libtask = "0.9"
Random123 = "1.3"
StableRNGs = "1"
julia = "1.3"
julia = "1.10.8"
11 changes: 6 additions & 5 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
val::Ref{Int}
end

function (model::Model)(rng::Random.AbstractRNG)
function (model::Model)()
t = [0]
while true
model.val[] += 1
Expand All @@ -148,17 +148,18 @@
@test consume(a.model.ctask) == 4
end

@testset "current trace" begin
@testset "Back-reference" begin
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end

function (model::TaskIdModel)(rng::Random.AbstractRNG)
function (model::TaskIdModel)()
# Just print the task it's running in
id = objectid(AdvancedPS.current_trace())
trace = Libtask.get_taped_globals(Any).other
id = objectid(trace)
return Libtask.produce(id)
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask.task, trace)
AdvancedPS.addreference!(trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end
Expand Down
23 changes: 16 additions & 7 deletions test/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
NormalModel() = new()
end

function (m::NormalModel)(rng::Random.AbstractRNG)
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

# First observation.
AdvancedPS.observe(Normal(a, 2), 3)

# Second latent variable.
rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))

# Second observation.
Expand All @@ -52,8 +54,11 @@
FailSMCModel() = new()
end

function (m::FailSMCModel)(rng::Random.AbstractRNG)
function (m::FailSMCModel)()
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))
if a >= 4
AdvancedPS.observe(Normal(b, 2), 1.5)
Expand All @@ -75,8 +80,9 @@
TestModel() = new()
end

function (m::TestModel)(rng::Random.AbstractRNG)
function (m::TestModel)()
# First hidden variables.
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal(0, 1))
m.x = x = rand(rng, Bernoulli(1))
m.b = rand(rng, Gamma(2, 3))
Expand All @@ -85,13 +91,14 @@
AdvancedPS.observe(Bernoulli(x / 2), 1)

# Second hidden variable.
rng = Libtask.get_taped_globals(Any).rng
m.c = rand(rng, Beta())

# Second observation.
return AdvancedPS.observe(Bernoulli(x / 2), 0)
end

chains_smc = sample(TestModel(), AdvancedPS.SMC(100))
chains_smc = sample(TestModel(), AdvancedPS.SMC(100); progress=false)

@test all(isone(particle.x) for particle in chains_smc.trajectories)
@test chains_smc.logevidence ≈ -2 * log(2)
Expand Down Expand Up @@ -145,7 +152,7 @@
return AdvancedPS.observe(Bernoulli(x / 2), 0)
end

chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100)
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100; progress=false)

@test all(isone(p.trajectory.x) for p in chains_pg)
@test mean(x.logevidence for x in chains_pg) ≈ -2 * log(2) atol = 0.01
Expand All @@ -159,16 +166,18 @@
DummyModel() = new()
end

function (m::DummyModel)(rng)
function (m::DummyModel)()
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal())
AdvancedPS.observe(Normal(), m.a)

rng = Libtask.get_taped_globals(Any).rng
m.b = rand(rng, Normal())
return AdvancedPS.observe(Normal(), m.b)
end

pg = AdvancedPS.PG(1)
first, second = sample(DummyModel(), pg, 2)
first, second = sample(DummyModel(), pg, 2; progress=false)

first_model = first.trajectory
second_model = second.trajectory
Expand Down
Loading