Skip to content

New libtask interface #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.6.2"
version = "0.7"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -21,13 +21,13 @@ AdvancedPSLibtaskExt = "Libtask"
[compat]
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Libtask = "0.9.2"
Random = "<0.0.1, 1"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
SSMProblems = "0.5"
julia = "1.7"
julia = "1.10.8"

[extras]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Expand Down
102 changes: 59 additions & 43 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ else
using ..Libtask: Libtask
end

# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
# and sometimes both an RNG and other information. In Turing.jl the other information
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
struct TapedGlobals{RngType}
rng::RngType
other::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove addreference!, the field other will only store varinfo instances.

end

TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)

"""
LibtaskModel{F}

Expand All @@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
function AdvancedPS.LibtaskModel(
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
) # Changed the API, need to take care of the RNG properly
return AdvancedPS.LibtaskModel(
f,
Libtask.TapedTask(
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
),
)
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
end

"""
Expand All @@ -43,64 +49,76 @@ end

const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
function Base.copy(trace::LibtaskTrace)
newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng))
set_other_global!(newtrace, newtrace)
return newtrace
end

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng)
AdvancedPS.inc_counter!(t.rng)

# Move to next step
return Libtask.consume(t.model.ctask)
"""Get the RNG from a `LibtaskTrace`."""
function get_rng(trace::LibtaskTrace)
return trace.model.ctask.taped_globals.rng
end

# create a backward reference in task_local_storage
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:__trace] = trace
"""Set the RNG for a `LibtaskTrace`."""
function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG)
other = get_other_global(trace.model.ctask)
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
trace.rng = rng
return trace
end

return task
"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
function set_other_global!(trace::LibtaskTrace, other)
rng = get_rng(trace)
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
return trace
end

function AdvancedPS.update_rng!(trace::LibtaskTrace)
rng, = trace.model.ctask.args
trace.rng = rng
"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
trace = AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
# Set a backreference so that the TapedTask in `trace` stores the `trace` itself in its
# taped globals.
set_other_global!(trace, trace)
return trace
end

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
rng = get_rng(trace)
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
AdvancedPS.inc_counter!(rng)
# Move to next step
return Libtask.consume(trace.model.ctask)
end

# Task copying version of fork for Trace.
function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
newtrace = copy(trace)
AdvancedPS.update_rng!(newtrace)
set_rng!(newtrace, deepcopy(get_rng(newtrace)))
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)

# add backward reference
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end

# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function AdvancedPS.forkr(trace::LibtaskTrace)
rng = get_rng(trace)
newf = AdvancedPS.reset_model(trace.model.f)
Random123.set_counter!(trace.rng, 1)
Random123.set_counter!(rng, 1)

ctask = Libtask.TapedTask(
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
)
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
AdvancedPS.addreference!(ctask.task, newtrace)
newtrace = AdvancedPS.Trace(new_tapedmodel, rng)
AdvancedPS.gen_refseed!(newtrace)
return newtrace
end
Expand All @@ -113,7 +131,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
"""
function AdvancedPS.observe(dist::Distributions.Distribution, x)
return Libtask.produce(Distributions.loglikelihood(dist, x))
Libtask.produce(Distributions.loglikelihood(dist, x))
return nothing
end

"""
Expand All @@ -138,7 +157,6 @@ function AbstractMCMC.step(
else
trng = AdvancedPS.TracedRNG()
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
trace
end
end
Expand All @@ -153,8 +171,7 @@ function AbstractMCMC.step(
newtrajectory = rand(rng, particles)

replayed = AdvancedPS.replay(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence),
AdvancedPS.PGState(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence), AdvancedPS.PGState(replayed)
end

function AbstractMCMC.sample(
Expand All @@ -176,7 +193,6 @@ function AbstractMCMC.sample(
traces = map(1:(sampler.nparticles)) do i
trng = AdvancedPS.TracedRNG()
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
trace
end

Expand Down
3 changes: 0 additions & 3 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,6 @@ end
# in an extension, we just define dummy in the main module and implement them in the extension.
function observe end
function replay end
function addreference! end

current_trace() = current_task().storage[:__trace]

# We need this one to be visible outside of the extension for dispatching (Turing.jl).
struct LibtaskModel{F,T}
Expand Down
2 changes: 0 additions & 2 deletions src/rng.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,3 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
Increase the model step counter by `n`
"""
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n

function update_rng! end
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ GaussianDistributions = "0.5"
Kalman = "0.1"
HypothesisTests = "0.11"
DynamicIterators = "0.4"
Libtask = "0.8"
Libtask = "0.9"
Random123 = "1.3"
StableRNGs = "1"
julia = "1.3"
julia = "1.10.8"
10 changes: 5 additions & 5 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
val::Ref{Int}
end

function (model::Model)(rng::Random.AbstractRNG)
function (model::Model)()
t = [0]
while true
model.val[] += 1
Expand All @@ -148,17 +148,17 @@
@test consume(a.model.ctask) == 4
end

@testset "current trace" begin
@testset "Back-reference" begin
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end

function (model::TaskIdModel)(rng::Random.AbstractRNG)
function (model::TaskIdModel)()
# Just print the task it's running in
id = objectid(AdvancedPS.current_trace())
trace = Libtask.get_taped_globals(Any).other
id = objectid(trace)
return Libtask.produce(id)
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask.task, trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end
Expand Down
23 changes: 16 additions & 7 deletions test/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
NormalModel() = new()
end

function (m::NormalModel)(rng::Random.AbstractRNG)
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

# First observation.
AdvancedPS.observe(Normal(a, 2), 3)

# Second latent variable.
rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))

# Second observation.
Expand All @@ -52,8 +54,11 @@
FailSMCModel() = new()
end

function (m::FailSMCModel)(rng::Random.AbstractRNG)
function (m::FailSMCModel)()
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))
if a >= 4
AdvancedPS.observe(Normal(b, 2), 1.5)
Expand All @@ -75,8 +80,9 @@
TestModel() = new()
end

function (m::TestModel)(rng::Random.AbstractRNG)
function (m::TestModel)()
# First hidden variables.
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal(0, 1))
m.x = x = rand(rng, Bernoulli(1))
m.b = rand(rng, Gamma(2, 3))
Expand All @@ -85,13 +91,14 @@
AdvancedPS.observe(Bernoulli(x / 2), 1)

# Second hidden variable.
rng = Libtask.get_taped_globals(Any).rng
m.c = rand(rng, Beta())

# Second observation.
return AdvancedPS.observe(Bernoulli(x / 2), 0)
end

chains_smc = sample(TestModel(), AdvancedPS.SMC(100))
chains_smc = sample(TestModel(), AdvancedPS.SMC(100); progress=false)

@test all(isone(particle.x) for particle in chains_smc.trajectories)
@test chains_smc.logevidence ≈ -2 * log(2)
Expand Down Expand Up @@ -145,7 +152,7 @@
return AdvancedPS.observe(Bernoulli(x / 2), 0)
end

chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100)
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100; progress=false)

@test all(isone(p.trajectory.x) for p in chains_pg)
@test mean(x.logevidence for x in chains_pg) ≈ -2 * log(2) atol = 0.01
Expand All @@ -159,16 +166,18 @@
DummyModel() = new()
end

function (m::DummyModel)(rng)
function (m::DummyModel)()
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal())
AdvancedPS.observe(Normal(), m.a)

rng = Libtask.get_taped_globals(Any).rng
m.b = rand(rng, Normal())
return AdvancedPS.observe(Normal(), m.b)
end

pg = AdvancedPS.PG(1)
first, second = sample(DummyModel(), pg, 2)
first, second = sample(DummyModel(), pg, 2; progress=false)

first_model = first.trajectory
second_model = second.trajectory
Expand Down
Loading