-
Notifications
You must be signed in to change notification settings - Fork 227
AdvancedPS v0.7 (and thus Libtask v0.9) support #2585
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
The tests that I had the patience to run locally now pass. Waiting for the AdvancedPS release to be able to run the full test suite on CI. Some indicators of speed: julia> module MWE
using Turing
@model function gdemo(x, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
x ~ Normal(m, sqrt(s))
y ~ Normal(m, sqrt(s))
return s, m
end
@time chn = sample(gdemo(2.5, 1.0), PG(10), 10_000)
describe(chn)
end On main:
On this branch:
julia> module MWE
using Turing
@model function f(dim=20, ::Type{T}=Float64) where T
s = Vector{Bool}(undef, dim)
x = Vector{T}(undef, dim)
for i in 1:dim
s[i] ~ Bernoulli()
if s[i]
x[i] ~ Normal()
else
x[i] ~ Beta()
end
0.0 ~ Normal(x[i])
end
return nothing
end
alg = Gibbs(
@varname(s)=>PG(10),
@varname(x)=>HMC(0.1, 5),
)
@time chn = sample(f(), alg, 1_000)
end On main:
On this branch:
Obviously the speed gains are all due to @willtebbutt's fantastic work on Libtask, everything else is just wrapping that work. |
Turing.jl documentation for PR #2585 is available at: |
@@ -85,7 +85,7 @@ Statistics = "1.6" | |||
StatsAPI = "1.6" | |||
StatsBase = "0.32, 0.33, 0.34" | |||
StatsFuns = "0.8, 0.9, 1" | |||
julia = "1.10.2" | |||
julia = "1.10.8" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Libtask requires 1.10.8 at a minimum.
@@ -402,11 +391,11 @@ end | |||
|
|||
function trace_local_varinfo_maybe(varinfo) | |||
try | |||
trace = AdvancedPS.current_trace() | |||
return trace.model.f.varinfo | |||
trace = Libtask.get_taped_globals(Any).other |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we change Libtask.get_taped_globals
to return nothing
if not inside a running TapedTask
, the following try .. catch ... end
can be removed.
@@ -416,11 +405,10 @@ end | |||
|
|||
function trace_local_rng_maybe(rng::Random.AbstractRNG) | |||
try | |||
trace = AdvancedPS.current_trace() | |||
return trace.rng | |||
return Libtask.get_taped_globals(Any).rng |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with above.
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2585 +/- ##
===========================================
- Coverage 85.57% 50.44% -35.13%
===========================================
Files 22 22
Lines 1456 1447 -9
===========================================
- Hits 1246 730 -516
- Misses 210 717 +507 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Pull Request Test Coverage Report for Build 15835391573Details
💛 - Coveralls |
Is this reviewable? The tests are failing, there's a method ambiguity that Aqua complains about, there's a Gibbs failure on 1.12 which should be disabled with
I don't want to speak for @mhauru in his absence but last time we spoke about this PR, it was clear that there were still a few gaps to bridge. If I were to review it at this stage, my sole comment would be to fix the tests. |
The complement PR of TuringLang/AdvancedPS.jl#114, which adds support for the newly rewritten Libtask.
Work in progress, currently blocked by TuringLang/Libtask.jl#186