Skip to content

Commit a5099c9

Browse files
Respond to robot feedback
1 parent f2d652f commit a5099c9

File tree

3 files changed

+16
-26
lines changed

3 files changed

+16
-26
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
)
2929

3030

31-
def _make_inital_point(model, initvals=None, random_seed=None, jitter_rvs=None):
31+
def _make_initial_point(model, initvals=None, random_seed=None, jitter_rvs=None):
3232
jitter_rvs = [] if jitter_rvs is None else jitter_rvs
3333

3434
ipfn = make_initial_point_fn(
@@ -201,7 +201,7 @@ def find_MAP(
201201
frozen_model = freeze_dims_and_data(model)
202202
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
203203

204-
initial_params = _make_inital_point(frozen_model, initvals, random_seed, jitter_rvs)
204+
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
205205

206206
do_basinhopping = method == "basinhopping"
207207
minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def make_unpacked_variable_names(name, model: pm.Model) -> list[str]:
3535
return [f"{name}[{','.join(map(str, label))}]" for label in labels]
3636

3737

38-
def map_results_to_inferece_data(results: dict[str, Any], model: pm.Model | None = None):
38+
def map_results_to_inference_data(results: dict[str, Any], model: pm.Model | None = None):
3939
"""
4040
Convert a dictionary of results to an InferenceData object.
4141
@@ -51,7 +51,7 @@ def map_results_to_inferece_data(results: dict[str, Any], model: pm.Model | None
5151
idata: az.InferenceData
5252
An InferenceData object containing the results.
5353
"""
54-
model = pm.modelcontext(model)
54+
model = pm.modelcontext(model) if model is None else model
5555
coords, dims = coords_and_dims_for_inferencedata(model)
5656

5757
idata = az.convert_to_inference_data(results, coords=coords, dims=dims)
@@ -78,7 +78,7 @@ def laplace_draws_to_inferencedata(
7878
idata: az.InferenceData
7979
An InferenceData object containing the approximated posterior samples
8080
"""
81-
model = pm.modelcontext(model)
81+
model = pm.modelcontext(model) if model is None else model
8282
chains, draws, *_ = posterior_draws[0].shape
8383

8484
def make_rv_coords(name):
@@ -191,7 +191,7 @@ def add_fit_to_inference_data(
191191
idata: az.InferenceData
192192
The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group.
193193
"""
194-
model = pm.modelcontext(model)
194+
model = pm.modelcontext(model) if model is None else model
195195

196196
variable_names, *_ = zip(*mu.point_map_info)
197197

@@ -242,7 +242,7 @@ def add_data_to_inference_data(
242242
idata: az.InferenceData
243243
The provided InferenceData, with observed and constant data added.
244244
"""
245-
model = pm.modelcontext(model)
245+
model = pm.modelcontext(model) if model is None else model
246246

247247
if model.deterministics:
248248
expand_dims = {}
@@ -309,6 +309,7 @@ def optimizer_result_to_dataset(
309309
"""
310310
if not isinstance(result, OptimizeResult):
311311
raise TypeError("result must be an instance of OptimizeResult")
312+
312313
model = pm.modelcontext(model) if model is None else model
313314
variable_names, *_ = zip(*mu.point_map_info)
314315
unpacked_variable_names = reduce(

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from pymc_extras.inference.laplace_approx.find_map import (
3636
_compute_inverse_hessian,
37-
_make_inital_point,
37+
_make_initial_point,
3838
find_MAP,
3939
)
4040
from pymc_extras.inference.laplace_approx.scipy_interface import scipy_optimize_funcs_from_loss
@@ -228,7 +228,6 @@ def fit_laplace(
228228
use_hess: bool | None = None,
229229
initvals: dict | None = None,
230230
random_seed: int | np.random.Generator | None = None,
231-
return_raw: bool = False,
232231
jitter_rvs: list[pt.TensorVariable] | None = None,
233232
progressbar: bool = True,
234233
include_transformed: bool = True,
@@ -268,23 +267,13 @@ def fit_laplace(
268267
If None, the model's default initial values are used.
269268
random_seed : None | int | np.random.Generator, optional
270269
Seed for the random number generator or a numpy Generator for reproducibility
271-
return_raw: bool | False, optinal
272-
Whether to also return the full output of `scipy.optimize.minimize`
273270
jitter_rvs : list of TensorVariables, optional
274271
Variables whose initial values should be jittered. If None, all variables are jittered.
275272
progressbar : bool, optional
276273
Whether to display a progress bar during optimization. Defaults to True.
277-
fit_in_unconstrained_space: bool, default False
278-
Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn
279-
from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will
280-
then be transformed back to the original parameter space. This will guarantee that the samples will respect
281-
the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0
282-
and 1).
283-
284-
.. warning::
285-
This argument should be considered highly experimental. It has not been verified if this method produces
286-
valid draws from the posterior. **Use at your own risk**.
287-
274+
include_transformed: bool, default True
275+
Whether to include transformed variables in the output. If True, transformed variables will be included in the
276+
output InferenceData object. If False, only the original variables will be included.
288277
gradient_backend: str, default "pytensor"
289278
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
290279
chains: int, default: 2
@@ -365,7 +354,7 @@ def fit_laplace(
365354
# The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
366355
# we have to go back and compute the Hessian at the MAP point now.
367356
frozen_model = freeze_dims_and_data(model)
368-
initial_params = _make_inital_point(frozen_model, initvals, random_seed, jitter_rvs)
357+
initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
369358

370359
_, f_hessp = scipy_optimize_funcs_from_loss(
371360
loss=-frozen_model.logp(jacobian=False),
@@ -405,9 +394,9 @@ def fit_laplace(
405394
.rename({"temp_chain": "chain", "temp_draw": "draw"})
406395
)
407396

408-
new_posterior.update(unstack_laplace_draws(new_posterior, model)).drop_vars(
409-
"laplace_approximation"
410-
)
397+
new_posterior.update(unstack_laplace_draws(new_posterior, model))
398+
new_posterior = new_posterior.drop_vars("laplace_approximation")
399+
411400
idata.posterior.update(new_posterior)
412401

413402
return idata

0 commit comments

Comments
 (0)