Skip to content

Commit 85bf5fe

Browse files
Use n_adapts instead of nadapts (#375)
* Use `n_adapts` instead of `nadapts` * Update Project.toml * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent f1251f5 commit 85bf5fe

File tree

5 files changed

+62
-6
lines changed

5 files changed

+62
-6
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.6.1"
3+
version = "0.6.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ samples = AbstractMCMC.sample(
128128
model,
129129
sampler,
130130
n_adapts + n_samples;
131-
nadapts = n_adapts,
131+
n_adapts = n_adapts,
132132
initial_params = initial_θ,
133133
)
134134
```

src/abstractmcmc.jl

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,14 @@ function AbstractMCMC.sample(
4747
callback = nothing,
4848
kwargs...,
4949
)
50+
if haskey(kwargs, :nadapts)
51+
throw(
52+
ArgumentError(
53+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
54+
),
55+
)
56+
end
57+
5058
if callback === nothing
5159
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
5260
progress = false # don't use AMCMC's progress-funtionality
@@ -78,6 +86,13 @@ function AbstractMCMC.sample(
7886
callback = nothing,
7987
kwargs...,
8088
)
89+
if haskey(kwargs, :nadapts)
90+
throw(
91+
ArgumentError(
92+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
93+
),
94+
)
95+
end
8196

8297
if callback === nothing
8398
callback = HMCProgressCallback(N, progress = progress, verbose = verbose)
@@ -144,6 +159,14 @@ function AbstractMCMC.step(
144159
n_adapts::Int = 0,
145160
kwargs...,
146161
)
162+
if haskey(kwargs, :nadapts)
163+
throw(
164+
ArgumentError(
165+
"keyword argument `nadapts` is unsupported. Please use `n_adapts` to specify the number of adaptation steps.",
166+
),
167+
)
168+
end
169+
147170
# Compute transition.
148171
i = state.i + 1
149172
t_old = state.transition
@@ -200,7 +223,16 @@ function HMCProgressCallback(n_samples; progress = true, verbose = false)
200223
HMCProgressCallback(pm, progress, verbose, Ref(0), Ref(0))
201224
end
202225

203-
function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kwargs...)
226+
function (cb::HMCProgressCallback)(
227+
rng,
228+
model,
229+
spl,
230+
t,
231+
state,
232+
i;
233+
n_adapts::Int = 0,
234+
kwargs...,
235+
)
204236
progress = cb.progress
205237
verbose = cb.verbose
206238
pm = cb.pm
@@ -243,8 +275,8 @@ function (cb::HMCProgressCallback)(rng, model, spl, t, state, i; nadapts = 0, kw
243275
),
244276
)
245277
# Report finish of adapation
246-
elseif verbose && isadapted && i == nadapts
247-
@info "Finished $nadapts adapation steps" adaptor κ.τ.integrator metric
278+
elseif verbose && isadapted && i == n_adapts
279+
@info "Finished $(n_adapts) adapation steps" adaptor κ.τ.integrator metric
248280
end
249281
end
250282

test/abstractmcmc.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,30 @@ using Statistics: mean
3232
verbose = false,
3333
)
3434

35+
# Error if keyword argument `nadapts` is used
36+
@test_throws ArgumentError AbstractMCMC.sample(
37+
rng,
38+
model,
39+
nuts,
40+
n_adapts + n_samples;
41+
nadapts = n_adapts,
42+
initial_params = θ_init,
43+
progress = false,
44+
verbose = false,
45+
)
46+
@test_throws ArgumentError AbstractMCMC.sample(
47+
rng,
48+
model,
49+
nuts,
50+
MCMCThreads(),
51+
n_adapts + n_samples,
52+
2;
53+
nadapts = n_adapts,
54+
initial_params = θ_init,
55+
progress = false,
56+
verbose = false,
57+
)
58+
3559
# Transform back to original space.
3660
# NOTE: We're not correcting for the `logabsdetjac` here since, but
3761
# we're only interested in the mean it doesn't matter.

test/mcmcchains.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ using Statistics: mean
2323
model,
2424
sampler,
2525
n_adapts + n_samples;
26-
nadapts = n_adapts,
26+
n_adapts = n_adapts,
2727
initial_params = θ_init,
2828
chain_type = Chains,
2929
progress = false,

0 commit comments

Comments
 (0)