From 5f9e579193c01ddb65ffd8ae582e379b011e3688 Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Fri, 11 Jul 2025 00:06:34 +0800 Subject: [PATCH] Respect include_transformed argument --- .../inference/laplace_approx/find_map.py | 4 +-- pymc_extras/inference/laplace_approx/idata.py | 7 +++--- .../inference/laplace_approx/laplace.py | 10 +++++--- .../inference/laplace_approx/test_find_map.py | 25 +++++++++++++------ .../inference/laplace_approx/test_laplace.py | 11 +++++++- 5 files changed, 40 insertions(+), 17 deletions(-) diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py index 930137e2..79a1dea5 100644 --- a/pymc_extras/inference/laplace_approx/find_map.py +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -326,7 +326,7 @@ def find_MAP( ) raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) - unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) + unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True) unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")( DictToArrayBijection.rmap(raveled_optimized) ) @@ -335,7 +335,7 @@ def find_MAP( var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) } - idata = map_results_to_inference_data(optimized_point, frozen_model) + idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed) idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv) idata = add_optimizer_result_to_inference_data( idata, optimizer_result, method, raveled_optimized, model diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py index edf011dd..0d81d64b 100644 --- a/pymc_extras/inference/laplace_approx/idata.py +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -59,6 +59,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str] def map_results_to_inference_data( map_point: dict[str, float | int | np.ndarray], model: pm.Model | None = None, + include_transformed: bool = True, ): """ Add the MAP point to an InferenceData object in the posterior group. @@ -68,13 +69,13 @@ def map_results_to_inference_data( Parameters ---------- - idata: az.InferenceData - An InferenceData object to which the MAP point will be added. map_point: dict A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and the values should be the corresponding MAP estimates. model: Model, optional A PyMC model. If None, the model is taken from the current model context. + include_transformed: bool + Whether to return transformed (unconstrained) variables in the constrained_posterior group. Default is True. Returns ------- @@ -118,7 +119,7 @@ def map_results_to_inference_data( dims=dims, ) - if unconstrained_names: + if unconstrained_names and include_transformed: unconstrained_posterior = az.from_dict( posterior={ k: np.expand_dims(v, (0, 1)) diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py index 2b5ef6a1..85e4cdda 100644 --- a/pymc_extras/inference/laplace_approx/laplace.py +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -302,7 +302,7 @@ def fit_laplace( ---------- model : pm.Model The PyMC model to be fit. If None, the current model context is used. - method : str + optimize_method : str The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. @@ -441,9 +441,11 @@ def fit_laplace( .rename({"temp_chain": "chain", "temp_draw": "draw"}) ) - idata.unconstrained_posterior = unstack_laplace_draws( - new_posterior.laplace_approximation.values, model, chains=chains, draws=draws - ) + if include_transformed: + idata.unconstrained_posterior = unstack_laplace_draws( + new_posterior.laplace_approximation.values, model, chains=chains, draws=draws + ) + idata.posterior = new_posterior.drop_vars( ["laplace_approximation", "unpacked_variable_names"] ) diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py index bf0cb292..309876ca 100644 --- a/tests/inference/laplace_approx/test_find_map.py +++ b/tests/inference/laplace_approx/test_find_map.py @@ -133,12 +133,19 @@ def compute_z(x): ], ) @pytest.mark.parametrize( - "backend, gradient_backend", - [("jax", "jax"), ("jax", "pytensor")], + "backend, gradient_backend, include_transformed", + [("jax", "jax", True), ("jax", "pytensor", False)], ids=str, ) def test_find_MAP( - method, use_grad, use_hess, use_hessp, backend, gradient_backend: GradientBackend, rng + method, + use_grad, + use_hess, + use_hessp, + backend, + gradient_backend: GradientBackend, + include_transformed, + rng, ): pytest.importorskip("jax") @@ -154,12 +161,12 @@ def test_find_MAP( use_hessp=use_hessp, progressbar=False, gradient_backend=gradient_backend, + include_transformed=include_transformed, compile_kwargs={"mode": backend.upper()}, maxiter=5, ) assert hasattr(idata, "posterior") - assert hasattr(idata, "unconstrained_posterior") assert hasattr(idata, "fit") assert hasattr(idata, "optimizer_result") assert hasattr(idata, "observed_data") @@ -169,9 +176,13 @@ def test_find_MAP( assert posterior["mu"].shape == () assert posterior["sigma"].shape == () - unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) - assert "sigma_log__" in unconstrained_posterior - assert unconstrained_posterior["sigma_log__"].shape == () + if include_transformed: + assert hasattr(idata, "unconstrained_posterior") + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () + else: + assert not hasattr(idata, "unconstrained_posterior") @pytest.mark.parametrize( diff --git a/tests/inference/laplace_approx/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py index be5665d0..ab0ed34b 100644 --- a/tests/inference/laplace_approx/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -83,7 +83,10 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend): np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3) -def test_fit_laplace_coords(rng): +@pytest.mark.parametrize( + "include_transformed", [True, False], ids=["include_transformed", "no_transformed"] +) +def test_fit_laplace_coords(include_transformed, rng): coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as model: mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) @@ -102,6 +105,7 @@ def test_fit_laplace_coords(rng): chains=1, draws=1000, optimizer_kwargs=dict(tol=1e-20), + include_transformed=include_transformed, ) np.testing.assert_allclose( @@ -120,6 +124,11 @@ def test_fit_laplace_coords(rng): "sigma_log__[C]", ] + assert hasattr(idata, "unconstrained_posterior") == include_transformed + if include_transformed: + assert "sigma_log__" in idata.unconstrained_posterior + assert "city" in idata.unconstrained_posterior.coords + def test_fit_laplace_ragged_coords(rng): coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}