Skip to content

Commit 6dff5f8

Browse files
committed
Fix addreference!, I think
1 parent 41125c4 commit 6dff5f8

File tree

4 files changed

+42
-31
lines changed

4 files changed

+42
-31
lines changed

ext/AdvancedPSLibtaskExt.jl

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ else
1616
using ..Libtask: Libtask
1717
end
1818

19+
# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20+
# and sometimes both an RNG and other information. In Turing.jl this other information
21+
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22+
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23+
struct TapedGlobals{RngType}
24+
rng::RngType
25+
other::Any
26+
end
27+
28+
TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)
29+
1930
"""
2031
LibtaskModel{F}
2132
@@ -24,7 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2435
function AdvancedPS.LibtaskModel(
2536
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
2637
) # Changed the API, need to take care of the RNG properly
27-
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(rng, f, args...))
38+
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
2839
end
2940

3041
"""
@@ -47,30 +58,30 @@ end
4758
# step to the next observe statement and
4859
# return the log probability of the transition (or nothing if done)
4960
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
50-
# Where is the RNG ?
51-
isref ? AdvancedPS.load_state!(trace.rng) : AdvancedPS.save_state!(trace.rng)
52-
AdvancedPS.inc_counter!(trace.rng)
61+
taped_globals = trace.model.ctask.taped_globals
62+
rng = taped_globals.rng
63+
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
64+
AdvancedPS.inc_counter!(rng)
5365

54-
Libtask.set_taped_globals!(trace.model.ctask, trace.rng)
66+
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, taped_globals.other))
67+
trace.rng = rng
5568

5669
# Move to next step
5770
return Libtask.consume(trace.model.ctask)
5871
end
5972

6073
# create a backward reference in task_local_storage
61-
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
62-
if task.storage === nothing
63-
task.storage = IdDict()
64-
end
65-
task.storage[:__trace] = trace
66-
74+
function AdvancedPS.addreference!(task::Libtask.TapedTask, trace::LibtaskTrace)
75+
rng = task.taped_globals.rng
76+
Libtask.set_taped_globals!(task, TapedGlobals(rng, trace))
6777
return task
6878
end
6979

7080
function AdvancedPS.update_rng!(trace::LibtaskTrace)
71-
new_rng = deepcopy(trace.rng)
81+
taped_globals = trace.model.ctask.taped_globals
82+
new_rng = deepcopy(taped_globals.rng)
7283
trace.rng = new_rng
73-
Libtask.set_taped_globals!(trace.model.ctask, trace.rng)
84+
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(new_rng, taped_globals.other))
7485
return trace
7586
end
7687

@@ -86,17 +97,18 @@ end
8697
# PG requires keeping all randomness for the reference particle
8798
# Create new task and copy randomness
8899
function AdvancedPS.forkr(trace::LibtaskTrace)
100+
taped_globals = trace.model.ctask.taped_globals
101+
rng = taped_globals.rng
89102
newf = AdvancedPS.reset_model(trace.model.f)
90-
Random123.set_counter!(trace.rng, 1)
103+
Random123.set_counter!(rng, 1)
104+
trace.rng = rng
91105

92-
ctask = Libtask.TapedTask(trace.rng, newf)
106+
ctask = Libtask.TapedTask(TapedGlobals(rng, taped_globals.other), newf)
93107
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
94108

95109
# add backward reference
96-
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
110+
newtrace = AdvancedPS.Trace(new_tapedmodel, rng)
97111
AdvancedPS.gen_refseed!(newtrace)
98-
99-
Libtask.set_taped_globals!(ctask, trace.rng) # Sync trace and rng
100112
return newtrace
101113
end
102114

src/model.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ function observe end
6464
function replay end
6565
function addreference! end
6666

67-
current_trace() = current_task().storage[:__trace]
68-
6967
# We need this one to be visible outside of the extension for dispatching (Turing.jl).
7068
struct LibtaskModel{F,T}
7169
f::F

test/container.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,17 +148,18 @@
148148
@test consume(a.model.ctask) == 4
149149
end
150150

151-
@testset "current trace" begin
151+
@testset "Back-reference" begin
152152
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end
153153

154154
function (model::TaskIdModel)()
155155
# Just print the task it's running in
156-
id = objectid(AdvancedPS.current_trace())
156+
trace = Libtask.get_taped_globals(Any).other
157+
id = objectid(trace)
157158
return Libtask.produce(id)
158159
end
159160

160161
trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
161-
AdvancedPS.addreference!(trace.model.ctask.task, trace)
162+
AdvancedPS.addreference!(trace.model.ctask, trace)
162163

163164
@test AdvancedPS.advance!(trace, false) === objectid(trace)
164165
end

test/smc.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@
3030

3131
function (m::NormalModel)()
3232
# First latent variable.
33-
rng = Libtask.get_taped_globals(Any)
33+
rng = Libtask.get_taped_globals(Any).rng
3434
m.a = a = rand(rng, Normal(4, 5))
3535

3636
# First observation.
3737
AdvancedPS.observe(Normal(a, 2), 3)
3838

3939
# Second latent variable.
40-
rng = Libtask.get_taped_globals(Any)
40+
rng = Libtask.get_taped_globals(Any).rng
4141
m.b = b = rand(rng, Normal(a, 1))
4242

4343
# Second observation.
@@ -55,10 +55,10 @@
5555
end
5656

5757
function (m::FailSMCModel)()
58-
rng = Libtask.get_taped_globals(Any)
58+
rng = Libtask.get_taped_globals(Any).rng
5959
m.a = a = rand(rng, Normal(4, 5))
6060

61-
rng = Libtask.get_taped_globals(Any)
61+
rng = Libtask.get_taped_globals(Any).rng
6262
m.b = b = rand(rng, Normal(a, 1))
6363
if a >= 4
6464
AdvancedPS.observe(Normal(b, 2), 1.5)
@@ -82,7 +82,7 @@
8282

8383
function (m::TestModel)()
8484
# First hidden variables.
85-
rng = Libtask.get_taped_globals(Any)
85+
rng = Libtask.get_taped_globals(Any).rng
8686
m.a = rand(rng, Normal(0, 1))
8787
m.x = x = rand(rng, Bernoulli(1))
8888
m.b = rand(rng, Gamma(2, 3))
@@ -91,7 +91,7 @@
9191
AdvancedPS.observe(Bernoulli(x / 2), 1)
9292

9393
# Second hidden variable.
94-
rng = Libtask.get_taped_globals(Any)
94+
rng = Libtask.get_taped_globals(Any).rng
9595
m.c = rand(rng, Beta())
9696

9797
# Second observation.
@@ -167,11 +167,11 @@
167167
end
168168

169169
function (m::DummyModel)()
170-
rng = Libtask.get_taped_globals(Any)
170+
rng = Libtask.get_taped_globals(Any).rng
171171
m.a = rand(rng, Normal())
172172
AdvancedPS.observe(Normal(), m.a)
173173

174-
rng = Libtask.get_taped_globals(Any)
174+
rng = Libtask.get_taped_globals(Any).rng
175175
m.b = rand(rng, Normal())
176176
return AdvancedPS.observe(Normal(), m.b)
177177
end

0 commit comments

Comments
 (0)