Skip to content

New libtask interface #114

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Jun 23, 2025
Merged

New libtask interface #114

merged 21 commits into from
Jun 23, 2025

Conversation

FredericWantiez
Copy link
Member

@FredericWantiez FredericWantiez commented Mar 23, 2025

Integrate refactor from TuringLang/Libtask.jl#179

Two things worth noting:

  1. Dealing with the RNG will be the user's responsability. Before
mutable struct Model <: AdvancedPS.AbstractGenericModel
  mu::Float64
  sig::Float64

  Model() = new()
end


function (model::Model)(rng::Random.AbstractRNG)
  model.sig = rand(rng, Beta(1, 1))  # AdvancedPS took care of syncing these
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

and now:

function (model::Model)()
  rng = Libtask.get_dynamic_scope() # We now need to query the RNG explicitly
  model.sig = rand(rng, Beta(1, 1))
  Libtask.produce(model.sig)

  rng = Libtask.get_dynamic_scope() # and do it everytime we want to sample random values
  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end
  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs
    function AdvancedPS.forkr(trace::LibtaskTrace)
    newf = AdvancedPS.reset_model(trace.model.ctask.fargs[1])
    Random123.set_counter!(trace.rng, 1)

@willtebbutt

@FredericWantiez FredericWantiez changed the title New libtask interface [WIP] New libtask interface Mar 23, 2025
@willtebbutt
Copy link
Member

Thanks for having a look at this!

  1. Dealing with the RNG will be the user's responsability. Before

Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)

  1. How do we keep track of model state between tasks ? Pretty sure we don't want to look inside tapedtask.fargs

I agree re not wanting ot dig into tapedtask.fargs. Could you elaborate a little bit on what is required here? My understanding was that task copying would handle this -- i.e. when you copy a task, all references to the model get updated, so from the perspective of the code inside the task, things just continue as normal.

As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do.

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Mar 25, 2025

  1. We can drop this one, that really only applies when AdvancedPS is used with Libtask outside of Turing. We will probably sunset that (or target people who supposedly know enough about Libtask)

  2. Still not 100% sure about Turing but we need something like this to manage the reference particle in the Particle Gibbs loop. Here's a mvp that should replicate a simple loop of the algo:

using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems


mutable struct Model <: AdvancedPS.AbstractGenericModel
  x::Float64
  y::Float64

  Model() = new()
end


function (model::Model)()
  rng = Libtask.get_dynamic_scope()
  model.x = rand(rng, Beta(1,1))
  Libtask.produce(model.x)

  rng = Libtask.get_dynamic_scope()
  model.y = rand(rng, Normal(0, model.x))
  Libtask.produce(model.y)
end

rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()

trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)

trace2 = AdvancedPS.fork(trace)

key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)

Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])

# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)

println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])

# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory, 
# sampled values are properly captured during the execution of the task

@yebai
Copy link
Member

yebai commented Mar 26, 2025

println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the Model
Note, this is only a problem when creating a reference trajectory,

@FredericWantiez can we store trace.rng inside TapedTask instead of trace? That way, when copying a TapedTask, we will copy the trace.rng.

@FredericWantiez
Copy link
Member Author

@willtebbutt I think 2) might also be a problem for Turing, when looking at this part:
https://github.com/TuringLang/Turing.jl/blob/afb5c44d6dc1736831f45620328c9d5681748111/src/mcmc/particle_mcmc.jl#L140-L142

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 5, 2025

Two small issues I found cleaning up the tests.

Libtask returns a value after the last produce statement:

function f()
  Libtask.produce(1)
  Libtask.produce(2)
end

t1 = TapedTask(nothing, f)
consume(t1)  # 1
consume(t1)  # 2
consume(t2)  # 2 (?)

Libtask doesn't catch some of the produce statements:

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
  a::Float64
  b::Float64

  NormalModel() = new()
end

function (m::NormalModel)()
  # First latent variable.
  rng = Libtask.get_dynamic_scope()
  m.a = a = rand(rng, Normal(4, 5))

  # First observation.
  AdvancedPS.observe(Normal(a, 2), 3)

  # Second latent variable.
  rng = Libtask.get_dynamic_scope()
  m.b = b = rand(rng, Normal(a, 1))

  # Second observation.
  AdvancedPS.observe(Normal(b, 2), 1.5)
  return nothing
end

rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())

consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?)

this works fine if I call Libtask.produce explicitly instead of observe

EDIT: Changing observe to something like this seems to work:

function AdvancedPS.observe(dist::Distributions.Distribution, x)
    Libtask.produce(Distributions.loglikelihood(dist, x))
    return nothing
end

@yebai
Copy link
Member

yebai commented Apr 8, 2025

If we store both rng and varinfo in the scoped variable, then the following suggestions will address (2):

  • store varinfo in the Trace struct, then change here to Libtask.set_dynamic_scope!(trace.model.ctask, (trace.rng, trace.varinfo))
  • change here and here to rng, varinfo = Libtask.get_dynamic_scope()
  • change here to transition = SMCTransition(model, particle.varinfo, weight)

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 8, 2025

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The other solution is to use one replay step before the transition, to repopulate the varinfo properly:

    new_particle = AdvancedPS.replay(particle)
    transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
    state = SMCState(particles, 2, logevidence)
    return transition, state

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 8, 2025

@willtebbutt running models against this PR I see a large performance drop:

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    rng = Libtask.get_dynamic_scope()
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.get_dynamic_scope()
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

On master:

1.816623 seconds (5.92 M allocations: 311.647 MiB, 1.52% gc time, 96.09% compilation time)

On this PR:

72.085056 seconds (369.62 M allocations: 17.322 GiB, 2.83% gc time, 77.21% compilation time)

@willtebbutt
Copy link
Member

Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance.

@yebai
Copy link
Member

yebai commented Apr 9, 2025

That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct.

The varinfo variable is updated during inference. I think we have to carefully ensure the correct varinfo is stored in the scoped variable.

cc @mhauru @FredericWantiez

@willtebbutt
Copy link
Member

@willtebbutt running models against this PR I see a large performance drop:

@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is

(jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
  [576499cb] AdvancedPS v0.6.1
  [31c24e10] Distributions v0.25.118
  [6f1fad26] Libtask v0.8.8
  [9a3f8284] Random v1.11.0

I tried it on LTS and 1.11.4.

In particular, I'm seeing the error:

ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
  [1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
    @ Base ./essentials.jl:14
  [2] getindex
    @ ./essentials.jl:916 [inlined]
  [3] _infer(f::NormalModel, args_type::Tuple{DataType})
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
  [4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
    @ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
  [5] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
  [6] _
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [7] TapedFunction
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
  [8] #TapedTask#15
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
  [9] TapedTask
    @ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
 [10] LibtaskModel
    @ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
 [11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
 [12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
 [13] iterate
    @ ./generator.jl:48 [inlined]
 [14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
    @ Base ./array.jl:811
 [15] collect_similar
    @ ./array.jl:720 [inlined]
 [16] map
    @ ./abstractarray.jl:3371 [inlined]
 [17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
    @ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
 [18] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
 [19] macro expansion
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
 [20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
    @ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
 [21] mcmcsample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
 [22] #sample#20
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
 [23] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
 [24] #sample#19
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
 [25] sample
    @ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
 [26] macro expansion
    @ ./timing.jl:581 [inlined]
 [27] top-level scope
    @ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types.

Any idea whether I'm doing something wrong?

@willtebbutt
Copy link
Member

But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to Libtask.get_dynamic_scope, which should be the type of the thing that it's going to return. We need this because there's no way to make the container typed (I assume that the previous implementation had a similar limitation). The docstring has been updated to reflect the changes.

@FredericWantiez
Copy link
Member Author

FredericWantiez commented Apr 15, 2025

@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that:

function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
  model.sig = rand(rng, Beta(1, 1)) 
  Libtask.produce(model.sig)

  model.mu = rand(rng, Normal())
  Libtask.produce(model.mu)
end

@yebai
Copy link
Member

yebai commented Apr 22, 2025

This now runs faster with AdvancedPS (dc5e594) and Libtask (8e7f784)

# run once to triger compilation 
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
  2.986750 seconds (7.31 M allocations: 380.449 MiB, 0.88% gc time, 99.51% compilation time)

# second time runs faster
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
  0.012714 seconds (32.85 k allocations: 18.581 MiB, 19.87% gc time)
Code
(@temp) pkg> add AdvancedPS#fred/libtask-revamp
(@temp) pkg> add Libtask#wct/refactor

using Libtask
using AdvancedPS
using Distributions
using Random

mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
    a::Float64
    b::Float64

    NormalModel() = new()
end

function (m::NormalModel)()
    # First latent variable.
    T = AdvancedPS.TracedRNG{UInt64, 1, AdvancedPS.Random123.Philox2x{UInt64, 10}}; 
    rng = Libtask.Libtask.get_taped_globals(T)
    m.a = a = rand(rng, Normal(4, 5))

    # First observation.
    AdvancedPS.observe(Normal(a, 2), 3)

    # Second latent variable.
    rng = Libtask.Libtask.get_taped_globals(T)
    m.b = b = rand(rng, Normal(a, 1))

    # Second observation.
    AdvancedPS.observe(Normal(b, 2), 1.5)
end

@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)

@yebai yebai changed the title [WIP] New libtask interface New libtask interface May 9, 2025
@yebai
Copy link
Member

yebai commented May 9, 2025

New Libtask has now been released. CI failures reveal that a few more fixes are required from AdvancedPS.

Copy link
Contributor

github-actions bot commented May 9, 2025

AdvancedPS.jl documentation for PR #114 is available at:
https://TuringLang.github.io/AdvancedPS.jl/previews/PR114/

Copy link

codecov bot commented Jun 6, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Please upload report for BASE (main@1ad89ec). Learn more about missing BASE report.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #114   +/-   ##
=======================================
  Coverage        ?   96.27%           
=======================================
  Files           ?        8           
  Lines           ?      429           
  Branches        ?        0           
=======================================
  Hits            ?      413           
  Misses          ?       16           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

Could I get some reviews on this please? I've now got some of Turing's ParticleGibbs tests passing locally using this, which gives me more confidence that it's largely correct (one bug was found in the process). Making a release of this would help run the full test suite of Turing.jl and see if anything else comes up.

@mhauru mhauru requested a review from yebai June 19, 2025 14:30
@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

I would request @FredericWantiez's review, but you are the PR owner so I can't.

rng = task.taped_globals.rng
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace))
return task
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that addreference! is no longer needed, given that it stores a self-reference? If so, can we remove it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last bug that I had here was in fact a missing call to addreference!, I think in fork. We also need to call it in Turing's particle_mcmc, to keep the reference in sync so that we can access the varinfo of the trace from within the TapedTask. There's probably a way to get rid of it, but that would require some refactoring, which would require me learning better what is going on. If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?

I suggest removing it from this PR since addreference! adds another layer of indirection. Given that this PR is already a breaking release and requires updates to Turing, I think it is worth it.

More context:

I think we are safe if we always store and retrieve (rng, varinfo) via set_taped_globals! / get_taped_globals. The entire motivation for addreference! is to be able to copy varinfo external to a live particle (e.g. during the resampling step, multiple child particles share the same varinfo). It should produce correct results if particles always retrieve varinfo via get_taped_global. Otherwise, something weird is happening.

@yebai
Copy link
Member

yebai commented Jun 19, 2025

Thanks @mhauru, I did a careful pass. It looks correct to me. Let's simplify more here so the code becomes less mysterious.

@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

Thanks @yebai. Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work.

@yebai
Copy link
Member

yebai commented Jun 19, 2025

Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work.

No worries, I am probably the one to blame for the messy code. My suggestions only encourage you to do more to simplify the code now that some heuristics have become unnecessary.

# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
struct TapedGlobals{RngType}
rng::RngType
other::Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove addreference!, the field other will only store varinfo instances.

@mhauru
Copy link
Member

mhauru commented Jun 19, 2025

I've removed the addreference! function by having the backreference added when a LibtaskTrace object is created. This makes it harder to forget to add it, and makes our interface simpler. The backreference itself still remains, because we use it to access the VarInfo of the trace. While it's true that we store the trace in the TapedGlobals.other field only to have access to trace.model.f.varinfo, we unfortunately can't simplify that by storing the VarInfo directly, because Turing isn't a dependency of AdvancedPS and thus we have no visibility into things like the structure of a DPPL.Model object.

Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me!

Agreed that other suggested simplification work (e.g. storing varinfo directly in tapedglobals) can be done after we upstream this package to Turing.

@yebai yebai merged commit ee1052d into main Jun 23, 2025
7 of 9 checks passed
@yebai yebai deleted the fred/libtask-revamp branch June 23, 2025 20:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants