Skip to content

AdvancedPS v0.7 (and thus Libtask v0.9) support #2585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ AbstractPPL = "0.11.0"
Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
AdvancedMH = "0.8"
AdvancedPS = "0.6.0"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Expand All @@ -67,7 +67,7 @@ DynamicHMC = "3.4"
DynamicPPL = "0.36.3"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3, 1"
Libtask = "0.8.8"
Libtask = "0.9.2"
LinearAlgebra = "1"
LogDensityProblems = "2"
MCMCChains = "5, 6, 7"
Expand All @@ -85,7 +85,7 @@ Statistics = "1.6"
StatsAPI = "1.6"
StatsBase = "0.32, 0.33, 0.34"
StatsFuns = "0.8, 0.9, 1"
julia = "1.10.2"
julia = "1.10.8"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Libtask requires 1.10.8 at a minimum.


[extras]
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
Expand Down
55 changes: 31 additions & 24 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@
"Sampling with `$(sampler.alg)` does not support models with keyword arguments. See issue #2007 for more details.",
)
end
return TracedModel{AbstractSampler,AbstractVarInfo,Model,Tuple}(
model, sampler, varinfo, (model.f, args...)
)
evaluator = (model.f, args...)
return TracedModel(model, sampler, varinfo, evaluator)
end

function AdvancedPS.advance!(
Expand Down Expand Up @@ -59,20 +58,10 @@
return trace
end

function AdvancedPS.update_rng!(
trace::AdvancedPS.Trace{<:AdvancedPS.LibtaskModel{<:TracedModel}}
)
# Extract the `args`.
args = trace.model.ctask.args
# From `args`, extract the `SamplingContext`, which contains the RNG.
sampling_context = args[3]
rng = sampling_context.rng
trace.rng = rng
return trace
end

function Libtask.TapedTask(model::TracedModel, ::Random.AbstractRNG, args...; kwargs...) # RNG ?
return Libtask.TapedTask(model.evaluator[1], model.evaluator[2:end]...; kwargs...)
function Libtask.TapedTask(taped_globals::Any, model::TracedModel, args...; kwargs...) # RNG ?
return Libtask.TapedTask(
taped_globals, model.evaluator[1], model.evaluator[2:end]...; kwargs...
)
end

abstract type ParticleInference <: InferenceAlgorithm end
Expand Down Expand Up @@ -402,11 +391,11 @@

function trace_local_varinfo_maybe(varinfo)
try
trace = AdvancedPS.current_trace()
return trace.model.f.varinfo
trace = Libtask.get_taped_globals(Any).other
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we change Libtask.get_taped_globals to return nothing if not inside a running TapedTask, the following try .. catch ... end can be removed.

return (trace === nothing ? varinfo : trace.model.f.varinfo)::AbstractVarInfo
catch e
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
if e == KeyError(:__trace) || current_task().storage isa Nothing
if e == KeyError(:task_variable)

Check warning on line 398 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L398

Added line #L398 was not covered by tests
return varinfo
else
rethrow(e)
Expand All @@ -416,11 +405,10 @@

function trace_local_rng_maybe(rng::Random.AbstractRNG)
try
trace = AdvancedPS.current_trace()
return trace.rng
return Libtask.get_taped_globals(Any).rng
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with above.

catch e
# NOTE: this heuristic allows Libtask evaluating a model outside a `Trace`.
if e == KeyError(:__trace) || current_task().storage isa Nothing
if e == KeyError(:task_variable)

Check warning on line 411 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L411

Added line #L411 was not covered by tests
return rng
else
rethrow(e)
Expand Down Expand Up @@ -481,6 +469,25 @@

tmodel = TracedModel(model, sampler, newvarinfo, rng)
newtrace = AdvancedPS.Trace(tmodel, rng)
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end

# We need to tell Libtask which calls may have `produce` calls within them. In practice most
# of these won't be needed, because of inlining and the fact that `might_produce` is only
# called on `:invoke` expressions rather than `:call`s, but since those are implementation
# details of the compiler, we set a bunch of methods as might_produce = true. We start with
# `acclogp_observe!!` which is what calls `produce` and go up the call stack.
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.acclogp_observe!!),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.tilde_observe!!),Vararg}}) = true
Libtask.might_produce(::Type{<:Tuple{typeof(DynamicPPL.evaluate!!),Vararg}}) = true
function Libtask.might_produce(

Check warning on line 483 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L480-L483

Added lines #L480 - L483 were not covered by tests
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadsafe!!),Vararg}}
)
return true

Check warning on line 486 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L486

Added line #L486 was not covered by tests
end
function Libtask.might_produce(

Check warning on line 488 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L488

Added line #L488 was not covered by tests
::Type{<:Tuple{typeof(DynamicPPL.evaluate_threadunsafe!!),Vararg}}
)
return true

Check warning on line 491 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L491

Added line #L491 was not covered by tests
end
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

Check warning on line 493 in src/mcmc/particle_mcmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/particle_mcmc.jl#L493

Added line #L493 was not covered by tests
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ ADTypes = "1"
AbstractMCMC = "5"
AbstractPPL = "0.9, 0.10, 0.11"
AdvancedMH = "0.6, 0.7, 0.8"
AdvancedPS = "=0.6.0"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
Aqua = "0.8"
BangBang = "0.4"
Expand Down
1 change: 1 addition & 0 deletions test/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ using Turing

tested = sample(normal(), SMC(), 100)

# TODO(mhauru) This needs an explanation for why it fails.
# failing test
@model function fail_smc()
a ~ Normal(4, 5)
Expand Down
Loading