-
Notifications
You must be signed in to change notification settings - Fork 12
New libtask interface #114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for having a look at this!
Does this have any implications for integration with Turing.jl? i.e. does not passing in a RNG to the model cause any trouble downstream? (to be clear, I have no idea -- I'm not suggesting that it does / doesn't in particular)
I agree re not wanting ot dig into As with the first item, I'm not sure exactly what the requirements are here, so I may have misunderstood something basic about what you need to do. |
using AdvancedPS
using Libtask
using Random
using Distributions
using SSMProblems
mutable struct Model <: AdvancedPS.AbstractGenericModel
x::Float64
y::Float64
Model() = new()
end
function (model::Model)()
rng = Libtask.get_dynamic_scope()
model.x = rand(rng, Beta(1,1))
Libtask.produce(model.x)
rng = Libtask.get_dynamic_scope()
model.y = rand(rng, Normal(0, model.x))
Libtask.produce(model.y)
end
rng = AdvancedPS.TracedRNG()
Random.seed!(rng, 10)
model = Model()
trace = AdvancedPS.Trace(model, rng)
# Sample `x`
AdvancedPS.advance!(trace)
trace2 = AdvancedPS.fork(trace)
key = AdvancedPS.state(trace.rng.rng)
seeds = AdvancedPS.split(key, 2)
Random.seed!(trace.rng, seeds[1])
Random.seed!(trace2.rng, seeds[2])
# Inherit `x` across independent particles
AdvancedPS.advance!(trace)
AdvancedPS.advance!(trace2)
println("Parent particle")
println(trace.model.f)
println("Child particle")
println(trace2.model.f)
println("Model with actual sampled values is in ctask.fargs")
println(trace2.model.ctask.fargs[1])
# Create reference particle
# Suppose we select the previous 'child' particle
ref = AdvancedPS.forkr(trace2)
println("Did we keep all the generated values ?")
println(ref.model.f) # If we just copy the tapedtask, we don't get the sampled values in the `Model`
# Note, this is only a problem when creating a reference trajectory,
# sampled values are properly captured during the execution of the task
|
@FredericWantiez can we store |
@willtebbutt I think 2) might also be a problem for Turing, when looking at this part: |
Two small issues I found cleaning up the tests. Libtask returns a value after the last produce statement: function f()
Libtask.produce(1)
Libtask.produce(2)
end
t1 = TapedTask(nothing, f)
consume(t1) # 1
consume(t1) # 2
consume(t2) # 2 (?) Libtask doesn't catch some of the produce statements: mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_dynamic_scope()
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.get_dynamic_scope()
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
return nothing
end
rng = AdvancedPS.TracedRNG()
t = TapedTask(rng, NormalModel())
consume(t) # some float
consume(t) # 0 (?)
consume(t) # 0 (?) this works fine if I call EDIT: Changing function AdvancedPS.observe(dist::Distributions.Distribution, x)
Libtask.produce(Distributions.loglikelihood(dist, x))
return nothing
end |
If we store both
|
That should work, I have a branch against Turing that tries to do this but seems like one copy is not quite correct. The other solution is to use one new_particle = AdvancedPS.replay(particle)
transition = SMCTransition(model, new_particle.model.f.varinfo, weight)
state = SMCState(particles, 2, logevidence)
return transition, state |
@willtebbutt running models against this PR I see a large performance drop: using Libtask
using AdvancedPS
using Distributions
using Random
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
rng = Libtask.get_dynamic_scope()
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.get_dynamic_scope()
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
end
@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false) On master:
On this PR:
|
Thanks for the data point. Essentially the final item on my todo list is sorting out various type inference issues in the current implementation. Once they're done, we should see substantially improved performance. |
The |
@FredericWantiez I'm finally looking at sorting out the performance of the Libtask updates. I'm struggling to replicate the performance of your example on the current versions of packages, because I find that it errors. My environment is (jl_4fXu3W) pkg> st
Status `/private/var/folders/z7/0fkyw8ms795b7znc_3vbvrsw0000gn/T/jl_4fXu3W/Project.toml`
[576499cb] AdvancedPS v0.6.1
[31c24e10] Distributions v0.25.118
[6f1fad26] Libtask v0.8.8
[9a3f8284] Random v1.11.0 I tried it on LTS and 1.11.4. In particular, I'm seeing the error: ERROR: BoundsError: attempt to access 0-element Vector{Any} at index [1]
Stacktrace:
[1] throw_boundserror(A::Vector{Any}, I::Tuple{Int64})
@ Base ./essentials.jl:14
[2] getindex
@ ./essentials.jl:916 [inlined]
[3] _infer(f::NormalModel, args_type::Tuple{DataType})
@ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:45
[4] Libtask.TapedFunction{…}(f::NormalModel, args::AdvancedPS.TracedRNG{…}; cache::Bool, deepcopy_types::Type)
@ Libtask ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:72
[5] TapedFunction
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:62 [inlined]
[6] _
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
[7] TapedFunction
@ ~/.julia/packages/Libtask/bxGQF/src/tapedfunction.jl:80 [inlined]
[8] #TapedTask#15
@ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:76 [inlined]
[9] TapedTask
@ ~/.julia/packages/Libtask/bxGQF/src/tapedtask.jl:70 [inlined]
[10] LibtaskModel
@ ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:27 [inlined]
[11] AdvancedPS.Trace(::NormalModel, ::AdvancedPS.TracedRNG{UInt64, 1, Random123.Philox2x{UInt64, 10}})
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:49
[12] (::AdvancedPSLibtaskExt.var"#2#3"{NormalModel, Nothing, Bool, Int64})(i::Int64)
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:140
[13] iterate
@ ./generator.jl:48 [inlined]
[14] _collect(c::UnitRange{…}, itr::Base.Generator{…}, ::Base.EltypeUnknown, isz::Base.HasShape{…})
@ Base ./array.jl:811
[15] collect_similar
@ ./array.jl:720 [inlined]
[16] map
@ ./abstractarray.jl:3371 [inlined]
[17] step(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, state::Nothing; kwargs::@Kwargs{})
@ AdvancedPSLibtaskExt ~/.julia/packages/AdvancedPS/O1Ftx/ext/AdvancedPSLibtaskExt.jl:134
[18] macro expansion
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:0 [inlined]
[19] macro expansion
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/logging.jl:16 [inlined]
[20] mcmcsample(rng::TaskLocalRNG, model::NormalModel, sampler::AdvancedPS.PG{…}, N::Int64; progress::Bool, progressname::String, callback::Nothing, num_warmup::Int64, discard_initial::Int64, thinning::Int64, chain_type::Type, initial_state::Nothing, kwargs::@Kwargs{})
@ AbstractMCMC ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:142
[21] mcmcsample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:107 [inlined]
[22] #sample#20
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:59 [inlined]
[23] sample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:52 [inlined]
[24] #sample#19
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:21 [inlined]
[25] sample
@ ~/.julia/packages/AbstractMCMC/FSyVk/src/sample.jl:18 [inlined]
[26] macro expansion
@ ./timing.jl:581 [inlined]
[27] top-level scope
@ ./REPL[10]:1
Some type information was truncated. Use `show(err)` to see complete types. Any idea whether I'm doing something wrong? |
But, additionally, the latest version of the PR should address the various performance issues we previously had. There is one important change though: you need to pass a type to |
@willtebbutt if you're testing against the released version of Libtask/AdvancedPS you need to explicitly pass the RNG in the model definition, something like that: function (model::Model)(rng::Random.AbstractRNG) # Add the RNG as argument
model.sig = rand(rng, Beta(1, 1))
Libtask.produce(model.sig)
model.mu = rand(rng, Normal())
Libtask.produce(model.mu)
end |
This now runs faster with AdvancedPS (dc5e594) and Libtask (8e7f784) # run once to triger compilation
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
2.986750 seconds (7.31 M allocations: 380.449 MiB, 0.88% gc time, 99.51% compilation time)
# second time runs faster
julia> @time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false);
0.012714 seconds (32.85 k allocations: 18.581 MiB, 19.87% gc time) Code(@temp) pkg> add AdvancedPS#fred/libtask-revamp
(@temp) pkg> add Libtask#wct/refactor
using Libtask
using AdvancedPS
using Distributions
using Random
mutable struct NormalModel <: AdvancedPS.AbstractGenericModel
a::Float64
b::Float64
NormalModel() = new()
end
function (m::NormalModel)()
# First latent variable.
T = AdvancedPS.TracedRNG{UInt64, 1, AdvancedPS.Random123.Philox2x{UInt64, 10}};
rng = Libtask.Libtask.get_taped_globals(T)
m.a = a = rand(rng, Normal(4, 5))
# First observation.
AdvancedPS.observe(Normal(a, 2), 3)
# Second latent variable.
rng = Libtask.Libtask.get_taped_globals(T)
m.b = b = rand(rng, Normal(a, 1))
# Second observation.
AdvancedPS.observe(Normal(b, 2), 1.5)
end
@time sample(NormalModel(), AdvancedPS.PG(10), 20; progress=false)
|
New Libtask has now been released. CI failures reveal that a few more fixes are required from AdvancedPS. |
AdvancedPS.jl documentation for PR #114 is available at: |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #114 +/- ##
=======================================
Coverage ? 96.27%
=======================================
Files ? 8
Lines ? 429
Branches ? 0
=======================================
Hits ? 413
Misses ? 16
Partials ? 0 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Could I get some reviews on this please? I've now got some of Turing's ParticleGibbs tests passing locally using this, which gives me more confidence that it's largely correct (one bug was found in the process). Making a release of this would help run the full test suite of Turing.jl and see if anything else comes up. |
I would request @FredericWantiez's review, but you are the PR owner so I can't. |
ext/AdvancedPSLibtaskExt.jl
Outdated
rng = task.taped_globals.rng | ||
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace)) | ||
return task | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that addreference!
is no longer needed, given that it stores a self-reference? If so, can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The last bug that I had here was in fact a missing call to addreference!
, I think in fork
. We also need to call it in Turing's particle_mcmc
, to keep the reference in sync so that we can access the varinfo of the trace from within the TapedTask. There's probably a way to get rid of it, but that would require some refactoring, which would require me learning better what is going on. If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there's a plan to merge AdvancedPS into Turing proper, do you think that's worth the effort now?
I suggest removing it from this PR since addreference!
adds another layer of indirection. Given that this PR is already a breaking release and requires updates to Turing, I think it is worth it.
More context:
I think we are safe if we always store and retrieve (rng, varinfo)
via set_taped_globals!
/ get_taped_globals
. The entire motivation for addreference!
is to be able to copy varinfo
external to a live particle (e.g. during the resampling step, multiple child particles share the same varinfo
). It should produce correct results if particles always retrieve varinfo
via get_taped_global
. Otherwise, something weird is happening.
Thanks @mhauru, I did a careful pass. It looks correct to me. Let's simplify more here so the code becomes less mysterious. |
Thanks @yebai. Sorry about the general messiness here, since my understanding of the code is poor I only tried to do the minimal edits needed to get it to work. |
No worries, I am probably the one to blame for the messy code. My suggestions only encourage you to do more to simplify the code now that some heuristics have become unnecessary. |
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway. | ||
struct TapedGlobals{RngType} | ||
rng::RngType | ||
other::Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove addreference!
, the field other
will only store varinfo
instances.
I've removed the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
Agreed that other suggested simplification work (e.g. storing varinfo
directly in tapedglobals
) can be done after we upstream this package to Turing.
Integrate refactor from TuringLang/Libtask.jl#179
Two things worth noting:
and now:
tapedtask.fargs
AdvancedPS.jl/ext/AdvancedPSLibtaskExt.jl
Lines 89 to 91 in 50d493c
@willtebbutt