Skip to content

Commit ee1052d

Browse files
FredericWantiezyebaimhauru
authored
New libtask interface (#114)
* New libtask interface * Format * Tests * Remove internal call * Format * Comment * Update Project.toml * Update AdvancedPSLibtaskExt.jl * Update Project.toml * Update Project.toml * Update Project.toml * Fix calls to get_dynamic_scope * Fix addreference!, I think * Simplify Libtask extension * Small simplification to Libtask extension * Remove addreference! * Update ext/AdvancedPSLibtaskExt.jl * Update ext/AdvancedPSLibtaskExt.jl * Update ext/AdvancedPSLibtaskExt.jl --------- Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Markus Hauru <[email protected]>
1 parent 1ad89ec commit ee1052d

File tree

7 files changed

+85
-65
lines changed

7 files changed

+85
-65
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AdvancedPS"
22
uuid = "576499cb-2369-40b2-a588-c64705576edc"
33
authors = ["TuringLang"]
4-
version = "0.6.2"
4+
version = "0.7"
55

66
[deps]
77
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -21,13 +21,13 @@ AdvancedPSLibtaskExt = "Libtask"
2121
[compat]
2222
AbstractMCMC = "2, 3, 4, 5"
2323
Distributions = "0.23, 0.24, 0.25"
24-
Libtask = "0.8"
24+
Libtask = "0.9.2"
2525
Random = "<0.0.1, 1"
2626
Random123 = "1.3"
2727
Requires = "1.0"
2828
StatsFuns = "0.9, 1"
2929
SSMProblems = "0.5"
30-
julia = "1.7"
30+
julia = "1.10.8"
3131

3232
[extras]
3333
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"

ext/AdvancedPSLibtaskExt.jl

Lines changed: 59 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,17 @@ else
1616
using ..Libtask: Libtask
1717
end
1818

19+
# In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20+
# and sometimes both an RNG and other information. In Turing.jl the other information
21+
# is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22+
# the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23+
struct TapedGlobals{RngType}
24+
rng::RngType
25+
other::Any
26+
end
27+
28+
TapedGlobals(rng::Random.AbstractRNG) = TapedGlobals(rng, nothing)
29+
1930
"""
2031
LibtaskModel{F}
2132
@@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
2435
function AdvancedPS.LibtaskModel(
2536
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
2637
) # Changed the API, need to take care of the RNG properly
27-
return AdvancedPS.LibtaskModel(
28-
f,
29-
Libtask.TapedTask(
30-
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
31-
),
32-
)
38+
return AdvancedPS.LibtaskModel(f, Libtask.TapedTask(TapedGlobals(rng), f, args...))
3339
end
3440

3541
"""
@@ -43,64 +49,76 @@ end
4349

4450
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}
4551

46-
function AdvancedPS.Trace(
47-
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
48-
)
49-
return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
52+
function Base.copy(trace::LibtaskTrace)
53+
newtrace = AdvancedPS.Trace(copy(trace.model), deepcopy(trace.rng))
54+
set_other_global!(newtrace, newtrace)
55+
return newtrace
5056
end
5157

52-
# step to the next observe statement and
53-
# return the log probability of the transition (or nothing if done)
54-
function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
55-
isref ? AdvancedPS.load_state!(t.rng) : AdvancedPS.save_state!(t.rng)
56-
AdvancedPS.inc_counter!(t.rng)
57-
58-
# Move to next step
59-
return Libtask.consume(t.model.ctask)
58+
"""Get the RNG from a `LibtaskTrace`."""
59+
function get_rng(trace::LibtaskTrace)
60+
return trace.model.ctask.taped_globals.rng
6061
end
6162

62-
# create a backward reference in task_local_storage
63-
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
64-
if task.storage === nothing
65-
task.storage = IdDict()
66-
end
67-
task.storage[:__trace] = trace
63+
"""Set the RNG for a `LibtaskTrace`."""
64+
function set_rng!(trace::LibtaskTrace, rng::Random.AbstractRNG)
65+
other = get_other_global(trace)
66+
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
67+
trace.rng = rng
68+
return trace
69+
end
6870

69-
return task
71+
"""Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
72+
function set_other_global!(trace::LibtaskTrace, other)
73+
rng = get_rng(trace)
74+
Libtask.set_taped_globals!(trace.model.ctask, TapedGlobals(rng, other))
75+
return trace
7076
end
7177

72-
function AdvancedPS.update_rng!(trace::LibtaskTrace)
73-
rng, = trace.model.ctask.args
74-
trace.rng = rng
78+
"""Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
79+
get_other_global(trace::LibtaskTrace) = trace.model.ctask.taped_globals.other
80+
81+
function AdvancedPS.Trace(
82+
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
83+
)
84+
trace = AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
85+
# Set a backreference so that the TapedTask in `trace` stores the `trace` itself in its
86+
# taped globals.
87+
set_other_global!(trace, trace)
7588
return trace
7689
end
7790

91+
# step to the next observe statement and
92+
# return the log probability of the transition (or nothing if done)
93+
function AdvancedPS.advance!(trace::LibtaskTrace, isref::Bool=false)
94+
rng = get_rng(trace)
95+
isref ? AdvancedPS.load_state!(rng) : AdvancedPS.save_state!(rng)
96+
AdvancedPS.inc_counter!(rng)
97+
# Move to next step
98+
return Libtask.consume(trace.model.ctask)
99+
end
100+
78101
# Task copying version of fork for Trace.
79102
function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
80103
newtrace = copy(trace)
81-
AdvancedPS.update_rng!(newtrace)
104+
set_rng!(newtrace, deepcopy(get_rng(newtrace)))
82105
isref && AdvancedPS.delete_retained!(newtrace.model.f)
83106
isref && delete_seeds!(newtrace)
84-
85-
# add backward reference
86-
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
87107
return newtrace
88108
end
89109

90110
# PG requires keeping all randomness for the reference particle
91111
# Create new task and copy randomness
92112
function AdvancedPS.forkr(trace::LibtaskTrace)
113+
rng = get_rng(trace)
93114
newf = AdvancedPS.reset_model(trace.model.f)
94-
Random123.set_counter!(trace.rng, 1)
115+
Random123.set_counter!(rng, 1)
95116

96-
ctask = Libtask.TapedTask(
97-
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
98-
)
117+
ctask = Libtask.TapedTask(TapedGlobals(rng, get_other_global(trace)), newf)
99118
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)
100119

101120
# add backward reference
102-
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
103-
AdvancedPS.addreference!(ctask.task, newtrace)
121+
newtrace = AdvancedPS.Trace(new_tapedmodel, rng)
104122
AdvancedPS.gen_refseed!(newtrace)
105123
return newtrace
106124
end
@@ -113,7 +131,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
113131
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
114132
"""
115133
function AdvancedPS.observe(dist::Distributions.Distribution, x)
116-
return Libtask.produce(Distributions.loglikelihood(dist, x))
134+
Libtask.produce(Distributions.loglikelihood(dist, x))
135+
return nothing
117136
end
118137

119138
"""
@@ -138,7 +157,6 @@ function AbstractMCMC.step(
138157
else
139158
trng = AdvancedPS.TracedRNG()
140159
trace = AdvancedPS.Trace(deepcopy(model), trng)
141-
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
142160
trace
143161
end
144162
end
@@ -153,8 +171,7 @@ function AbstractMCMC.step(
153171
newtrajectory = rand(rng, particles)
154172

155173
replayed = AdvancedPS.replay(newtrajectory)
156-
return AdvancedPS.PGSample(replayed.model.f, logevidence),
157-
AdvancedPS.PGState(newtrajectory)
174+
return AdvancedPS.PGSample(replayed.model.f, logevidence), AdvancedPS.PGState(replayed)
158175
end
159176

160177
function AbstractMCMC.sample(
@@ -176,7 +193,6 @@ function AbstractMCMC.sample(
176193
traces = map(1:(sampler.nparticles)) do i
177194
trng = AdvancedPS.TracedRNG()
178195
trace = AdvancedPS.Trace(deepcopy(model), trng)
179-
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
180196
trace
181197
end
182198

src/model.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,6 @@ end
6262
# in an extension, we just define dummy in the main module and implement them in the extension.
6363
function observe end
6464
function replay end
65-
function addreference! end
66-
67-
current_trace() = current_task().storage[:__trace]
6865

6966
# We need this one to be visible outside of the extension for dispatching (Turing.jl).
7067
struct LibtaskModel{F,T}

src/rng.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,5 +118,3 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
118118
Increase the model step counter by `n`
119119
"""
120120
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n
121-
122-
function update_rng! end

test/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ GaussianDistributions = "0.5"
1919
Kalman = "0.1"
2020
HypothesisTests = "0.11"
2121
DynamicIterators = "0.4"
22-
Libtask = "0.8"
22+
Libtask = "0.9"
2323
Random123 = "1.3"
2424
StableRNGs = "1"
25-
julia = "1.3"
25+
julia = "1.10.8"

test/container.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@
123123
val::Ref{Int}
124124
end
125125

126-
function (model::Model)(rng::Random.AbstractRNG)
126+
function (model::Model)()
127127
t = [0]
128128
while true
129129
model.val[] += 1
@@ -148,17 +148,17 @@
148148
@test consume(a.model.ctask) == 4
149149
end
150150

151-
@testset "current trace" begin
151+
@testset "Back-reference" begin
152152
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end
153153

154-
function (model::TaskIdModel)(rng::Random.AbstractRNG)
154+
function (model::TaskIdModel)()
155155
# Just print the task it's running in
156-
id = objectid(AdvancedPS.current_trace())
156+
trace = Libtask.get_taped_globals(Any).other
157+
id = objectid(trace)
157158
return Libtask.produce(id)
158159
end
159160

160161
trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
161-
AdvancedPS.addreference!(trace.model.ctask.task, trace)
162162

163163
@test AdvancedPS.advance!(trace, false) === objectid(trace)
164164
end

test/smc.jl

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,16 @@
2828
NormalModel() = new()
2929
end
3030

31-
function (m::NormalModel)(rng::Random.AbstractRNG)
31+
function (m::NormalModel)()
3232
# First latent variable.
33+
rng = Libtask.get_taped_globals(Any).rng
3334
m.a = a = rand(rng, Normal(4, 5))
3435

3536
# First observation.
3637
AdvancedPS.observe(Normal(a, 2), 3)
3738

3839
# Second latent variable.
40+
rng = Libtask.get_taped_globals(Any).rng
3941
m.b = b = rand(rng, Normal(a, 1))
4042

4143
# Second observation.
@@ -52,8 +54,11 @@
5254
FailSMCModel() = new()
5355
end
5456

55-
function (m::FailSMCModel)(rng::Random.AbstractRNG)
57+
function (m::FailSMCModel)()
58+
rng = Libtask.get_taped_globals(Any).rng
5659
m.a = a = rand(rng, Normal(4, 5))
60+
61+
rng = Libtask.get_taped_globals(Any).rng
5762
m.b = b = rand(rng, Normal(a, 1))
5863
if a >= 4
5964
AdvancedPS.observe(Normal(b, 2), 1.5)
@@ -75,8 +80,9 @@
7580
TestModel() = new()
7681
end
7782

78-
function (m::TestModel)(rng::Random.AbstractRNG)
83+
function (m::TestModel)()
7984
# First hidden variables.
85+
rng = Libtask.get_taped_globals(Any).rng
8086
m.a = rand(rng, Normal(0, 1))
8187
m.x = x = rand(rng, Bernoulli(1))
8288
m.b = rand(rng, Gamma(2, 3))
@@ -85,13 +91,14 @@
8591
AdvancedPS.observe(Bernoulli(x / 2), 1)
8692

8793
# Second hidden variable.
94+
rng = Libtask.get_taped_globals(Any).rng
8895
m.c = rand(rng, Beta())
8996

9097
# Second observation.
9198
return AdvancedPS.observe(Bernoulli(x / 2), 0)
9299
end
93100

94-
chains_smc = sample(TestModel(), AdvancedPS.SMC(100))
101+
chains_smc = sample(TestModel(), AdvancedPS.SMC(100); progress=false)
95102

96103
@test all(isone(particle.x) for particle in chains_smc.trajectories)
97104
@test chains_smc.logevidence -2 * log(2)
@@ -145,7 +152,7 @@
145152
return AdvancedPS.observe(Bernoulli(x / 2), 0)
146153
end
147154

148-
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100)
155+
chains_pg = sample(TestModel(), AdvancedPS.PG(10), 100; progress=false)
149156

150157
@test all(isone(p.trajectory.x) for p in chains_pg)
151158
@test mean(x.logevidence for x in chains_pg) -2 * log(2) atol = 0.01
@@ -159,16 +166,18 @@
159166
DummyModel() = new()
160167
end
161168

162-
function (m::DummyModel)(rng)
169+
function (m::DummyModel)()
170+
rng = Libtask.get_taped_globals(Any).rng
163171
m.a = rand(rng, Normal())
164172
AdvancedPS.observe(Normal(), m.a)
165173

174+
rng = Libtask.get_taped_globals(Any).rng
166175
m.b = rand(rng, Normal())
167176
return AdvancedPS.observe(Normal(), m.b)
168177
end
169178

170179
pg = AdvancedPS.PG(1)
171-
first, second = sample(DummyModel(), pg, 2)
180+
first, second = sample(DummyModel(), pg, 2; progress=false)
172181

173182
first_model = first.trajectory
174183
second_model = second.trajectory

0 commit comments

Comments
 (0)