Skip to content

Commit 5ffaed6

Browse files
Respect include_transformed argument (#537)
1 parent 6531cea commit 5ffaed6

File tree

5 files changed

+40
-17
lines changed

5 files changed

+40
-17
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def find_MAP(
326326
)
327327

328328
raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329-
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed)
329+
unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
330330
unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")(
331331
DictToArrayBijection.rmap(raveled_optimized)
332332
)
@@ -335,7 +335,7 @@ def find_MAP(
335335
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
336336
}
337337

338-
idata = map_results_to_inference_data(optimized_point, frozen_model)
338+
idata = map_results_to_inference_data(optimized_point, frozen_model, include_transformed)
339339
idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv)
340340
idata = add_optimizer_result_to_inference_data(
341341
idata, optimizer_result, method, raveled_optimized, model

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]
5959
def map_results_to_inference_data(
6060
map_point: dict[str, float | int | np.ndarray],
6161
model: pm.Model | None = None,
62+
include_transformed: bool = True,
6263
):
6364
"""
6465
Add the MAP point to an InferenceData object in the posterior group.
@@ -68,13 +69,13 @@ def map_results_to_inference_data(
6869
6970
Parameters
7071
----------
71-
idata: az.InferenceData
72-
An InferenceData object to which the MAP point will be added.
7372
map_point: dict
7473
A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
7574
the values should be the corresponding MAP estimates.
7675
model: Model, optional
7776
A PyMC model. If None, the model is taken from the current model context.
77+
include_transformed: bool
78+
Whether to return transformed (unconstrained) variables in the constrained_posterior group. Default is True.
7879
7980
Returns
8081
-------
@@ -118,7 +119,7 @@ def map_results_to_inference_data(
118119
dims=dims,
119120
)
120121

121-
if unconstrained_names:
122+
if unconstrained_names and include_transformed:
122123
unconstrained_posterior = az.from_dict(
123124
posterior={
124125
k: np.expand_dims(v, (0, 1))

pymc_extras/inference/laplace_approx/laplace.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ def fit_laplace(
302302
----------
303303
model : pm.Model
304304
The PyMC model to be fit. If None, the current model context is used.
305-
method : str
305+
optimize_method : str
306306
The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP,
307307
trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping.
308308
@@ -441,9 +441,11 @@ def fit_laplace(
441441
.rename({"temp_chain": "chain", "temp_draw": "draw"})
442442
)
443443

444-
idata.unconstrained_posterior = unstack_laplace_draws(
445-
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
446-
)
444+
if include_transformed:
445+
idata.unconstrained_posterior = unstack_laplace_draws(
446+
new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
447+
)
448+
447449
idata.posterior = new_posterior.drop_vars(
448450
["laplace_approximation", "unpacked_variable_names"]
449451
)

tests/inference/laplace_approx/test_find_map.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,19 @@ def compute_z(x):
133133
],
134134
)
135135
@pytest.mark.parametrize(
136-
"backend, gradient_backend",
137-
[("jax", "jax"), ("jax", "pytensor")],
136+
"backend, gradient_backend, include_transformed",
137+
[("jax", "jax", True), ("jax", "pytensor", False)],
138138
ids=str,
139139
)
140140
def test_find_MAP(
141-
method, use_grad, use_hess, use_hessp, backend, gradient_backend: GradientBackend, rng
141+
method,
142+
use_grad,
143+
use_hess,
144+
use_hessp,
145+
backend,
146+
gradient_backend: GradientBackend,
147+
include_transformed,
148+
rng,
142149
):
143150
pytest.importorskip("jax")
144151

@@ -154,12 +161,12 @@ def test_find_MAP(
154161
use_hessp=use_hessp,
155162
progressbar=False,
156163
gradient_backend=gradient_backend,
164+
include_transformed=include_transformed,
157165
compile_kwargs={"mode": backend.upper()},
158166
maxiter=5,
159167
)
160168

161169
assert hasattr(idata, "posterior")
162-
assert hasattr(idata, "unconstrained_posterior")
163170
assert hasattr(idata, "fit")
164171
assert hasattr(idata, "optimizer_result")
165172
assert hasattr(idata, "observed_data")
@@ -169,9 +176,13 @@ def test_find_MAP(
169176
assert posterior["mu"].shape == ()
170177
assert posterior["sigma"].shape == ()
171178

172-
unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"])
173-
assert "sigma_log__" in unconstrained_posterior
174-
assert unconstrained_posterior["sigma_log__"].shape == ()
179+
if include_transformed:
180+
assert hasattr(idata, "unconstrained_posterior")
181+
unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"])
182+
assert "sigma_log__" in unconstrained_posterior
183+
assert unconstrained_posterior["sigma_log__"].shape == ()
184+
else:
185+
assert not hasattr(idata, "unconstrained_posterior")
175186

176187

177188
@pytest.mark.parametrize(

tests/inference/laplace_approx/test_laplace.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def test_fit_laplace_basic(mode, gradient_backend: GradientBackend):
8383
np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3)
8484

8585

86-
def test_fit_laplace_coords(rng):
86+
@pytest.mark.parametrize(
87+
"include_transformed", [True, False], ids=["include_transformed", "no_transformed"]
88+
)
89+
def test_fit_laplace_coords(include_transformed, rng):
8790
coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)}
8891
with pm.Model(coords=coords) as model:
8992
mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"])
@@ -102,6 +105,7 @@ def test_fit_laplace_coords(rng):
102105
chains=1,
103106
draws=1000,
104107
optimizer_kwargs=dict(tol=1e-20),
108+
include_transformed=include_transformed,
105109
)
106110

107111
np.testing.assert_allclose(
@@ -120,6 +124,11 @@ def test_fit_laplace_coords(rng):
120124
"sigma_log__[C]",
121125
]
122126

127+
assert hasattr(idata, "unconstrained_posterior") == include_transformed
128+
if include_transformed:
129+
assert "sigma_log__" in idata.unconstrained_posterior
130+
assert "city" in idata.unconstrained_posterior.coords
131+
123132

124133
def test_fit_laplace_ragged_coords(rng):
125134
coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)}

0 commit comments

Comments
 (0)