Skip to content

Commit c23bc33

Browse files
Refactor find_MAP
1 parent 93e3aa2 commit c23bc33

File tree

8 files changed

+844
-422
lines changed

8 files changed

+844
-422
lines changed

pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 125 additions & 360 deletions
Large diffs are not rendered by default.

pymc_extras/inference/laplace_approx/idata.py

Lines changed: 140 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from functools import reduce
22
from itertools import product
3-
from typing import Literal
3+
from typing import Any, Literal
44

55
import arviz as az
66
import numpy as np
@@ -35,6 +35,29 @@ def make_unpacked_variable_names(name, model: pm.Model) -> list[str]:
3535
return [f"{name}[{','.join(map(str, label))}]" for label in labels]
3636

3737

38+
def map_results_to_inferece_data(results: dict[str, Any], model: pm.Model | None = None):
39+
"""
40+
Convert a dictionary of results to an InferenceData object.
41+
42+
Parameters
43+
----------
44+
results: dict
45+
A dictionary containing the results to convert.
46+
model: Model, optional
47+
A PyMC model. If None, the model is taken from the current model context.
48+
49+
Returns
50+
-------
51+
idata: az.InferenceData
52+
An InferenceData object containing the results.
53+
"""
54+
model = pm.modelcontext(model)
55+
coords, dims = coords_and_dims_for_inferencedata(model)
56+
57+
idata = az.convert_to_inference_data(results, coords=coords, dims=dims)
58+
return idata
59+
60+
3861
def laplace_draws_to_inferencedata(
3962
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
4063
) -> az.InferenceData:
@@ -91,13 +114,67 @@ def make_rv_dims(name):
91114
return idata
92115

93116

94-
def add_fit_to_inferencedata(
117+
def add_map_posterior_to_inference_data(
118+
idata: az.InferenceData,
119+
map_point: dict[str, float | int | np.ndarray],
120+
model: pm.Model | None = None,
121+
):
122+
"""
123+
Add the MAP point to an InferenceData object in the posterior group.
124+
125+
Unlike a typical posterior, the MAP point is a single point estimate rather than a distribution. As a result, it
126+
does not have a chain or draw dimension, and is stored as a single point in the posterior group.
127+
128+
Parameters
129+
----------
130+
idata: az.InferenceData
131+
An InferenceData object to which the MAP point will be added.
132+
map_point: dict
133+
A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and
134+
the values should be the corresponding MAP estimates.
135+
model: Model, optional
136+
A PyMC model. If None, the model is taken from the current model context.
137+
138+
Returns
139+
-------
140+
idata: az.InferenceData
141+
The provided InferenceData, with the MAP point added to the posterior group.
142+
"""
143+
144+
model = pm.modelcontext(model) if model is None else model
145+
coords, dims = coords_and_dims_for_inferencedata(model)
146+
147+
# The MAP point will have both the transformed and untransformed variables, so we need to ensure that
148+
# we have the correct dimensions for each variable.
149+
var_name_to_value_name = {rv.name: value.name for rv, value in model.rvs_to_values.items()}
150+
dims.update(
151+
{
152+
value_name: dims[var_name]
153+
for var_name, value_name in var_name_to_value_name.items()
154+
if var_name in dims
155+
}
156+
)
157+
158+
posterior_data = {
159+
name: xr.DataArray(
160+
data=np.asarray(value),
161+
coords={dim: coords[dim] for dim in dims.get(name, [])},
162+
dims=dims.get(name),
163+
name=name,
164+
)
165+
for name, value in map_point.items()
166+
}
167+
idata.add_groups(posterior=xr.Dataset(posterior_data))
168+
169+
return idata
170+
171+
172+
def add_fit_to_inference_data(
95173
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
96174
) -> az.InferenceData:
97175
"""
98176
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
99177
100-
101178
Parameters
102179
----------
103180
idata: az.InfereceData
@@ -123,19 +200,24 @@ def add_fit_to_inferencedata(
123200
)
124201

125202
mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names})
126-
cov_dataarray = xr.DataArray(
127-
H_inv,
128-
dims=["rows", "columns"],
129-
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
130-
)
131203

132-
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
204+
data = {"mean_vector": mean_dataarray}
205+
206+
if H_inv is not None:
207+
cov_dataarray = xr.DataArray(
208+
H_inv,
209+
dims=["rows", "columns"],
210+
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
211+
)
212+
data["covariance_matrix"] = cov_dataarray
213+
214+
dataset = xr.Dataset(data)
133215
idata.add_groups(fit=dataset)
134216

135217
return idata
136218

137219

138-
def add_data_to_inferencedata(
220+
def add_data_to_inference_data(
139221
idata: az.InferenceData,
140222
progressbar: bool = True,
141223
model: pm.Model | None = None,
@@ -163,8 +245,14 @@ def add_data_to_inferencedata(
163245
model = pm.modelcontext(model)
164246

165247
if model.deterministics:
248+
expand_dims = {}
249+
if "chain" not in idata.posterior.coords:
250+
expand_dims["chain"] = [0]
251+
if "draw" not in idata.posterior.coords:
252+
expand_dims["draw"] = [0]
253+
166254
idata.posterior = pm.compute_deterministics(
167-
idata.posterior,
255+
idata.posterior.expand_dims(expand_dims),
168256
model=model,
169257
merge_dataset=True,
170258
progressbar=progressbar,
@@ -229,6 +317,13 @@ def optimizer_result_to_dataset(
229317

230318
data_vars = {}
231319

320+
if hasattr(result, "lowest_optimization_result"):
321+
# If we did basinhopping, there's a results inside the results. We want to pop this out and collapse them,
322+
# overwriting outer keys with the inner keys
323+
inner_res = result.pop("lowest_optimization_result")
324+
for key in inner_res.keys():
325+
result[key] = inner_res[key]
326+
232327
if hasattr(result, "x"):
233328
data_vars["x"] = xr.DataArray(
234329
result.x, dims=["variables"], coords={"variables": unpacked_variable_names}
@@ -299,3 +394,37 @@ def optimizer_result_to_dataset(
299394
data_vars["method"] = xr.DataArray(np.array(method), dims=[])
300395

301396
return xr.Dataset(data_vars)
397+
398+
399+
def add_optimizer_result_to_inference_data(
400+
idata: az.InferenceData,
401+
result: OptimizeResult,
402+
method: minimize_method | Literal["basinhopping"],
403+
mu: RaveledVars | None = None,
404+
model: pm.Model | None = None,
405+
) -> az.InferenceData:
406+
"""
407+
Add the optimization result to an InferenceData object.
408+
409+
Parameters
410+
----------
411+
idata: az.InferenceData
412+
An InferenceData object containing the approximated posterior samples.
413+
result: OptimizeResult
414+
The result of the optimization process.
415+
method: minimize_method or "basinhopping"
416+
The optimization method used.
417+
mu: RaveledVars, optional
418+
The MAP estimate of the model parameters.
419+
model: Model, optional
420+
A PyMC model. If None, the model is taken from the current model context.
421+
422+
Returns
423+
-------
424+
idata: az.InferenceData
425+
The provided InferenceData, with the optimization results added to the "optimizer" group.
426+
"""
427+
dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model)
428+
idata.add_groups({"optimizer_result": dataset})
429+
430+
return idata

0 commit comments

Comments
 (0)