16
16
using .. Libtask: Libtask
17
17
end
18
18
19
+ # In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20
+ # and sometimes both an RNG and other information. In Turing.jl the other information
21
+ # is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22
+ # the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23
+ struct TapedGlobals{RngType}
24
+ rng:: RngType
25
+ other:: Any
26
+ end
27
+
28
+ TapedGlobals (rng:: Random.AbstractRNG ) = TapedGlobals (rng, nothing )
29
+
19
30
"""
20
31
LibtaskModel{F}
21
32
@@ -24,12 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
24
35
function AdvancedPS. LibtaskModel (
25
36
f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
26
37
) # Changed the API, need to take care of the RNG properly
27
- return AdvancedPS. LibtaskModel (
28
- f,
29
- Libtask. TapedTask (
30
- f, rng, args... ; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (f)}
31
- ),
32
- )
38
+ return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (TapedGlobals (rng), f, args... ))
33
39
end
34
40
35
41
"""
43
49
44
50
const LibtaskTrace{R} = AdvancedPS. Trace{<: AdvancedPS.LibtaskModel ,R}
45
51
46
- function AdvancedPS . Trace (
47
- model :: AdvancedPS.AbstractGenericModel , rng :: Random.AbstractRNG , args ...
48
- )
49
- return AdvancedPS . Trace (AdvancedPS . LibtaskModel (model, rng, args ... ), rng)
52
+ function Base . copy (trace :: LibtaskTrace )
53
+ newtrace = AdvancedPS. Trace ( copy (trace . model), deepcopy (trace . rng))
54
+ set_other_global! (newtrace, newtrace )
55
+ return newtrace
50
56
end
51
57
52
- # step to the next observe statement and
53
- # return the log probability of the transition (or nothing if done)
54
- function AdvancedPS. advance! (t:: LibtaskTrace , isref:: Bool = false )
55
- isref ? AdvancedPS. load_state! (t. rng) : AdvancedPS. save_state! (t. rng)
56
- AdvancedPS. inc_counter! (t. rng)
57
-
58
- # Move to next step
59
- return Libtask. consume (t. model. ctask)
58
+ """ Get the RNG from a `LibtaskTrace`."""
59
+ function get_rng (trace:: LibtaskTrace )
60
+ return trace. model. ctask. taped_globals. rng
60
61
end
61
62
62
- # create a backward reference in task_local_storage
63
- function AdvancedPS. addreference! (task:: Task , trace:: LibtaskTrace )
64
- if task. storage === nothing
65
- task. storage = IdDict ()
66
- end
67
- task. storage[:__trace ] = trace
63
+ """ Set the RNG for a `LibtaskTrace`."""
64
+ function set_rng! (trace:: LibtaskTrace , rng:: Random.AbstractRNG )
65
+ other = get_other_global (trace)
66
+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, other))
67
+ trace. rng = rng
68
+ return trace
69
+ end
68
70
69
- return task
71
+ """ Set the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
72
+ function set_other_global! (trace:: LibtaskTrace , other)
73
+ rng = get_rng (trace)
74
+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, other))
75
+ return trace
70
76
end
71
77
72
- function AdvancedPS. update_rng! (trace:: LibtaskTrace )
73
- rng, = trace. model. ctask. args
74
- trace. rng = rng
78
+ """ Get the other "taped global" variable of a `LibtaskTrace`, other than the RNG."""
79
+ get_other_global (trace:: LibtaskTrace ) = trace. model. ctask. taped_globals. other
80
+
81
+ function AdvancedPS. Trace (
82
+ model:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
83
+ )
84
+ trace = AdvancedPS. Trace (AdvancedPS. LibtaskModel (model, rng, args... ), rng)
85
+ # Set a backreference so that the TapedTask in `trace` stores the `trace` itself in its
86
+ # taped globals.
87
+ set_other_global! (trace, trace)
75
88
return trace
76
89
end
77
90
91
+ # step to the next observe statement and
92
+ # return the log probability of the transition (or nothing if done)
93
+ function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
94
+ rng = get_rng (trace)
95
+ isref ? AdvancedPS. load_state! (rng) : AdvancedPS. save_state! (rng)
96
+ AdvancedPS. inc_counter! (rng)
97
+ # Move to next step
98
+ return Libtask. consume (trace. model. ctask)
99
+ end
100
+
78
101
# Task copying version of fork for Trace.
79
102
function AdvancedPS. fork (trace:: LibtaskTrace , isref:: Bool = false )
80
103
newtrace = copy (trace)
81
- AdvancedPS . update_rng ! (newtrace)
104
+ set_rng ! (newtrace, deepcopy ( get_rng (newtrace)) )
82
105
isref && AdvancedPS. delete_retained! (newtrace. model. f)
83
106
isref && delete_seeds! (newtrace)
84
-
85
- # add backward reference
86
- AdvancedPS. addreference! (newtrace. model. ctask. task, newtrace)
87
107
return newtrace
88
108
end
89
109
90
110
# PG requires keeping all randomness for the reference particle
91
111
# Create new task and copy randomness
92
112
function AdvancedPS. forkr (trace:: LibtaskTrace )
113
+ rng = get_rng (trace)
93
114
newf = AdvancedPS. reset_model (trace. model. f)
94
- Random123. set_counter! (trace . rng, 1 )
115
+ Random123. set_counter! (rng, 1 )
95
116
96
- ctask = Libtask. TapedTask (
97
- newf, trace. rng; deepcopy_types= Union{AdvancedPS. TracedRNG,typeof (trace. model. f)}
98
- )
117
+ ctask = Libtask. TapedTask (TapedGlobals (rng, get_other_global (trace)), newf)
99
118
new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
100
119
101
120
# add backward reference
102
- newtrace = AdvancedPS. Trace (new_tapedmodel, trace. rng)
103
- AdvancedPS. addreference! (ctask. task, newtrace)
121
+ newtrace = AdvancedPS. Trace (new_tapedmodel, rng)
104
122
AdvancedPS. gen_refseed! (newtrace)
105
123
return newtrace
106
124
end
@@ -113,7 +131,8 @@ AdvancedPS.update_ref!(::LibtaskTrace) = nothing
113
131
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
114
132
"""
115
133
function AdvancedPS. observe (dist:: Distributions.Distribution , x)
116
- return Libtask. produce (Distributions. loglikelihood (dist, x))
134
+ Libtask. produce (Distributions. loglikelihood (dist, x))
135
+ return nothing
117
136
end
118
137
119
138
"""
@@ -138,7 +157,6 @@ function AbstractMCMC.step(
138
157
else
139
158
trng = AdvancedPS. TracedRNG ()
140
159
trace = AdvancedPS. Trace (deepcopy (model), trng)
141
- AdvancedPS. addreference! (trace. model. ctask. task, trace) # TODO : Do we need it here ?
142
160
trace
143
161
end
144
162
end
@@ -153,8 +171,7 @@ function AbstractMCMC.step(
153
171
newtrajectory = rand (rng, particles)
154
172
155
173
replayed = AdvancedPS. replay (newtrajectory)
156
- return AdvancedPS. PGSample (replayed. model. f, logevidence),
157
- AdvancedPS. PGState (newtrajectory)
174
+ return AdvancedPS. PGSample (replayed. model. f, logevidence), AdvancedPS. PGState (replayed)
158
175
end
159
176
160
177
function AbstractMCMC. sample (
@@ -176,7 +193,6 @@ function AbstractMCMC.sample(
176
193
traces = map (1 : (sampler. nparticles)) do i
177
194
trng = AdvancedPS. TracedRNG ()
178
195
trace = AdvancedPS. Trace (deepcopy (model), trng)
179
- AdvancedPS. addreference! (trace. model. ctask. task, trace) # Do we need it here ?
180
196
trace
181
197
end
182
198
0 commit comments