16
16
using .. Libtask: Libtask
17
17
end
18
18
19
+ # In Libtask.TapedTask.taped_globals, this extension sometimes needs to store an RNG,
20
+ # and sometimes both an RNG and other information. In Turing.jl this other information
21
+ # is a VarInfo. This struct puts those in a single struct. Note the abstract type of
22
+ # the second field. This is okay, because `get_taped_globals` needs a type assertion anyway.
23
+ struct TapedGlobals{RngType}
24
+ rng:: RngType
25
+ other:: Any
26
+ end
27
+
28
+ TapedGlobals (rng:: Random.AbstractRNG ) = TapedGlobals (rng, nothing )
29
+
19
30
"""
20
31
LibtaskModel{F}
21
32
@@ -24,7 +35,7 @@ State wrapper to hold `Libtask.CTask` model initiated from `f`.
24
35
function AdvancedPS. LibtaskModel (
25
36
f:: AdvancedPS.AbstractGenericModel , rng:: Random.AbstractRNG , args...
26
37
) # Changed the API, need to take care of the RNG properly
27
- return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (rng, f, args... ))
38
+ return AdvancedPS. LibtaskModel (f, Libtask. TapedTask (TapedGlobals ( rng) , f, args... ))
28
39
end
29
40
30
41
"""
47
58
# step to the next observe statement and
48
59
# return the log probability of the transition (or nothing if done)
49
60
function AdvancedPS. advance! (trace:: LibtaskTrace , isref:: Bool = false )
50
- # Where is the RNG ?
51
- isref ? AdvancedPS. load_state! (trace. rng) : AdvancedPS. save_state! (trace. rng)
52
- AdvancedPS. inc_counter! (trace. rng)
61
+ taped_globals = trace. model. ctask. taped_globals
62
+ rng = taped_globals. rng
63
+ isref ? AdvancedPS. load_state! (rng) : AdvancedPS. save_state! (rng)
64
+ AdvancedPS. inc_counter! (rng)
53
65
54
- Libtask. set_taped_globals! (trace. model. ctask, trace. rng)
66
+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (rng, taped_globals. other))
67
+ trace. rng = rng
55
68
56
69
# Move to next step
57
70
return Libtask. consume (trace. model. ctask)
58
71
end
59
72
60
73
# create a backward reference in task_local_storage
61
- function AdvancedPS. addreference! (task:: Task , trace:: LibtaskTrace )
62
- if task. storage === nothing
63
- task. storage = IdDict ()
64
- end
65
- task. storage[:__trace ] = trace
66
-
74
+ function AdvancedPS. addreference! (task:: Libtask.TapedTask , trace:: LibtaskTrace )
75
+ rng = task. taped_globals. rng
76
+ Libtask. set_taped_globals! (task, TapedGlobals (rng, trace))
67
77
return task
68
78
end
69
79
70
80
function AdvancedPS. update_rng! (trace:: LibtaskTrace )
71
- new_rng = deepcopy (trace. rng)
81
+ taped_globals = trace. model. ctask. taped_globals
82
+ new_rng = deepcopy (taped_globals. rng)
72
83
trace. rng = new_rng
73
- Libtask. set_taped_globals! (trace. model. ctask, trace . rng )
84
+ Libtask. set_taped_globals! (trace. model. ctask, TapedGlobals (new_rng, taped_globals . other) )
74
85
return trace
75
86
end
76
87
86
97
# PG requires keeping all randomness for the reference particle
87
98
# Create new task and copy randomness
88
99
function AdvancedPS. forkr (trace:: LibtaskTrace )
100
+ taped_globals = trace. model. ctask. taped_globals
101
+ rng = taped_globals. rng
89
102
newf = AdvancedPS. reset_model (trace. model. f)
90
- Random123. set_counter! (trace. rng, 1 )
103
+ Random123. set_counter! (rng, 1 )
104
+ trace. rng = rng
91
105
92
- ctask = Libtask. TapedTask (trace . rng, newf)
106
+ ctask = Libtask. TapedTask (TapedGlobals ( rng, taped_globals . other) , newf)
93
107
new_tapedmodel = AdvancedPS. LibtaskModel (newf, ctask)
94
108
95
109
# add backward reference
96
- newtrace = AdvancedPS. Trace (new_tapedmodel, trace . rng)
110
+ newtrace = AdvancedPS. Trace (new_tapedmodel, rng)
97
111
AdvancedPS. gen_refseed! (newtrace)
98
-
99
- Libtask. set_taped_globals! (ctask, trace. rng) # Sync trace and rng
100
112
return newtrace
101
113
end
102
114
0 commit comments