Skip to content

Commit 23140a5

Browse files
Move laplace and find_map to submodule
1 parent 862e52d commit 23140a5

File tree

10 files changed

+136
-268
lines changed

10 files changed

+136
-268
lines changed

pymc_extras/inference/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from pymc_extras.inference.find_map import find_MAP
1615
from pymc_extras.inference.fit import fit
17-
from pymc_extras.inference.laplace import fit_laplace
16+
from pymc_extras.inference.laplace_approx.find_map import find_MAP
17+
from pymc_extras.inference.laplace_approx.laplace import fit_laplace
1818
from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
1919

2020
__all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"]

pymc_extras/inference/fit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,6 @@ def fit(method: str, **kwargs) -> az.InferenceData:
3737
return fit_pathfinder(**kwargs)
3838

3939
if method == "laplace":
40-
from pymc_extras.inference.laplace import fit_laplace
40+
from pymc_extras.inference import fit_laplace
4141

4242
return fit_laplace(**kwargs)

pymc_extras/inference/laplace_approx/__init__.py

Whitespace-only changes.

pymc_extras/inference/find_map.py renamed to pymc_extras/inference/laplace_approx/find_map.py

Lines changed: 85 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,14 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
114114

115115

116116
def _compile_grad_and_hess_to_jax(
117-
f_loss: Function, use_hess: bool, use_hessp: bool
117+
f_fused: Function, use_hess: bool, use_hessp: bool
118118
) -> tuple[Callable | None, Callable | None]:
119119
"""
120120
Compile loss function gradients using JAX.
121121
122122
Parameters
123123
----------
124-
f_loss: Function
124+
f_fused: Function
125125
The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
126126
compiled with mode="JAX".
127127
use_hess: bool
@@ -131,43 +131,40 @@ def _compile_grad_and_hess_to_jax(
131131
132132
Returns
133133
-------
134-
f_loss_and_grad: Callable
135-
The compiled loss function and gradient function.
136-
f_hess: Callable | None
137-
The compiled hessian function, or None if use_hess is False.
134+
f_fused: Callable
135+
The compiled loss function and gradient function, which may also compute the hessian if requested.
138136
f_hessp: Callable | None
139137
The compiled hessian-vector product function, or None if use_hessp is False.
140138
"""
141139
import jax
142140

143-
f_hess = None
144141
f_hessp = None
145142

146-
orig_loss_fn = f_loss.vm.jit_fn
143+
orig_loss_fn = f_fused.vm.jit_fn
147144

148-
@jax.jit
149-
def loss_fn_jax_grad(x):
150-
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
145+
if use_hess:
146+
147+
@jax.jit
148+
def loss_fn_fused(x):
149+
loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
150+
hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x)
151+
return *loss_and_grad, hess
152+
153+
else:
151154

152-
f_loss_and_grad = loss_fn_jax_grad
155+
@jax.jit
156+
def loss_fn_fused(x):
157+
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
153158

154159
if use_hessp:
155160

156161
def f_hessp_jax(x, p):
157-
y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,))
162+
y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,))
158163
return jax.numpy.stack(u)
159164

160165
f_hessp = jax.jit(f_hessp_jax)
161166

162-
if use_hess:
163-
_f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1])
164-
165-
def f_hess_jax(x):
166-
return jax.numpy.stack(_f_hess_jax(x))
167-
168-
f_hess = jax.jit(f_hess_jax)
169-
170-
return f_loss_and_grad, f_hess, f_hessp
167+
return loss_fn_fused, f_hessp
171168

172169

173170
def _compile_functions_for_scipy_optimize(
@@ -199,33 +196,47 @@ def _compile_functions_for_scipy_optimize(
199196
200197
Returns
201198
-------
202-
f_loss: Function
203-
204-
f_hess: Function | None
199+
f_fused: Function
200+
The compiled loss function, which may also include gradients and hessian if requested.
205201
f_hessp: Function | None
202+
The compiled hessian-vector product function, or None if compute_hessp is False.
206203
"""
204+
compile_kwargs = {} if compile_kwargs is None else compile_kwargs
205+
207206
loss = pm.pytensorf.rewrite_pregrad(loss)
208-
f_hess = None
209207
f_hessp = None
210208

211-
if compute_grad:
212-
grads = pytensor.gradient.grad(loss, inputs)
213-
grad = pt.concatenate([grad.ravel() for grad in grads])
214-
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
215-
else:
209+
# In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent
210+
# with the case where we also compute gradients, hessians, or hessian-vector products.
211+
if not (compute_grad or compute_hess or compute_hessp):
216212
f_loss = pm.compile(inputs, loss, **compile_kwargs)
217213
return [f_loss]
218214

219-
if compute_hess:
220-
hess = pytensor.gradient.jacobian(grad, inputs)[0]
221-
f_hess = pm.compile(inputs, hess, **compile_kwargs)
215+
# Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single
216+
# fused function and retun it. If the user also wants the hession, the fused function will return the loss,
217+
# gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss
218+
# and gradients, and a separate function for the hessian-vector product.
222219

223220
if compute_hessp:
221+
# Handle this first, since it can be compiled alone.
224222
p = pt.tensor("p", shape=inputs[0].type.shape)
225223
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
226224
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
227225

228-
return [f_loss_and_grad, f_hess, f_hessp]
226+
outputs = [loss]
227+
228+
if compute_grad:
229+
grads = pytensor.gradient.grad(loss, inputs)
230+
grad = pt.concatenate([grad.ravel() for grad in grads])
231+
outputs.append(grad)
232+
233+
if compute_hess:
234+
hess = pytensor.gradient.jacobian(grad, inputs)[0]
235+
outputs.append(hess)
236+
237+
f_fused = pm.compile(inputs, outputs, **compile_kwargs)
238+
239+
return [f_fused, f_hessp]
229240

230241

231242
def scipy_optimize_funcs_from_loss(
@@ -262,10 +273,8 @@ def scipy_optimize_funcs_from_loss(
262273
263274
Returns
264275
-------
265-
f_loss: Callable
266-
The compiled loss function.
267-
f_hess: Callable | None
268-
The compiled hessian function, or None if use_hess is False.
276+
f_fused: Callable
277+
The compiled loss function, which may also include gradients and hessian if requested.
269278
f_hessp: Callable | None
270279
The compiled hessian-vector product function, or None if use_hessp is False.
271280
"""
@@ -322,16 +331,15 @@ def scipy_optimize_funcs_from_loss(
322331
compile_kwargs=compile_kwargs,
323332
)
324333

325-
# f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
326-
f_loss = funcs.pop(0)
327-
f_hess = funcs.pop(0) if compute_grad else None
328-
f_hessp = funcs.pop(0) if compute_grad else None
334+
# Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients,
335+
# or the loss function with gradients and hessian.
336+
f_fused = funcs.pop(0)
337+
f_hessp = funcs.pop(0) if compute_hessp else None
329338

330339
if use_jax_gradients:
331-
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
332-
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
340+
f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp)
333341

334-
return f_loss, f_hess, f_hessp
342+
return f_fused, f_hessp
335343

336344

337345
def find_MAP(
@@ -434,7 +442,7 @@ def find_MAP(
434442
method, use_grad, use_hess, use_hessp
435443
)
436444

437-
f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss(
445+
f_fused, f_hessp = scipy_optimize_funcs_from_loss(
438446
loss=-frozen_model.logp(jacobian=False),
439447
inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars,
440448
initial_point_dict=start_dict,
@@ -445,23 +453,21 @@ def find_MAP(
445453
compile_kwargs=compile_kwargs,
446454
)
447455

448-
args = optimizer_kwargs.pop("args", None)
456+
args = optimizer_kwargs.pop("args", ())
449457

450458
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
451459
# if so. That is why the jac argument is not passed here in either branch.
452460

453461
if do_basinhopping:
454462
if "args" not in minimizer_kwargs:
455463
minimizer_kwargs["args"] = args
456-
if "hess" not in minimizer_kwargs:
457-
minimizer_kwargs["hess"] = f_hess
458464
if "hessp" not in minimizer_kwargs:
459465
minimizer_kwargs["hessp"] = f_hessp
460466
if "method" not in minimizer_kwargs:
461467
minimizer_kwargs["method"] = method
462468

463469
optimizer_result = basinhopping(
464-
func=f_logp,
470+
func=f_fused,
465471
x0=cast(np.ndarray[float], initial_params.data),
466472
progressbar=progressbar,
467473
minimizer_kwargs=minimizer_kwargs,
@@ -470,10 +476,9 @@ def find_MAP(
470476

471477
else:
472478
optimizer_result = minimize(
473-
f=f_logp,
479+
f=f_fused,
474480
x0=cast(np.ndarray[float], initial_params.data),
475481
args=args,
476-
hess=f_hess,
477482
hessp=f_hessp,
478483
progressbar=progressbar,
479484
method=method,
@@ -486,6 +491,33 @@ def find_MAP(
486491
DictToArrayBijection.rmap(raveled_optimized)
487492
)
488493

494+
# Downstream computation will probably want the covaraince matrix at the optimized point, so we compute it here,
495+
# while we still have access to the compiled function.
496+
x_star = optimizer_result.x
497+
n_vars = len(x_star)
498+
499+
if method == "BFGS":
500+
# If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than
501+
# re-computing something
502+
getattr(optimizer_result, "hess_inv", None)
503+
elif method == "L-BFGS-B":
504+
# Here we will have a LinearOperator representing the inverse Hessian-Vector product.
505+
f_hessp_inv = optimizer_result.hess_inv
506+
basis = np.eye(n_vars)
507+
np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1)
508+
509+
elif f_hessp is not None:
510+
# In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from
511+
# the hessp function, using euclidian basis vector.
512+
basis = np.eye(n_vars)
513+
H = np.stack([f_hessp(optimizer_result.x, basis[:, i]) for i in range(n_vars)], axis=-1)
514+
np.linalg.inv(get_nearest_psd(H))
515+
516+
elif use_hess:
517+
# If we compiled a hessian function, just use it
518+
_, _, H = f_fused(x_star)
519+
np.linalg.inv(get_nearest_psd(H))
520+
489521
optimized_point = {
490522
var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values)
491523
}

0 commit comments

Comments
 (0)