|
34 | 34 |
|
35 | 35 | from pymc_extras.inference.laplace_approx.find_map import (
|
36 | 36 | _compute_inverse_hessian,
|
37 |
| - _make_inital_point, |
| 37 | + _make_initial_point, |
38 | 38 | find_MAP,
|
39 | 39 | )
|
40 | 40 | from pymc_extras.inference.laplace_approx.scipy_interface import scipy_optimize_funcs_from_loss
|
@@ -228,7 +228,6 @@ def fit_laplace(
|
228 | 228 | use_hess: bool | None = None,
|
229 | 229 | initvals: dict | None = None,
|
230 | 230 | random_seed: int | np.random.Generator | None = None,
|
231 |
| - return_raw: bool = False, |
232 | 231 | jitter_rvs: list[pt.TensorVariable] | None = None,
|
233 | 232 | progressbar: bool = True,
|
234 | 233 | include_transformed: bool = True,
|
@@ -268,23 +267,13 @@ def fit_laplace(
|
268 | 267 | If None, the model's default initial values are used.
|
269 | 268 | random_seed : None | int | np.random.Generator, optional
|
270 | 269 | Seed for the random number generator or a numpy Generator for reproducibility
|
271 |
| - return_raw: bool | False, optinal |
272 |
| - Whether to also return the full output of `scipy.optimize.minimize` |
273 | 270 | jitter_rvs : list of TensorVariables, optional
|
274 | 271 | Variables whose initial values should be jittered. If None, all variables are jittered.
|
275 | 272 | progressbar : bool, optional
|
276 | 273 | Whether to display a progress bar during optimization. Defaults to True.
|
277 |
| - fit_in_unconstrained_space: bool, default False |
278 |
| - Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn |
279 |
| - from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will |
280 |
| - then be transformed back to the original parameter space. This will guarantee that the samples will respect |
281 |
| - the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0 |
282 |
| - and 1). |
283 |
| -
|
284 |
| - .. warning:: |
285 |
| - This argument should be considered highly experimental. It has not been verified if this method produces |
286 |
| - valid draws from the posterior. **Use at your own risk**. |
287 |
| -
|
| 274 | + include_transformed: bool, default True |
| 275 | + Whether to include transformed variables in the output. If True, transformed variables will be included in the |
| 276 | + output InferenceData object. If False, only the original variables will be included. |
288 | 277 | gradient_backend: str, default "pytensor"
|
289 | 278 | The backend to use for gradient computations. Must be one of "pytensor" or "jax".
|
290 | 279 | chains: int, default: 2
|
@@ -365,7 +354,7 @@ def fit_laplace(
|
365 | 354 | # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
|
366 | 355 | # we have to go back and compute the Hessian at the MAP point now.
|
367 | 356 | frozen_model = freeze_dims_and_data(model)
|
368 |
| - initial_params = _make_inital_point(frozen_model, initvals, random_seed, jitter_rvs) |
| 357 | + initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs) |
369 | 358 |
|
370 | 359 | _, f_hessp = scipy_optimize_funcs_from_loss(
|
371 | 360 | loss=-frozen_model.logp(jacobian=False),
|
@@ -405,9 +394,9 @@ def fit_laplace(
|
405 | 394 | .rename({"temp_chain": "chain", "temp_draw": "draw"})
|
406 | 395 | )
|
407 | 396 |
|
408 |
| - new_posterior.update(unstack_laplace_draws(new_posterior, model)).drop_vars( |
409 |
| - "laplace_approximation" |
410 |
| - ) |
| 397 | + new_posterior.update(unstack_laplace_draws(new_posterior, model)) |
| 398 | + new_posterior = new_posterior.drop_vars("laplace_approximation") |
| 399 | + |
411 | 400 | idata.posterior.update(new_posterior)
|
412 | 401 |
|
413 | 402 | return idata
|
0 commit comments