Skip to content

New libtask interface #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 23, 2025
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.6.2"
version = "0.7"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -21,13 +21,13 @@ AdvancedPSLibtaskExt = "Libtask"
[compat]
AbstractMCMC = "2, 3, 4, 5"
Distributions = "0.23, 0.24, 0.25"
Libtask = "0.8"
Libtask = "0.9"
Random = "<0.0.1, 1"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
SSMProblems = "0.5"
julia = "1.7"
julia = "1.10.8"

[extras]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Expand Down
71 changes: 38 additions & 33 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ else
using ..Libtask: Libtask
end

# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
# and sometimes both an RNG and other information. In Turing.jl this other information
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
struct TapedGlobals{RngType}
rng::RngType
other::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove addreference!, the field other will only store varinfo instances.

end

TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)

"""
LibtaskModel{F}

Expand All @@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
function AdvancedPS.LibtaskModel(
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
) # Changed the API, need to take care of the RNG properly
return AdvancedPS.LibtaskModel(
f,
Libtask.TapedTask(
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
),
)
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
end

"""
Expand All @@ -51,27 +57,31 @@ end

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng)
AdvancedPS.inc_counter!(t.rng)
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
taped_globals = trace.model.ctask.taped_globals
rng = taped_globals.rng
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
AdvancedPS.inc_counter!(rng)

Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
trace.rng = rng

# Move to next step
return Libtask.consume(t.model.ctask)
return Libtask.consume(trace.model.ctask)
end

# create a backward reference in task_local_storage
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:__trace] = trace

function AdvancedPS.addreference!(task::Libtask.TapedTask, trace::LibtaskTrace)
rng = task.taped_globals.rng
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace))
return task
end

function AdvancedPS.update_rng!(trace::LibtaskTrace)
rng, = trace.model.ctask.args
trace.rng = rng
taped_globals = trace.model.ctask.taped_globals
new_rng = deepcopy(taped_globals.rng)
trace.rng = new_rng
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
return trace
end

Expand All @@ -81,26 +91,23 @@ function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
AdvancedPS.update_rng!(newtrace)
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)

# add backward reference
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end

# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function AdvancedPS.forkr(trace::LibtaskTrace)
taped_globals = trace.model.ctask.taped_globals
rng = taped_globals.rng
newf = AdvancedPS.reset_model(trace.model.f)
Random123.set_counter!(trace.rng, 1)
Random123.set_counter!(rng, 1)
trace.rng = rng

ctask = Libtask.TapedTask(
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
)
ctask = Libtask.TapedTask(TapedGlobals(rng, taped_globals.other), newf)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
AdvancedPS.addreference!(ctask.task, newtrace)
newtrace = AdvancedPS.Trace(new_tapedmodel, rng)
AdvancedPS.gen_refseed!(newtrace)
return newtrace
end
Expand All @@ -113,11 +120,12 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
"""
function AdvancedPS.observe(dist::Distributions.Distribution, x)
return Libtask.produce(Distributions.loglikelihood(dist, x))
Libtask.produce(Distributions.loglikelihood(dist, x))
return nothing
end

"""
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModel
AbstractMCMC interface. We need libtask to sample from arbitrary callable AbstractModelext
"""

function AbstractMCMC.step(
Expand All @@ -138,7 +146,6 @@ function AbstractMCMC.step(
else
trng = AdvancedPS.TracedRNG()
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
trace
end
end
Expand All @@ -153,8 +160,7 @@ function AbstractMCMC.step(
newtrajectory = rand(rng, particles)

replayed = AdvancedPS.replay(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence),
AdvancedPS.PGState(newtrajectory)
return AdvancedPS.PGSample(replayed.model.f, logevidence), AdvancedPS.PGState(replayed)
end

function AbstractMCMC.sample(
Expand All @@ -176,7 +182,6 @@ function AbstractMCMC.sample(
traces = map(1:(sampler.nparticles)) do i
trng = AdvancedPS.TracedRNG()
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
trace
end

Expand Down
2 changes: 0 additions & 2 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ function observe end
function replay end
function addreference! end

current_trace() = current_task().storage[:__trace]

# We need this one to be visible outside of the extension for dispatching (Turing.jl).
struct LibtaskModel{F,T}
f::F
Expand Down
4 changes: 2 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ GaussianDistributions = "0.5"
Kalman = "0.1"
HypothesisTests = "0.11"
DynamicIterators = "0.4"
Libtask = "0.8"
Libtask = "0.9"
Random123 = "1.3"
StableRNGs = "1"
julia = "1.3"
julia = "1.10.8"
11 changes: 6 additions & 5 deletions test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@
val::Ref{Int}
end

function (model::Model)(rng::Random.AbstractRNG)
function (model::Model)()
t = [0]
while true
model.val[] += 1
Expand All @@ -148,17 +148,18 @@
@test consume(a.model.ctask) == 4
end

@testset "current trace" begin
@testset "Back-reference" begin
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end

function (model::TaskIdModel)(rng::Random.AbstractRNG)
function (model::TaskIdModel)()
# Just print the task it's running in
id = objectid(AdvancedPS.current_trace())
trace = Libtask.get_taped_globals(Any).other
id = objectid(trace)
return Libtask.produce(id)
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask.task, trace)
AdvancedPS.addreference!(trace.model.ctask, trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end
Expand Down
23 changes: 16 additions & 7 deletions test/smc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@
NormalModel() = new()
end

function (m::NormalModel)(rng::Random.AbstractRNG)
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

# First observation.
AdvancedPS.observe(Normal(a, 2), 3)

# Second latent variable.
rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))

# Second observation.
Expand All @@ -52,8 +54,11 @@
FailSMCModel() = new()
end

function (m::FailSMCModel)(rng::Random.AbstractRNG)
function (m::FailSMCModel)()
rng = Libtask.get_taped_globals(Any).rng
m.a = a = rand(rng, Normal(4, 5))

rng = Libtask.get_taped_globals(Any).rng
m.b = b = rand(rng, Normal(a, 1))
if a >= 4
AdvancedPS.observe(Normal(b, 2), 1.5)
Expand All @@ -75,8 +80,9 @@
TestModel() = new()
end

function (m::TestModel)(rng::Random.AbstractRNG)
function (m::TestModel)()
# First hidden variables.
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal(0, 1))
m.x = x = rand(rng, Bernoulli(1))
m.b = rand(rng, Gamma(2, 3))
Expand All @@ -85,13 +91,14 @@
AdvancedPS.observe(Bernoulli(x / 2), 1)

# Second hidden variable.
rng = Libtask.get_taped_globals(Any).rng
m.c = rand(rng, Beta())

# Second observation.
return AdvancedPS.observe(Bernoulli(x / 2), 0)
end

chains_smc = sample(TestModel(), AdvancedPS.SMC(100))
chains_smc = sample(TestModel(), AdvancedPS.SMC(100); progress=false)

@test all(isone(particle.x) for particle in chains_smc.trajectories)
@test chains_smc.logevidence ≈ -2 * log(2)
Expand Down Expand Up @@ -145,7 +152,7 @@
return AdvancedPS.observe(Bernoulli(x / 2), 0)
end

chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100)
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100; progress=false)

@test all(isone(p.trajectory.x) for p in chains_pg)
@test mean(x.logevidence for x in chains_pg) ≈ -2 * log(2) atol = 0.01
Expand All @@ -159,16 +166,18 @@
DummyModel() = new()
end

function (m::DummyModel)(rng)
function (m::DummyModel)()
rng = Libtask.get_taped_globals(Any).rng
m.a = rand(rng, Normal())
AdvancedPS.observe(Normal(), m.a)

rng = Libtask.get_taped_globals(Any).rng
m.b = rand(rng, Normal())
return AdvancedPS.observe(Normal(), m.b)
end

pg = AdvancedPS.PG(1)
first, second = sample(DummyModel(), pg, 2)
first, second = sample(DummyModel(), pg, 2; progress=false)

first_model = first.trajectory
second_model = second.trajectory
Expand Down
Loading