@@ -114,14 +114,14 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
114
114
115
115
116
116
def _compile_grad_and_hess_to_jax (
117
- f_loss : Function , use_hess : bool , use_hessp : bool
117
+ f_fused : Function , use_hess : bool , use_hessp : bool
118
118
) -> tuple [Callable | None , Callable | None ]:
119
119
"""
120
120
Compile loss function gradients using JAX.
121
121
122
122
Parameters
123
123
----------
124
- f_loss : Function
124
+ f_fused : Function
125
125
The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss,
126
126
compiled with mode="JAX".
127
127
use_hess: bool
@@ -131,43 +131,40 @@ def _compile_grad_and_hess_to_jax(
131
131
132
132
Returns
133
133
-------
134
- f_loss_and_grad: Callable
135
- The compiled loss function and gradient function.
136
- f_hess: Callable | None
137
- The compiled hessian function, or None if use_hess is False.
134
+ f_fused: Callable
135
+ The compiled loss function and gradient function, which may also compute the hessian if requested.
138
136
f_hessp: Callable | None
139
137
The compiled hessian-vector product function, or None if use_hessp is False.
140
138
"""
141
139
import jax
142
140
143
- f_hess = None
144
141
f_hessp = None
145
142
146
- orig_loss_fn = f_loss .vm .jit_fn
143
+ orig_loss_fn = f_fused .vm .jit_fn
147
144
148
- @jax .jit
149
- def loss_fn_jax_grad (x ):
150
- return jax .value_and_grad (lambda x : orig_loss_fn (x )[0 ])(x )
145
+ if use_hess :
146
+
147
+ @jax .jit
148
+ def loss_fn_fused (x ):
149
+ loss_and_grad = jax .value_and_grad (lambda x : orig_loss_fn (x )[0 ])(x )
150
+ hess = jax .hessian (lambda x : orig_loss_fn (x )[0 ])(x )
151
+ return * loss_and_grad , hess
152
+
153
+ else :
151
154
152
- f_loss_and_grad = loss_fn_jax_grad
155
+ @jax .jit
156
+ def loss_fn_fused (x ):
157
+ return jax .value_and_grad (lambda x : orig_loss_fn (x )[0 ])(x )
153
158
154
159
if use_hessp :
155
160
156
161
def f_hessp_jax (x , p ):
157
- y , u = jax .jvp (lambda x : f_loss_and_grad (x )[1 ], (x ,), (p ,))
162
+ y , u = jax .jvp (lambda x : loss_fn_fused (x )[1 ], (x ,), (p ,))
158
163
return jax .numpy .stack (u )
159
164
160
165
f_hessp = jax .jit (f_hessp_jax )
161
166
162
- if use_hess :
163
- _f_hess_jax = jax .jacfwd (lambda x : f_loss_and_grad (x )[1 ])
164
-
165
- def f_hess_jax (x ):
166
- return jax .numpy .stack (_f_hess_jax (x ))
167
-
168
- f_hess = jax .jit (f_hess_jax )
169
-
170
- return f_loss_and_grad , f_hess , f_hessp
167
+ return loss_fn_fused , f_hessp
171
168
172
169
173
170
def _compile_functions_for_scipy_optimize (
@@ -199,33 +196,47 @@ def _compile_functions_for_scipy_optimize(
199
196
200
197
Returns
201
198
-------
202
- f_loss: Function
203
-
204
- f_hess: Function | None
199
+ f_fused: Function
200
+ The compiled loss function, which may also include gradients and hessian if requested.
205
201
f_hessp: Function | None
202
+ The compiled hessian-vector product function, or None if compute_hessp is False.
206
203
"""
204
+ compile_kwargs = {} if compile_kwargs is None else compile_kwargs
205
+
207
206
loss = pm .pytensorf .rewrite_pregrad (loss )
208
- f_hess = None
209
207
f_hessp = None
210
208
211
- if compute_grad :
212
- grads = pytensor .gradient .grad (loss , inputs )
213
- grad = pt .concatenate ([grad .ravel () for grad in grads ])
214
- f_loss_and_grad = pm .compile (inputs , [loss , grad ], ** compile_kwargs )
215
- else :
209
+ # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent
210
+ # with the case where we also compute gradients, hessians, or hessian-vector products.
211
+ if not (compute_grad or compute_hess or compute_hessp ):
216
212
f_loss = pm .compile (inputs , loss , ** compile_kwargs )
217
213
return [f_loss ]
218
214
219
- if compute_hess :
220
- hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
221
- f_hess = pm .compile (inputs , hess , ** compile_kwargs )
215
+ # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single
216
+ # fused function and retun it. If the user also wants the hession, the fused function will return the loss,
217
+ # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss
218
+ # and gradients, and a separate function for the hessian-vector product.
222
219
223
220
if compute_hessp :
221
+ # Handle this first, since it can be compiled alone.
224
222
p = pt .tensor ("p" , shape = inputs [0 ].type .shape )
225
223
hessp = pytensor .gradient .hessian_vector_product (loss , inputs , p )
226
224
f_hessp = pm .compile ([* inputs , p ], hessp [0 ], ** compile_kwargs )
227
225
228
- return [f_loss_and_grad , f_hess , f_hessp ]
226
+ outputs = [loss ]
227
+
228
+ if compute_grad :
229
+ grads = pytensor .gradient .grad (loss , inputs )
230
+ grad = pt .concatenate ([grad .ravel () for grad in grads ])
231
+ outputs .append (grad )
232
+
233
+ if compute_hess :
234
+ hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
235
+ outputs .append (hess )
236
+
237
+ f_fused = pm .compile (inputs , outputs , ** compile_kwargs )
238
+
239
+ return [f_fused , f_hessp ]
229
240
230
241
231
242
def scipy_optimize_funcs_from_loss (
@@ -262,10 +273,8 @@ def scipy_optimize_funcs_from_loss(
262
273
263
274
Returns
264
275
-------
265
- f_loss: Callable
266
- The compiled loss function.
267
- f_hess: Callable | None
268
- The compiled hessian function, or None if use_hess is False.
276
+ f_fused: Callable
277
+ The compiled loss function, which may also include gradients and hessian if requested.
269
278
f_hessp: Callable | None
270
279
The compiled hessian-vector product function, or None if use_hessp is False.
271
280
"""
@@ -322,16 +331,15 @@ def scipy_optimize_funcs_from_loss(
322
331
compile_kwargs = compile_kwargs ,
323
332
)
324
333
325
- # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values
326
- f_loss = funcs . pop ( 0 )
327
- f_hess = funcs .pop (0 ) if compute_grad else None
328
- f_hessp = funcs .pop (0 ) if compute_grad else None
334
+ # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients,
335
+ # or the loss function with gradients and hessian.
336
+ f_fused = funcs .pop (0 )
337
+ f_hessp = funcs .pop (0 ) if compute_hessp else None
329
338
330
339
if use_jax_gradients :
331
- # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
332
- f_loss , f_hess , f_hessp = _compile_grad_and_hess_to_jax (f_loss , use_hess , use_hessp )
340
+ f_fused , f_hessp = _compile_grad_and_hess_to_jax (f_fused , use_hess , use_hessp )
333
341
334
- return f_loss , f_hess , f_hessp
342
+ return f_fused , f_hessp
335
343
336
344
337
345
def find_MAP (
@@ -434,7 +442,7 @@ def find_MAP(
434
442
method , use_grad , use_hess , use_hessp
435
443
)
436
444
437
- f_logp , f_hess , f_hessp = scipy_optimize_funcs_from_loss (
445
+ f_fused , f_hessp = scipy_optimize_funcs_from_loss (
438
446
loss = - frozen_model .logp (jacobian = False ),
439
447
inputs = frozen_model .continuous_value_vars + frozen_model .discrete_value_vars ,
440
448
initial_point_dict = start_dict ,
@@ -445,23 +453,21 @@ def find_MAP(
445
453
compile_kwargs = compile_kwargs ,
446
454
)
447
455
448
- args = optimizer_kwargs .pop ("args" , None )
456
+ args = optimizer_kwargs .pop ("args" , () )
449
457
450
458
# better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument
451
459
# if so. That is why the jac argument is not passed here in either branch.
452
460
453
461
if do_basinhopping :
454
462
if "args" not in minimizer_kwargs :
455
463
minimizer_kwargs ["args" ] = args
456
- if "hess" not in minimizer_kwargs :
457
- minimizer_kwargs ["hess" ] = f_hess
458
464
if "hessp" not in minimizer_kwargs :
459
465
minimizer_kwargs ["hessp" ] = f_hessp
460
466
if "method" not in minimizer_kwargs :
461
467
minimizer_kwargs ["method" ] = method
462
468
463
469
optimizer_result = basinhopping (
464
- func = f_logp ,
470
+ func = f_fused ,
465
471
x0 = cast (np .ndarray [float ], initial_params .data ),
466
472
progressbar = progressbar ,
467
473
minimizer_kwargs = minimizer_kwargs ,
@@ -470,10 +476,9 @@ def find_MAP(
470
476
471
477
else :
472
478
optimizer_result = minimize (
473
- f = f_logp ,
479
+ f = f_fused ,
474
480
x0 = cast (np .ndarray [float ], initial_params .data ),
475
481
args = args ,
476
- hess = f_hess ,
477
482
hessp = f_hessp ,
478
483
progressbar = progressbar ,
479
484
method = method ,
@@ -486,6 +491,33 @@ def find_MAP(
486
491
DictToArrayBijection .rmap (raveled_optimized )
487
492
)
488
493
494
+ # Downstream computation will probably want the covaraince matrix at the optimized point, so we compute it here,
495
+ # while we still have access to the compiled function.
496
+ x_star = optimizer_result .x
497
+ n_vars = len (x_star )
498
+
499
+ if method == "BFGS" :
500
+ # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than
501
+ # re-computing something
502
+ getattr (optimizer_result , "hess_inv" , None )
503
+ elif method == "L-BFGS-B" :
504
+ # Here we will have a LinearOperator representing the inverse Hessian-Vector product.
505
+ f_hessp_inv = optimizer_result .hess_inv
506
+ basis = np .eye (n_vars )
507
+ np .stack ([f_hessp_inv (basis [:, i ]) for i in range (n_vars )], axis = - 1 )
508
+
509
+ elif f_hessp is not None :
510
+ # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from
511
+ # the hessp function, using euclidian basis vector.
512
+ basis = np .eye (n_vars )
513
+ H = np .stack ([f_hessp (optimizer_result .x , basis [:, i ]) for i in range (n_vars )], axis = - 1 )
514
+ np .linalg .inv (get_nearest_psd (H ))
515
+
516
+ elif use_hess :
517
+ # If we compiled a hessian function, just use it
518
+ _ , _ , H = f_fused (x_star )
519
+ np .linalg .inv (get_nearest_psd (H ))
520
+
489
521
optimized_point = {
490
522
var .name : value for var , value in zip (unobserved_vars , unobserved_vars_values )
491
523
}
0 commit comments