diff --git a/Project.toml b/Project.toml index 7cf44203..2bde35e8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AdvancedPS" uuid = "576499cb-2369-40b2-a588-c64705576edc" authors = ["TuringLang"] -version = "0.6.2" +version = "0.7" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -21,13 +21,13 @@ AdvancedPSLibtaskExt = "Libtask" [compat] AbstractMCMC = "2, 3, 4, 5" Distributions = "0.23, 0.24, 0.25" -Libtask = "0.8" +Libtask = "0.9.2" Random = "<0.0.1, 1" Random123 = "1.3" Requires = "1.0" StatsFuns = "0.9, 1" SSMProblems = "0.5" -julia = "1.7" +julia = "1.10.8" [extras] Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index f4029282..d3ba9e02 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -16,6 +16,17 @@ else using ..Libtask: Libtask end +# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG, +# and sometimes both an RNG and other information. In Turing.jl the other information +# is a VarInfo. This struct puts those in a single struct. Note the abstract type of +# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway. +struct TapedGlobals{RngType} + rng::RngType + other::Any +end + +TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing) + """ LibtaskModel{F} @@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`. function AdvancedPS.LibtaskModel( f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... ) # Changed the API, need to take care of the RNG properly - return AdvancedPS.LibtaskModel( - f, - Libtask.TapedTask( - f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)} - ), - ) + return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...)) end """ @@ -43,64 +49,76 @@ end const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} -function AdvancedPS.Trace( - model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... -) - return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng) +function Base.copy(trace::LibtaskTrace) + newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng)) + set_other_global!(newtrace, newtrace) + return newtrace end -# step to the next observe statement and -# return the log probability of the transition (or nothing if done) -function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false) - isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng) - AdvancedPS.inc_counter!(t.rng) - - # Move to next step - return Libtask.consume(t.model.ctask) +"""Get the RNG from a `LibtaskTrace`.""" +function get_rng(trace::LibtaskTrace) + return trace.model.ctask.taped_globals.rng end -# create a backward reference in task_local_storage -function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace) - if task.storage === nothing - task.storage = IdDict() - end - task.storage[:__trace] = trace +"""Set the RNG for a `LibtaskTrace`.""" +function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG) + other = get_other_global(trace) + Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other)) + trace.rng = rng + return trace +end - return task +"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG.""" +function set_other_global!(trace::LibtaskTrace, other) + rng = get_rng(trace) + Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other)) + return trace end -function AdvancedPS.update_rng!(trace::LibtaskTrace) - rng, = trace.model.ctask.args - trace.rng = rng +"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG.""" +get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other + +function AdvancedPS.Trace( + model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... +) + trace = AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng) + # Set a backreference so that the TapedTask in `trace` stores the `trace` itself in its + # taped globals. + set_other_global!(trace, trace) return trace end +# step to the next observe statement and +# return the log probability of the transition (or nothing if done) +function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false) + rng = get_rng(trace) + isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng) + AdvancedPS.inc_counter!(rng) + # Move to next step + return Libtask.consume(trace.model.ctask) +end + # Task copying version of fork for Trace. function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false) newtrace = copy(trace) - AdvancedPS.update_rng!(newtrace) + set_rng!(newtrace, deepcopy(get_rng(newtrace))) isref && AdvancedPS.delete_retained!(newtrace.model.f) isref && delete_seeds!(newtrace) - - # add backward reference - AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) return newtrace end # PG requires keeping all randomness for the reference particle # Create new task and copy randomness function AdvancedPS.forkr(trace::LibtaskTrace) + rng = get_rng(trace) newf = AdvancedPS.reset_model(trace.model.f) - Random123.set_counter!(trace.rng, 1) + Random123.set_counter!(rng, 1) - ctask = Libtask.TapedTask( - newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)} - ) + ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf) new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) # add backward reference - newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng) - AdvancedPS.addreference!(ctask.task, newtrace) + newtrace = AdvancedPS.Trace(new_tapedmodel, rng) AdvancedPS.gen_refseed!(newtrace) return newtrace end @@ -113,7 +131,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing Observe sample `x` from distribution `dist` and yield its log-likelihood value. """ function AdvancedPS.observe(dist::Distributions.Distribution, x) - return Libtask.produce(Distributions.loglikelihood(dist, x)) + Libtask.produce(Distributions.loglikelihood(dist, x)) + return nothing end """ @@ -138,7 +157,6 @@ function AbstractMCMC.step( else trng = AdvancedPS.TracedRNG() trace = AdvancedPS.Trace(deepcopy(model), trng) - AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ? trace end end @@ -153,8 +171,7 @@ function AbstractMCMC.step( newtrajectory = rand(rng, particles) replayed = AdvancedPS.replay(newtrajectory) - return AdvancedPS.PGSample(replayed.model.f, logevidence), - AdvancedPS.PGState(newtrajectory) + return AdvancedPS.PGSample(replayed.model.f, logevidence), AdvancedPS.PGState(replayed) end function AbstractMCMC.sample( @@ -176,7 +193,6 @@ function AbstractMCMC.sample( traces = map(1:(sampler.nparticles)) do i trng = AdvancedPS.TracedRNG() trace = AdvancedPS.Trace(deepcopy(model), trng) - AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ? trace end diff --git a/src/model.jl b/src/model.jl index c836c74f..1b89198d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -62,9 +62,6 @@ end # in an extension, we just define dummy in the main module and implement them in the extension. function observe end function replay end -function addreference! end - -current_trace() = current_task().storage[:__trace] # We need this one to be visible outside of the extension for dispatching (Turing.jl). struct LibtaskModel{F,T} diff --git a/src/rng.jl b/src/rng.jl index 86109299..55e05da2 100644 --- a/src/rng.jl +++ b/src/rng.jl @@ -118,5 +118,3 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n Increase the model step counter by `n` """ inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n - -function update_rng! end diff --git a/test/Project.toml b/test/Project.toml index cac881c2..44c865ad 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -19,7 +19,7 @@ GaussianDistributions = "0.5" Kalman = "0.1" HypothesisTests = "0.11" DynamicIterators = "0.4" -Libtask = "0.8" +Libtask = "0.9" Random123 = "1.3" StableRNGs = "1" -julia = "1.3" +julia = "1.10.8" diff --git a/test/container.jl b/test/container.jl index f5434d61..07308e21 100644 --- a/test/container.jl +++ b/test/container.jl @@ -123,7 +123,7 @@ val::Ref{Int} end - function (model::Model)(rng::Random.AbstractRNG) + function (model::Model)() t = [0] while true model.val[] += 1 @@ -148,17 +148,17 @@ @test consume(a.model.ctask) == 4 end - @testset "current trace" begin + @testset "Back-reference" begin struct TaskIdModel <: AdvancedPS.AbstractGenericModel end - function (model::TaskIdModel)(rng::Random.AbstractRNG) + function (model::TaskIdModel)() # Just print the task it's running in - id = objectid(AdvancedPS.current_trace()) + trace = Libtask.get_taped_globals(Any).other + id = objectid(trace) return Libtask.produce(id) end trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG()) - AdvancedPS.addreference!(trace.model.ctask.task, trace) @test AdvancedPS.advance!(trace, false) === objectid(trace) end diff --git a/test/smc.jl b/test/smc.jl index 1fb4392c..e33d3716 100644 --- a/test/smc.jl +++ b/test/smc.jl @@ -28,14 +28,16 @@ NormalModel() = new() end - function (m::NormalModel)(rng::Random.AbstractRNG) + function (m::NormalModel)() # First latent variable. + rng = Libtask.get_taped_globals(Any).rng m.a = a = rand(rng, Normal(4, 5)) # First observation. AdvancedPS.observe(Normal(a, 2), 3) # Second latent variable. + rng = Libtask.get_taped_globals(Any).rng m.b = b = rand(rng, Normal(a, 1)) # Second observation. @@ -52,8 +54,11 @@ FailSMCModel() = new() end - function (m::FailSMCModel)(rng::Random.AbstractRNG) + function (m::FailSMCModel)() + rng = Libtask.get_taped_globals(Any).rng m.a = a = rand(rng, Normal(4, 5)) + + rng = Libtask.get_taped_globals(Any).rng m.b = b = rand(rng, Normal(a, 1)) if a >= 4 AdvancedPS.observe(Normal(b, 2), 1.5) @@ -75,8 +80,9 @@ TestModel() = new() end - function (m::TestModel)(rng::Random.AbstractRNG) + function (m::TestModel)() # First hidden variables. + rng = Libtask.get_taped_globals(Any).rng m.a = rand(rng, Normal(0, 1)) m.x = x = rand(rng, Bernoulli(1)) m.b = rand(rng, Gamma(2, 3)) @@ -85,13 +91,14 @@ AdvancedPS.observe(Bernoulli(x / 2), 1) # Second hidden variable. + rng = Libtask.get_taped_globals(Any).rng m.c = rand(rng, Beta()) # Second observation. return AdvancedPS.observe(Bernoulli(x / 2), 0) end - chains_smc = sample(TestModel(), AdvancedPS.SMC(100)) + chains_smc = sample(TestModel(), AdvancedPS.SMC(100); progress=false) @test all(isone(particle.x) for particle in chains_smc.trajectories) @test chains_smc.logevidence ≈ -2 * log(2) @@ -145,7 +152,7 @@ return AdvancedPS.observe(Bernoulli(x / 2), 0) end - chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100) + chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100; progress=false) @test all(isone(p.trajectory.x) for p in chains_pg) @test mean(x.logevidence for x in chains_pg) ≈ -2 * log(2) atol = 0.01 @@ -159,16 +166,18 @@ DummyModel() = new() end - function (m::DummyModel)(rng) + function (m::DummyModel)() + rng = Libtask.get_taped_globals(Any).rng m.a = rand(rng, Normal()) AdvancedPS.observe(Normal(), m.a) + rng = Libtask.get_taped_globals(Any).rng m.b = rand(rng, Normal()) return AdvancedPS.observe(Normal(), m.b) end pg = AdvancedPS.PG(1) - first, second = sample(DummyModel(), pg, 2) + first, second = sample(DummyModel(), pg, 2; progress=false) first_model = first.trajectory second_model = second.trajectory