Skip to content

Commit 93e3aa2

Browse files
Split idata utilities into idata.py
1 parent 23140a5 commit 93e3aa2

File tree

3 files changed

+533
-2
lines changed

3 files changed

+533
-2
lines changed
Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
from functools import reduce
2+
from itertools import product
3+
from typing import Literal
4+
5+
import arviz as az
6+
import numpy as np
7+
import pymc as pm
8+
import xarray as xr
9+
10+
from arviz import dict_to_dataset
11+
from better_optimize.constants import minimize_method
12+
from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations
13+
from pymc.blocking import RaveledVars
14+
from pymc.util import get_default_varnames
15+
from scipy.optimize import OptimizeResult
16+
from scipy.sparse.linalg import LinearOperator
17+
18+
19+
def make_unpacked_variable_names(name, model: pm.Model) -> list[str]:
20+
coords = model.coords
21+
22+
value_to_dim = {
23+
x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None)
24+
for x in model.value_vars
25+
}
26+
value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None}
27+
28+
rv_to_dim = model.named_vars_to_dims
29+
dims_dict = rv_to_dim | value_to_dim
30+
31+
dims = dims_dict.get(name)
32+
if dims is None:
33+
return [name]
34+
labels = product(*(coords[dim] for dim in dims))
35+
return [f"{name}[{','.join(map(str, label))}]" for label in labels]
36+
37+
38+
def laplace_draws_to_inferencedata(
39+
posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None
40+
) -> az.InferenceData:
41+
"""
42+
Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object.
43+
44+
45+
Parameters
46+
----------
47+
posterior_draws: list of np.ndarray
48+
A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where
49+
shape is the shape of the variable in the posterior.
50+
model: Model, optional
51+
A PyMC model. If None, the model is taken from the current model context.
52+
53+
Returns
54+
-------
55+
idata: az.InferenceData
56+
An InferenceData object containing the approximated posterior samples
57+
"""
58+
model = pm.modelcontext(model)
59+
chains, draws, *_ = posterior_draws[0].shape
60+
61+
def make_rv_coords(name):
62+
coords = {"chain": range(chains), "draw": range(draws)}
63+
extra_dims = model.named_vars_to_dims.get(name)
64+
if extra_dims is None:
65+
return coords
66+
return coords | {dim: list(model.coords[dim]) for dim in extra_dims}
67+
68+
def make_rv_dims(name):
69+
dims = ["chain", "draw"]
70+
extra_dims = model.named_vars_to_dims.get(name)
71+
if extra_dims is None:
72+
return dims
73+
return dims + list(extra_dims)
74+
75+
names = [
76+
x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
77+
]
78+
idata = {
79+
name: xr.DataArray(
80+
data=draws,
81+
coords=make_rv_coords(name),
82+
dims=make_rv_dims(name),
83+
name=name,
84+
)
85+
for name, draws in zip(names, posterior_draws)
86+
}
87+
88+
coords, dims = coords_and_dims_for_inferencedata(model)
89+
idata = az.convert_to_inference_data(idata, coords=coords, dims=dims)
90+
91+
return idata
92+
93+
94+
def add_fit_to_inferencedata(
95+
idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
96+
) -> az.InferenceData:
97+
"""
98+
Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
99+
100+
101+
Parameters
102+
----------
103+
idata: az.InfereceData
104+
An InferenceData object containing the approximated posterior samples.
105+
mu: RaveledVars
106+
The MAP estimate of the model parameters.
107+
H_inv: np.ndarray
108+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
109+
model: Model, optional
110+
A PyMC model. If None, the model is taken from the current model context.
111+
112+
Returns
113+
-------
114+
idata: az.InferenceData
115+
The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group.
116+
"""
117+
model = pm.modelcontext(model)
118+
119+
variable_names, *_ = zip(*mu.point_map_info)
120+
121+
unpacked_variable_names = reduce(
122+
lambda lst, name: lst + make_unpacked_variable_names(name, model), variable_names, []
123+
)
124+
125+
mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names})
126+
cov_dataarray = xr.DataArray(
127+
H_inv,
128+
dims=["rows", "columns"],
129+
coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
130+
)
131+
132+
dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray})
133+
idata.add_groups(fit=dataset)
134+
135+
return idata
136+
137+
138+
def add_data_to_inferencedata(
139+
idata: az.InferenceData,
140+
progressbar: bool = True,
141+
model: pm.Model | None = None,
142+
compile_kwargs: dict | None = None,
143+
) -> az.InferenceData:
144+
"""
145+
Add observed and constant data to an InferenceData object.
146+
147+
Parameters
148+
----------
149+
idata: az.InferenceData
150+
An InferenceData object containing the approximated posterior samples.
151+
progressbar: bool
152+
Whether to display a progress bar during computations. Default is True.
153+
model: Model, optional
154+
A PyMC model. If None, the model is taken from the current model context.
155+
compile_kwargs: dict, optional
156+
Additional keyword arguments to pass to pytensor.function.
157+
158+
Returns
159+
-------
160+
idata: az.InferenceData
161+
The provided InferenceData, with observed and constant data added.
162+
"""
163+
model = pm.modelcontext(model)
164+
165+
if model.deterministics:
166+
idata.posterior = pm.compute_deterministics(
167+
idata.posterior,
168+
model=model,
169+
merge_dataset=True,
170+
progressbar=progressbar,
171+
compile_kwargs=compile_kwargs,
172+
)
173+
174+
coords, dims = coords_and_dims_for_inferencedata(model)
175+
176+
observed_data = dict_to_dataset(
177+
find_observations(model),
178+
library=pm,
179+
coords=coords,
180+
dims=dims,
181+
default_dims=[],
182+
)
183+
184+
constant_data = dict_to_dataset(
185+
find_constants(model),
186+
library=pm,
187+
coords=coords,
188+
dims=dims,
189+
default_dims=[],
190+
)
191+
192+
idata.add_groups(
193+
{"observed_data": observed_data, "constant_data": constant_data},
194+
coords=coords,
195+
dims=dims,
196+
)
197+
198+
return idata
199+
200+
201+
def optimizer_result_to_dataset(
202+
result: OptimizeResult,
203+
method: minimize_method | Literal["basinhopping"],
204+
mu: RaveledVars | None = None,
205+
model: pm.Model | None = None,
206+
) -> xr.Dataset:
207+
"""
208+
Convert an OptimizeResult object to an xarray Dataset object.
209+
210+
Parameters
211+
----------
212+
result: OptimizeResult
213+
The result of the optimization process.
214+
method: minimize_method or "basinhopping"
215+
The optimization method used.
216+
217+
Returns
218+
-------
219+
dataset: xr.Dataset
220+
An xarray Dataset containing the optimization results.
221+
"""
222+
if not isinstance(result, OptimizeResult):
223+
raise TypeError("result must be an instance of OptimizeResult")
224+
model = pm.modelcontext(model) if model is None else model
225+
variable_names, *_ = zip(*mu.point_map_info)
226+
unpacked_variable_names = reduce(
227+
lambda lst, name: lst + make_unpacked_variable_names(name, model), variable_names, []
228+
)
229+
230+
data_vars = {}
231+
232+
if hasattr(result, "x"):
233+
data_vars["x"] = xr.DataArray(
234+
result.x, dims=["variables"], coords={"variables": unpacked_variable_names}
235+
)
236+
if hasattr(result, "fun"):
237+
data_vars["fun"] = xr.DataArray(result.fun, dims=[])
238+
if hasattr(result, "success"):
239+
data_vars["success"] = xr.DataArray(result.success, dims=[])
240+
if hasattr(result, "message"):
241+
data_vars["message"] = xr.DataArray(str(result.message), dims=[])
242+
if hasattr(result, "jac") and result.jac is not None:
243+
jac = np.asarray(result.jac)
244+
if jac.ndim == 1:
245+
data_vars["jac"] = xr.DataArray(
246+
jac, dims=["variables"], coords={"variables": unpacked_variable_names}
247+
)
248+
else:
249+
data_vars["jac"] = xr.DataArray(
250+
jac,
251+
dims=["variables", "variables_aux"],
252+
coords={
253+
"variables": unpacked_variable_names,
254+
"variables_aux": unpacked_variable_names,
255+
},
256+
)
257+
258+
if hasattr(result, "hess_inv") and result.hess_inv is not None:
259+
hess_inv = result.hess_inv
260+
if isinstance(hess_inv, LinearOperator):
261+
n = hess_inv.shape[0]
262+
eye = np.eye(n)
263+
hess_inv_mat = np.column_stack([hess_inv.matvec(eye[:, i]) for i in range(n)])
264+
hess_inv = hess_inv_mat
265+
else:
266+
hess_inv = np.asarray(hess_inv)
267+
data_vars["hess_inv"] = xr.DataArray(
268+
hess_inv,
269+
dims=["variables", "variables_aux"],
270+
coords={"variables": unpacked_variable_names, "variables_aux": unpacked_variable_names},
271+
)
272+
273+
if hasattr(result, "nit"):
274+
data_vars["nit"] = xr.DataArray(result.nit, dims=[])
275+
if hasattr(result, "nfev"):
276+
data_vars["nfev"] = xr.DataArray(result.nfev, dims=[])
277+
if hasattr(result, "njev"):
278+
data_vars["njev"] = xr.DataArray(result.njev, dims=[])
279+
if hasattr(result, "status"):
280+
data_vars["status"] = xr.DataArray(result.status, dims=[])
281+
282+
# Add any other fields present in result
283+
for key, value in result.items():
284+
if key in data_vars:
285+
continue # already added
286+
if value is None:
287+
continue
288+
arr = np.asarray(value)
289+
290+
# TODO: We can probably do something smarter here with a dictionary of all possible values and their expected
291+
# dimensions.
292+
dims = [f"{key}_dim_{i}" for i in range(arr.ndim)]
293+
data_vars[key] = xr.DataArray(
294+
arr,
295+
dims=dims,
296+
coords={f"{key}_dim_{i}": np.arange(arr.shape[i]) for i in range(len(dims))},
297+
)
298+
299+
data_vars["method"] = xr.DataArray(np.array(method), dims=[])
300+
301+
return xr.Dataset(data_vars)

pymc_extras/inference/pathfinder/pathfinder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
# TODO: change to typing.Self after Python versions greater than 3.10
6464
from typing_extensions import Self
6565

66-
from pymc_extras.inference.laplace_approx.idata import add_data_to_inferencedata
66+
from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data
6767
from pymc_extras.inference.pathfinder.importance_sampling import (
6868
importance_sampling as _importance_sampling,
6969
)
@@ -1759,6 +1759,6 @@ def fit_pathfinder(
17591759
importance_sampling=importance_sampling,
17601760
)
17611761

1762-
idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs)
1762+
idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs)
17631763

17641764
return idata

0 commit comments

Comments
 (0)