diff --git a/pymc_extras/inference/__init__.py b/pymc_extras/inference/__init__.py index a01fdd5c3..a536f91e6 100644 --- a/pymc_extras/inference/__init__.py +++ b/pymc_extras/inference/__init__.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pymc_extras.inference.find_map import find_MAP from pymc_extras.inference.fit import fit -from pymc_extras.inference.laplace import fit_laplace +from pymc_extras.inference.laplace_approx.find_map import find_MAP +from pymc_extras.inference.laplace_approx.laplace import fit_laplace from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder __all__ = ["fit", "fit_pathfinder", "fit_laplace", "find_MAP"] diff --git a/pymc_extras/inference/find_map.py b/pymc_extras/inference/find_map.py deleted file mode 100644 index a4d664789..000000000 --- a/pymc_extras/inference/find_map.py +++ /dev/null @@ -1,496 +0,0 @@ -import logging - -from collections.abc import Callable -from importlib.util import find_spec -from typing import Literal, cast, get_args - -import numpy as np -import pymc as pm -import pytensor -import pytensor.tensor as pt - -from better_optimize import basinhopping, minimize -from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method -from pymc.blocking import DictToArrayBijection, RaveledVars -from pymc.initial_point import make_initial_point_fn -from pymc.model.transform.optimization import freeze_dims_and_data -from pymc.pytensorf import join_nonshared_inputs -from pymc.util import get_default_varnames -from pytensor.compile import Function -from pytensor.compile.mode import Mode -from pytensor.tensor import TensorVariable -from scipy.optimize import OptimizeResult - -_log = logging.getLogger(__name__) - -GradientBackend = Literal["pytensor", "jax"] -VALID_BACKENDS = get_args(GradientBackend) - - -def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): - method_info = MINIMIZE_MODE_KWARGS[method].copy() - - if use_hess and use_hessp: - _log.warning( - 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the ' - 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. ' - 'Setting "use_hess" to False.' - ) - use_hess = False - - use_grad = use_grad if use_grad is not None else method_info["uses_grad"] - - if use_hessp is not None and use_hess is None: - use_hess = not use_hessp - - elif use_hess is not None and use_hessp is None: - use_hessp = not use_hess - - elif use_hessp is None and use_hess is None: - use_hessp = method_info["uses_hessp"] - use_hess = method_info["uses_hess"] - if use_hessp and use_hess: - # If a method could use either hess or hessp, we default to using hessp - use_hess = False - - return use_grad, use_hess, use_hessp - - -def get_nearest_psd(A: np.ndarray) -> np.ndarray: - """ - Compute the nearest positive semi-definite matrix to a given matrix. - - This function takes a square matrix and returns the nearest positive semi-definite matrix using - eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms - of the Frobenius norm. - - Parameters - ---------- - A : np.ndarray - Input square matrix. - - Returns - ------- - np.ndarray - The nearest positive semi-definite matrix to the input matrix. - """ - C = (A + A.T) / 2 - eigval, eigvec = np.linalg.eigh(C) - eigval[eigval < 0] = 0 - - return eigvec @ np.diag(eigval) @ eigvec.T - - -def _unconstrained_vector_to_constrained_rvs(model): - constrained_rvs, unconstrained_vector = join_nonshared_inputs( - model.initial_point(), - inputs=model.value_vars, - outputs=get_default_varnames(model.unobserved_value_vars, include_transformed=False), - ) - - unconstrained_vector.name = "unconstrained_vector" - return constrained_rvs, unconstrained_vector - - -def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model, chains, draws): - X = pt.tensor("transformed_draws", shape=(chains, draws, H_inv.shape[0])) - out = [] - for rv, idx in slices.items(): - f = model.rvs_to_transforms[rv] - untransformed_X = f.backward(X[..., idx]) if f is not None else X[..., idx] - - if rv in out_shapes: - new_shape = (chains, draws) + out_shapes[rv] - untransformed_X = untransformed_X.reshape(new_shape) - - out.append(untransformed_X) - - f_untransform = pytensor.function( - inputs=[pytensor.In(X, borrow=True)], - outputs=pytensor.Out(out, borrow=True), - mode=Mode(linker="py", optimizer="FAST_COMPILE"), - ) - return f_untransform(posterior_draws) - - -def _compile_grad_and_hess_to_jax( - f_loss: Function, use_hess: bool, use_hessp: bool -) -> tuple[Callable | None, Callable | None]: - """ - Compile loss function gradients using JAX. - - Parameters - ---------- - f_loss: Function - The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, - compiled with mode="JAX". - use_hess: bool - Whether to compile a function to compute the hessian of the loss function. - use_hessp: bool - Whether to compile a function to compute the hessian-vector product of the loss function. - - Returns - ------- - f_loss_and_grad: Callable - The compiled loss function and gradient function. - f_hess: Callable | None - The compiled hessian function, or None if use_hess is False. - f_hessp: Callable | None - The compiled hessian-vector product function, or None if use_hessp is False. - """ - import jax - - f_hess = None - f_hessp = None - - orig_loss_fn = f_loss.vm.jit_fn - - @jax.jit - def loss_fn_jax_grad(x): - return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) - - f_loss_and_grad = loss_fn_jax_grad - - if use_hessp: - - def f_hessp_jax(x, p): - y, u = jax.jvp(lambda x: f_loss_and_grad(x)[1], (x,), (p,)) - return jax.numpy.stack(u) - - f_hessp = jax.jit(f_hessp_jax) - - if use_hess: - _f_hess_jax = jax.jacfwd(lambda x: f_loss_and_grad(x)[1]) - - def f_hess_jax(x): - return jax.numpy.stack(_f_hess_jax(x)) - - f_hess = jax.jit(f_hess_jax) - - return f_loss_and_grad, f_hess, f_hessp - - -def _compile_functions_for_scipy_optimize( - loss: TensorVariable, - inputs: list[TensorVariable], - compute_grad: bool, - compute_hess: bool, - compute_hessp: bool, - compile_kwargs: dict | None = None, -) -> list[Function] | list[Function, Function | None, Function | None]: - """ - Compile loss functions for use with scipy.optimize.minimize. - - Parameters - ---------- - loss: TensorVariable - The loss function to compile. - inputs: list[TensorVariable] - A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines - expect the function signature to be f(x, *args), where x is a 1D array of parameters. - compute_grad: bool - Whether to compile a function that computes the gradients of the loss function. - compute_hess: bool - Whether to compile a function that computes the Hessian of the loss function. - compute_hessp: bool - Whether to compile a function that computes the Hessian-vector product of the loss function. - compile_kwargs: dict, optional - Additional keyword arguments to pass to the ``pm.compile`` function. - - Returns - ------- - f_loss: Function - - f_hess: Function | None - f_hessp: Function | None - """ - loss = pm.pytensorf.rewrite_pregrad(loss) - f_hess = None - f_hessp = None - - if compute_grad: - grads = pytensor.gradient.grad(loss, inputs) - grad = pt.concatenate([grad.ravel() for grad in grads]) - f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs) - else: - f_loss = pm.compile(inputs, loss, **compile_kwargs) - return [f_loss] - - if compute_hess: - hess = pytensor.gradient.jacobian(grad, inputs)[0] - f_hess = pm.compile(inputs, hess, **compile_kwargs) - - if compute_hessp: - p = pt.tensor("p", shape=inputs[0].type.shape) - hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) - f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs) - - return [f_loss_and_grad, f_hess, f_hessp] - - -def scipy_optimize_funcs_from_loss( - loss: TensorVariable, - inputs: list[TensorVariable], - initial_point_dict: dict[str, np.ndarray | float | int], - use_grad: bool, - use_hess: bool, - use_hessp: bool, - gradient_backend: GradientBackend = "pytensor", - compile_kwargs: dict | None = None, -) -> tuple[Callable, ...]: - """ - Compile loss functions for use with scipy.optimize.minimize. - - Parameters - ---------- - loss: TensorVariable - The loss function to compile. - inputs: list[TensorVariable] - The input variables to the loss function. - initial_point_dict: dict[str, np.ndarray | float | int] - Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables. - use_grad: bool - Whether to compile a function that computes the gradients of the loss function. - use_hess: bool - Whether to compile a function that computes the Hessian of the loss function. - use_hessp: bool - Whether to compile a function that computes the Hessian-vector product of the loss function. - gradient_backend: str, default "pytensor" - Which backend to use to compute gradients. Must be one of "jax" or "pytensor" - compile_kwargs: - Additional keyword arguments to pass to the ``pm.compile`` function. - - Returns - ------- - f_loss: Callable - The compiled loss function. - f_hess: Callable | None - The compiled hessian function, or None if use_hess is False. - f_hessp: Callable | None - The compiled hessian-vector product function, or None if use_hessp is False. - """ - - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - - if (use_hess or use_hessp) and not use_grad: - raise ValueError( - "Cannot compute hessian or hessian-vector product without also computing the gradient" - ) - - if gradient_backend not in VALID_BACKENDS: - raise ValueError( - f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}" - ) - - use_jax_gradients = (gradient_backend == "jax") and use_grad - if use_jax_gradients and not find_spec("jax"): - raise ImportError("JAX must be installed to use JAX gradients") - - mode = compile_kwargs.get("mode", None) - if mode is None and use_jax_gradients: - compile_kwargs["mode"] = "JAX" - elif mode != "JAX" and use_jax_gradients: - raise ValueError( - 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"' - ) - - if not isinstance(inputs, list): - inputs = [inputs] - - [loss], flat_input = join_nonshared_inputs( - point=initial_point_dict, outputs=[loss], inputs=inputs - ) - - # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When - # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them - # away. - if use_jax_gradients: - from pymc.sampling.jax import _replace_shared_variables - - [loss] = _replace_shared_variables([loss]) - - compute_grad = use_grad and not use_jax_gradients - compute_hess = use_hess and not use_jax_gradients - compute_hessp = use_hessp and not use_jax_gradients - - funcs = _compile_functions_for_scipy_optimize( - loss=loss, - inputs=[flat_input], - compute_grad=compute_grad, - compute_hess=compute_hess, - compute_hessp=compute_hessp, - compile_kwargs=compile_kwargs, - ) - - # f_loss here is f_loss_and_grad if compute_grad = True. The name is unchanged to simplify the return values - f_loss = funcs.pop(0) - f_hess = funcs.pop(0) if compute_grad else None - f_hessp = funcs.pop(0) if compute_grad else None - - if use_jax_gradients: - # f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values - f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp) - - return f_loss, f_hess, f_hessp - - -def find_MAP( - method: minimize_method | Literal["basinhopping"], - *, - model: pm.Model | None = None, - use_grad: bool | None = None, - use_hessp: bool | None = None, - use_hess: bool | None = None, - initvals: dict | None = None, - random_seed: int | np.random.Generator | None = None, - return_raw: bool = False, - jitter_rvs: list[TensorVariable] | None = None, - progressbar: bool = True, - include_transformed: bool = True, - gradient_backend: GradientBackend = "pytensor", - compile_kwargs: dict | None = None, - **optimizer_kwargs, -) -> dict[str, np.ndarray] | tuple[dict[str, np.ndarray], OptimizeResult]: - """ - Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize. - - Parameters - ---------- - model : pm.Model - The PyMC model to be fit. If None, the current model context is used. - method : str - The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, - trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. - - See scipy.optimize.minimize documentation for details. - use_grad : bool | None, optional - Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on - the ``method``. - use_hessp : bool | None, optional - Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on - the ``method``. - use_hess : bool | None, optional - Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on - the ``method``. - initvals : None | dict, optional - Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. - If None, the model's default initial values are used. - random_seed : None | int | np.random.Generator, optional - Seed for the random number generator or a numpy Generator for reproducibility - return_raw: bool | False, optinal - Whether to also return the full output of `scipy.optimize.minimize` - jitter_rvs : list of TensorVariables, optional - Variables whose initial values should be jittered. If None, all variables are jittered. - progressbar : bool, optional - Whether to display a progress bar during optimization. Defaults to True. - include_transformed: bool, optional - Whether to include transformed variable values in the returned dictionary. Defaults to True. - gradient_backend: str, default "pytensor" - Which backend to use to compute gradients. Must be one of "pytensor" or "jax". - compile_kwargs: dict, optional - Additional options to pass to the ``pytensor.function`` function when compiling loss functions. - **optimizer_kwargs - Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless - ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, - ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details. - - Returns - ------- - optimizer_result: dict[str, np.ndarray] or tuple[dict[str, np.ndarray], OptimizerResult] - Dictionary with names of random variables as keys, and optimization results as values. If return_raw is True, - also returns the object returned by ``scipy.optimize.minimize``. - """ - model = pm.modelcontext(model) - frozen_model = freeze_dims_and_data(model) - - jitter_rvs = [] if jitter_rvs is None else jitter_rvs - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - - ipfn = make_initial_point_fn( - model=frozen_model, - jitter_rvs=set(jitter_rvs), - return_transformed=True, - overrides=initvals, - ) - - start_dict = ipfn(random_seed) - vars_dict = {var.name: var for var in frozen_model.continuous_value_vars} - initial_params = DictToArrayBijection.map( - {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} - ) - - do_basinhopping = method == "basinhopping" - minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {}) - - if do_basinhopping: - # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need - # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default - # if one isn't provided. - - method = minimizer_kwargs.pop("method", "L-BFGS-B") - minimizer_kwargs["method"] = method - - use_grad, use_hess, use_hessp = set_optimizer_function_defaults( - method, use_grad, use_hess, use_hessp - ) - - f_logp, f_hess, f_hessp = scipy_optimize_funcs_from_loss( - loss=-frozen_model.logp(jacobian=False), - inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, - initial_point_dict=start_dict, - use_grad=use_grad, - use_hess=use_hess, - use_hessp=use_hessp, - gradient_backend=gradient_backend, - compile_kwargs=compile_kwargs, - ) - - args = optimizer_kwargs.pop("args", None) - - # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument - # if so. That is why the jac argument is not passed here in either branch. - - if do_basinhopping: - if "args" not in minimizer_kwargs: - minimizer_kwargs["args"] = args - if "hess" not in minimizer_kwargs: - minimizer_kwargs["hess"] = f_hess - if "hessp" not in minimizer_kwargs: - minimizer_kwargs["hessp"] = f_hessp - if "method" not in minimizer_kwargs: - minimizer_kwargs["method"] = method - - optimizer_result = basinhopping( - func=f_logp, - x0=cast(np.ndarray[float], initial_params.data), - progressbar=progressbar, - minimizer_kwargs=minimizer_kwargs, - **optimizer_kwargs, - ) - - else: - optimizer_result = minimize( - f=f_logp, - x0=cast(np.ndarray[float], initial_params.data), - args=args, - hess=f_hess, - hessp=f_hessp, - progressbar=progressbar, - method=method, - **optimizer_kwargs, - ) - - raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) - unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) - unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")( - DictToArrayBijection.rmap(raveled_optimized) - ) - - optimized_point = { - var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) - } - - if return_raw: - return optimized_point, optimizer_result - - return optimized_point diff --git a/pymc_extras/inference/fit.py b/pymc_extras/inference/fit.py index 5b83ff1f3..ac51e76bb 100644 --- a/pymc_extras/inference/fit.py +++ b/pymc_extras/inference/fit.py @@ -37,6 +37,6 @@ def fit(method: str, **kwargs) -> az.InferenceData: return fit_pathfinder(**kwargs) if method == "laplace": - from pymc_extras.inference.laplace import fit_laplace + from pymc_extras.inference import fit_laplace return fit_laplace(**kwargs) diff --git a/pymc_extras/inference/laplace.py b/pymc_extras/inference/laplace.py deleted file mode 100644 index d64d2adab..000000000 --- a/pymc_extras/inference/laplace.py +++ /dev/null @@ -1,685 +0,0 @@ -# Copyright 2024 The PyMC Developers -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging - -from collections.abc import Callable -from functools import reduce -from importlib.util import find_spec -from itertools import product -from typing import Literal - -import arviz as az -import numpy as np -import pymc as pm -import pytensor -import pytensor.tensor as pt -import xarray as xr - -from arviz import dict_to_dataset -from better_optimize.constants import minimize_method -from numpy.typing import ArrayLike -from pymc import DictToArrayBijection -from pymc.backends.arviz import ( - coords_and_dims_for_inferencedata, - find_constants, - find_observations, -) -from pymc.blocking import RaveledVars -from pymc.model.transform.conditioning import remove_value_transforms -from pymc.model.transform.optimization import freeze_dims_and_data -from pymc.util import get_default_varnames -from pytensor.tensor import TensorVariable -from pytensor.tensor.optimize import minimize -from scipy import stats - -from pymc_extras.inference.find_map import ( - GradientBackend, - _unconstrained_vector_to_constrained_rvs, - find_MAP, - get_nearest_psd, - scipy_optimize_funcs_from_loss, -) - -_log = logging.getLogger(__name__) - - -def get_conditional_gaussian_approximation( - x: TensorVariable, - Q: TensorVariable | ArrayLike, - mu: TensorVariable | ArrayLike, - args: list[TensorVariable] | None = None, - model: pm.Model | None = None, - method: minimize_method = "BFGS", - use_jac: bool = True, - use_hess: bool = False, - optimizer_kwargs: dict | None = None, -) -> Callable: - """ - Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. - - That is: - y | x, sigma ~ N(Ax, sigma^2 W) - x | params ~ N(mu, Q(params)^-1) - - We seek to estimate log(p(x | y, params)): - - log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const - - Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). - - This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. - - Thus: - - 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. - - 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). - - Parameters - ---------- - x: TensorVariable - The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). - Q: TensorVariable | ArrayLike - The precision matrix of the latent field x. - mu: TensorVariable | ArrayLike - The mean of the latent field x. - args: list[TensorVariable] - Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. - model: Model - PyMC model to use. - method: minimize_method - Which minimization algorithm to use. - use_jac: bool - If true, the minimizer will compute the gradient of log(p(x | y, params)). - use_hess: bool - If true, the minimizer will compute the Hessian log(p(x | y, params)). - optimizer_kwargs: dict - Kwargs to pass to scipy.optimize.minimize. - - Returns - ------- - f: Callable - A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. - """ - model = pm.modelcontext(model) - - if args is None: - args = model.continuous_value_vars + model.discrete_value_vars - - # f = log(p(y | x, params)) - f_x = model.logp() - jac = pytensor.gradient.grad(f_x, x) - hess = pytensor.gradient.jacobian(jac.flatten(), x) - - # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) - log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) - - # Maximize log(p(x | y, params)) wrt x to find mode x0 - x0, _ = minimize( - objective=-log_x_posterior, - x=x, - method=method, - jac=use_jac, - hess=use_hess, - optimizer_kwargs=optimizer_kwargs, - ) - - # require f'(x0) and f''(x0) for Laplace approx - jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) - hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) - - # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) - _, logdetQ = pt.nlinalg.slogdet(Q) - conditional_gaussian_approx = ( - -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ - ) - - # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is - # far from the mode x0 or in a neighbourhood which results in poor convergence. - return pytensor.function(args, [x0, conditional_gaussian_approx]) - - -def laplace_draws_to_inferencedata( - posterior_draws: list[np.ndarray[float | int]], model: pm.Model | None = None -) -> az.InferenceData: - """ - Convert draws from a posterior estimated with the Laplace approximation to an InferenceData object. - - - Parameters - ---------- - posterior_draws: list of np.ndarray - A list of arrays containing the posterior draws. Each array should have shape (chains, draws, *shape), where - shape is the shape of the variable in the posterior. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - - Returns - ------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples - """ - model = pm.modelcontext(model) - chains, draws, *_ = posterior_draws[0].shape - - def make_rv_coords(name): - coords = {"chain": range(chains), "draw": range(draws)} - extra_dims = model.named_vars_to_dims.get(name) - if extra_dims is None: - return coords - return coords | {dim: list(model.coords[dim]) for dim in extra_dims} - - def make_rv_dims(name): - dims = ["chain", "draw"] - extra_dims = model.named_vars_to_dims.get(name) - if extra_dims is None: - return dims - return dims + list(extra_dims) - - names = [ - x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) - ] - idata = { - name: xr.DataArray( - data=draws, - coords=make_rv_coords(name), - dims=make_rv_dims(name), - name=name, - ) - for name, draws in zip(names, posterior_draws) - } - - coords, dims = coords_and_dims_for_inferencedata(model) - idata = az.convert_to_inference_data(idata, coords=coords, dims=dims) - - return idata - - -def add_fit_to_inferencedata( - idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None -) -> az.InferenceData: - """ - Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. - - - Parameters - ---------- - idata: az.InfereceData - An InferenceData object containing the approximated posterior samples. - mu: RaveledVars - The MAP estimate of the model parameters. - H_inv: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - - Returns - ------- - idata: az.InferenceData - The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group. - """ - model = pm.modelcontext(model) - coords = model.coords - - variable_names, *_ = zip(*mu.point_map_info) - - def make_unpacked_variable_names(name): - value_to_dim = { - x.name: model.named_vars_to_dims.get(model.values_to_rvs[x].name, None) - for x in model.value_vars - } - value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None} - - rv_to_dim = model.named_vars_to_dims - dims_dict = rv_to_dim | value_to_dim - - dims = dims_dict.get(name) - if dims is None: - return [name] - labels = product(*(coords[dim] for dim in dims)) - return [f"{name}[{','.join(map(str, label))}]" for label in labels] - - unpacked_variable_names = reduce( - lambda lst, name: lst + make_unpacked_variable_names(name), variable_names, [] - ) - - mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names}) - cov_dataarray = xr.DataArray( - H_inv, - dims=["rows", "columns"], - coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, - ) - - dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray}) - idata.add_groups(fit=dataset) - - return idata - - -def add_data_to_inferencedata( - idata: az.InferenceData, - progressbar: bool = True, - model: pm.Model | None = None, - compile_kwargs: dict | None = None, -) -> az.InferenceData: - """ - Add observed and constant data to an InferenceData object. - - Parameters - ---------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples. - progressbar: bool - Whether to display a progress bar during computations. Default is True. - model: Model, optional - A PyMC model. If None, the model is taken from the current model context. - compile_kwargs: dict, optional - Additional keyword arguments to pass to pytensor.function. - - Returns - ------- - idata: az.InferenceData - The provided InferenceData, with observed and constant data added. - """ - model = pm.modelcontext(model) - - if model.deterministics: - idata.posterior = pm.compute_deterministics( - idata.posterior, - model=model, - merge_dataset=True, - progressbar=progressbar, - compile_kwargs=compile_kwargs, - ) - - coords, dims = coords_and_dims_for_inferencedata(model) - - observed_data = dict_to_dataset( - find_observations(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - - constant_data = dict_to_dataset( - find_constants(model), - library=pm, - coords=coords, - dims=dims, - default_dims=[], - ) - - idata.add_groups( - {"observed_data": observed_data, "constant_data": constant_data}, - coords=coords, - dims=dims, - ) - - return idata - - -def fit_mvn_at_MAP( - optimized_point: dict[str, np.ndarray], - model: pm.Model | None = None, - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - transform_samples: bool = False, - gradient_backend: GradientBackend = "pytensor", - zero_tol: float = 1e-8, - diag_jitter: float | None = 1e-8, - compile_kwargs: dict | None = None, -) -> tuple[RaveledVars, np.ndarray]: - """ - Create a multivariate normal distribution using the inverse of the negative Hessian matrix of the log-posterior - evaluated at the MAP estimate. This is the basis of the Laplace approximation. - - Parameters - ---------- - optimized_point : dict[str, np.ndarray] - Local maximum a posteriori (MAP) point returned from pymc.find_MAP or jax_tools.fit_map - model : Model, optional - A PyMC model. If None, the model is taken from the current model context. - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. - If 'error', an error will be raised. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. - gradient_backend: str, default "pytensor" - The backend to use for gradient computations. Must be one of "pytensor" or "jax". - zero_tol: float - Value below which an element of the Hessian matrix is counted as 0. - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. - diag_jitter: float | None - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. - If None, no jitter is added. Default is 1e-8. - compile_kwargs: dict, optional - Additional keyword arguments to pass to pytensor.function when compiling loss functions - - Returns - ------- - map_estimate: RaveledVars - The MAP estimate of the model parameters, raveled into a 1D array. - - inverse_hessian: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - """ - if gradient_backend == "jax" and not find_spec("jax"): - raise ImportError("JAX must be installed to use JAX gradients") - - model = pm.modelcontext(model) - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - frozen_model = freeze_dims_and_data(model) - - if not transform_samples: - untransformed_model = remove_value_transforms(frozen_model) - logp = untransformed_model.logp(jacobian=False) - variables = untransformed_model.continuous_value_vars - else: - logp = frozen_model.logp(jacobian=True) - variables = frozen_model.continuous_value_vars - - variable_names = {var.name for var in variables} - optimized_free_params = {k: v for k, v in optimized_point.items() if k in variable_names} - mu = DictToArrayBijection.map(optimized_free_params) - - _, f_hess, _ = scipy_optimize_funcs_from_loss( - loss=-logp, - inputs=variables, - initial_point_dict=optimized_free_params, - use_grad=True, - use_hess=True, - use_hessp=False, - gradient_backend=gradient_backend, - compile_kwargs=compile_kwargs, - ) - - H = -f_hess(mu.data) - if H.ndim == 1: - H = np.expand_dims(H, axis=1) - H_inv = np.linalg.pinv(np.where(np.abs(H) < zero_tol, 0, -H)) - - def stabilize(x, jitter): - return x + np.eye(x.shape[0]) * jitter - - H_inv = H_inv if diag_jitter is None else stabilize(H_inv, diag_jitter) - - try: - np.linalg.cholesky(H_inv) - except np.linalg.LinAlgError: - if on_bad_cov == "error": - raise np.linalg.LinAlgError( - "Inverse Hessian not positive-semi definite at the provided point" - ) - H_inv = get_nearest_psd(H_inv) - if on_bad_cov == "warn": - _log.warning( - "Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD " - "matrix in L1-norm instead" - ) - - return mu, H_inv - - -def sample_laplace_posterior( - mu: RaveledVars, - H_inv: np.ndarray, - model: pm.Model | None = None, - chains: int = 2, - draws: int = 500, - transform_samples: bool = False, - progressbar: bool = True, - random_seed: int | np.random.Generator | None = None, - compile_kwargs: dict | None = None, -) -> az.InferenceData: - """ - Generate samples from a multivariate normal distribution with mean `mu` and inverse covariance matrix `H_inv`. - - Parameters - ---------- - mu: RaveledVars - The MAP estimate of the model parameters. - H_inv: np.ndarray - The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. - model : Model - A PyMC model - chains : int - The number of sampling chains running in parallel. Default is 2. - draws : int - The number of samples to draw from the approximated posterior. Default is 500. - transform_samples : bool - Whether to transform the samples back to the original parameter space. Default is True. - progressbar : bool - Whether to display a progress bar during computations. Default is True. - random_seed: int | np.random.Generator | None - Seed for the random number generator or a numpy Generator for reproducibility - - Returns - ------- - idata: az.InferenceData - An InferenceData object containing the approximated posterior samples. - """ - model = pm.modelcontext(model) - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - rng = np.random.default_rng(random_seed) - - posterior_dist = stats.multivariate_normal( - mean=mu.data, cov=H_inv, allow_singular=True, seed=rng - ) - - posterior_draws = posterior_dist.rvs(size=(chains, draws)) - if mu.data.shape == (1,): - posterior_draws = np.expand_dims(posterior_draws, -1) - - if transform_samples: - constrained_rvs, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(model) - batched_values = pt.tensor( - "batched_values", - shape=(chains, draws, *unconstrained_vector.type.shape), - dtype=unconstrained_vector.type.dtype, - ) - batched_rvs = pytensor.graph.vectorize_graph( - constrained_rvs, replace={unconstrained_vector: batched_values} - ) - - f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs) - posterior_draws = f_constrain(posterior_draws) - - else: - info = mu.point_map_info - flat_shapes = [size for _, _, size, _ in info] - slices = [ - slice(sum(flat_shapes[:i]), sum(flat_shapes[: i + 1])) for i in range(len(flat_shapes)) - ] - - posterior_draws = [ - posterior_draws[..., idx].reshape((chains, draws, *shape)).astype(dtype) - for idx, (name, shape, _, dtype) in zip(slices, info) - ] - - idata = laplace_draws_to_inferencedata(posterior_draws, model) - idata = add_fit_to_inferencedata(idata, mu, H_inv) - idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) - - return idata - - -def fit_laplace( - optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", - *, - model: pm.Model | None = None, - use_grad: bool | None = None, - use_hessp: bool | None = None, - use_hess: bool | None = None, - initvals: dict | None = None, - random_seed: int | np.random.Generator | None = None, - return_raw: bool = False, - jitter_rvs: list[pt.TensorVariable] | None = None, - progressbar: bool = True, - include_transformed: bool = True, - gradient_backend: GradientBackend = "pytensor", - chains: int = 2, - draws: int = 500, - on_bad_cov: Literal["warn", "error", "ignore"] = "ignore", - fit_in_unconstrained_space: bool = False, - zero_tol: float = 1e-8, - diag_jitter: float | None = 1e-8, - optimizer_kwargs: dict | None = None, - compile_kwargs: dict | None = None, -) -> az.InferenceData: - """ - Create a Laplace (quadratic) approximation for a posterior distribution. - - This function generates a Laplace approximation for a given posterior distribution using a specified - number of draws. This is useful for obtaining a parametric approximation to the posterior distribution - that can be used for further analysis. - - Parameters - ---------- - model : pm.Model - The PyMC model to be fit. If None, the current model context is used. - method : str - The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, - trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. - - See scipy.optimize.minimize documentation for details. - use_grad : bool | None, optional - Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on - the ``method``. - use_hessp : bool | None, optional - Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on - the ``method``. - use_hess : bool | None, optional - Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on - the ``method``. - initvals : None | dict, optional - Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. - If None, the model's default initial values are used. - random_seed : None | int | np.random.Generator, optional - Seed for the random number generator or a numpy Generator for reproducibility - return_raw: bool | False, optinal - Whether to also return the full output of `scipy.optimize.minimize` - jitter_rvs : list of TensorVariables, optional - Variables whose initial values should be jittered. If None, all variables are jittered. - progressbar : bool, optional - Whether to display a progress bar during optimization. Defaults to True. - fit_in_unconstrained_space: bool, default False - Whether to fit the Laplace approximation in the unconstrained parameter space. If True, samples will be drawn - from a mean and covariance matrix computed at a point in the **unconstrained** parameter space. Samples will - then be transformed back to the original parameter space. This will guarantee that the samples will respect - the domain of prior distributions (for exmaple, samples from a Beta distribution will be strictly between 0 - and 1). - - .. warning:: - This argument should be considered highly experimental. It has not been verified if this method produces - valid draws from the posterior. **Use at your own risk**. - - gradient_backend: str, default "pytensor" - The backend to use for gradient computations. Must be one of "pytensor" or "jax". - chains: int, default: 2 - The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel, - because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are - compatible with the ArviZ library. - draws: int, default: 500 - The number of samples to draw from the approximated posterior. Totals samples will be chains * draws. - on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore' - What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite. - If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned. - If 'error', an error will be raised. - zero_tol: float - Value below which an element of the Hessian matrix is counted as 0. - This is used to stabilize the computation of the inverse Hessian matrix. Default is 1e-8. - diag_jitter: float | None - A small value added to the diagonal of the inverse Hessian matrix to ensure it is positive semi-definite. - If None, no jitter is added. Default is 1e-8. - optimizer_kwargs - Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless - ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, - ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details. - compile_kwargs: dict, optional - Additional keyword arguments to pass to pytensor.function. - - Returns - ------- - :class:`~arviz.InferenceData` - An InferenceData object containing the approximated posterior samples. - - Examples - -------- - >>> from pymc_extras.inference.laplace import fit_laplace - >>> import numpy as np - >>> import pymc as pm - >>> import arviz as az - >>> y = np.array([2642, 3503, 4358]*10) - >>> with pm.Model() as m: - >>> logsigma = pm.Uniform("logsigma", 1, 100) - >>> mu = pm.Uniform("mu", -10000, 10000) - >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - >>> idata = fit_laplace() - - Notes - ----- - This method of approximation may not be suitable for all types of posterior distributions, - especially those with significant skewness or multimodality. - - See Also - -------- - fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m) - will forward the call to 'fit_laplace'. - - """ - compile_kwargs = {} if compile_kwargs is None else compile_kwargs - optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs - - optimized_point = find_MAP( - method=optimize_method, - model=model, - use_grad=use_grad, - use_hessp=use_hessp, - use_hess=use_hess, - initvals=initvals, - random_seed=random_seed, - return_raw=return_raw, - jitter_rvs=jitter_rvs, - progressbar=progressbar, - include_transformed=include_transformed, - gradient_backend=gradient_backend, - compile_kwargs=compile_kwargs, - **optimizer_kwargs, - ) - - mu, H_inv = fit_mvn_at_MAP( - optimized_point=optimized_point, - model=model, - on_bad_cov=on_bad_cov, - transform_samples=fit_in_unconstrained_space, - gradient_backend=gradient_backend, - zero_tol=zero_tol, - diag_jitter=diag_jitter, - compile_kwargs=compile_kwargs, - ) - - return sample_laplace_posterior( - mu=mu, - H_inv=H_inv, - model=model, - chains=chains, - draws=draws, - transform_samples=fit_in_unconstrained_space, - progressbar=progressbar, - random_seed=random_seed, - compile_kwargs=compile_kwargs, - ) diff --git a/pymc_extras/inference/laplace_approx/__init__.py b/pymc_extras/inference/laplace_approx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_extras/inference/laplace_approx/find_map.py b/pymc_extras/inference/laplace_approx/find_map.py new file mode 100644 index 000000000..930137e22 --- /dev/null +++ b/pymc_extras/inference/laplace_approx/find_map.py @@ -0,0 +1,347 @@ +import logging + +from collections.abc import Callable +from typing import Literal, cast + +import numpy as np +import pymc as pm + +from better_optimize import basinhopping, minimize +from better_optimize.constants import MINIMIZE_MODE_KWARGS, minimize_method +from pymc.blocking import DictToArrayBijection, RaveledVars +from pymc.initial_point import make_initial_point_fn +from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.util import get_default_varnames +from pytensor.tensor import TensorVariable +from scipy.optimize import OptimizeResult + +from pymc_extras.inference.laplace_approx.idata import ( + add_data_to_inference_data, + add_fit_to_inference_data, + add_optimizer_result_to_inference_data, + map_results_to_inference_data, +) +from pymc_extras.inference.laplace_approx.scipy_interface import ( + GradientBackend, + scipy_optimize_funcs_from_loss, +) + +_log = logging.getLogger(__name__) + + +def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp): + method_info = MINIMIZE_MODE_KWARGS[method].copy() + + if use_hess and use_hessp: + _log.warning( + 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the ' + 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. ' + 'Setting "use_hess" to False.' + ) + use_hess = False + + use_grad = use_grad if use_grad is not None else method_info["uses_grad"] + + if use_hessp is not None and use_hess is None: + use_hess = not use_hessp + + elif use_hess is not None and use_hessp is None: + use_hessp = not use_hess + + elif use_hessp is None and use_hess is None: + use_hessp = method_info["uses_hessp"] + use_hess = method_info["uses_hess"] + if use_hessp and use_hess: + # If a method could use either hess or hessp, we default to using hessp + use_hess = False + + return use_grad, use_hess, use_hessp + + +def get_nearest_psd(A: np.ndarray) -> np.ndarray: + """ + Compute the nearest positive semi-definite matrix to a given matrix. + + This function takes a square matrix and returns the nearest positive semi-definite matrix using + eigenvalue decomposition. It ensures all eigenvalues are non-negative. The "nearest" matrix is defined in terms + of the Frobenius norm. + + Parameters + ---------- + A : np.ndarray + Input square matrix. + + Returns + ------- + np.ndarray + The nearest positive semi-definite matrix to the input matrix. + """ + C = (A + A.T) / 2 + eigval, eigvec = np.linalg.eigh(C) + eigval[eigval < 0] = 0 + + return eigvec @ np.diag(eigval) @ eigvec.T + + +def _make_initial_point(model, initvals=None, random_seed=None, jitter_rvs=None): + jitter_rvs = [] if jitter_rvs is None else jitter_rvs + + ipfn = make_initial_point_fn( + model=model, + jitter_rvs=set(jitter_rvs), + return_transformed=True, + overrides=initvals, + ) + + start_dict = ipfn(random_seed) + vars_dict = {var.name: var for var in model.continuous_value_vars} + initial_params = DictToArrayBijection.map( + {var_name: value for var_name, value in start_dict.items() if var_name in vars_dict} + ) + + return initial_params + + +def _compute_inverse_hessian( + optimizer_result: OptimizeResult | None, + optimal_point: np.ndarray | None, + f_fused: Callable | None, + f_hessp: Callable | None, + use_hess: bool, + method: minimize_method | Literal["BFGS", "L-BFGS-B"], +): + """ + Compute the Hessian matrix or its inverse based on the optimization result and the method used. + + Downstream functions (e.g. laplace approximation) will need the inverse Hessian matrix. This function computes it + in the cheapest way possible, depending on the optimization method used and the available compiled functions. + + Parameters + ---------- + optimizer_result: OptimizeResult, optional + The result of the optimization, containing the optimized parameters and possibly an approximate inverse Hessian. + optimal_point: np.ndarray, optional + The optimal point found by the optimizer, used to compute the Hessian if necessary. If not provided, it will be + extracted from the optimizer result. + f_fused: callable, optional + The compiled function representing the loss and possibly its gradient and Hessian. + f_hessp: callable, optional + The compiled function for Hessian-vector products, if available. + use_hess: bool + Whether the Hessian matrix was used in the optimization. + method: minimize_method + The optimization method used, which determines how the Hessian is computed. + + Returns + ------- + H_inv: np.ndarray + The inverse Hessian matrix, computed based on the optimization method and available functions. + """ + if optimal_point is None and optimizer_result is None: + raise ValueError("At least one of `optimal_point` or `optimizer_result` must be provided.") + + x_star = optimizer_result.x if optimizer_result is not None else optimal_point + n_vars = len(x_star) + + if method == "BFGS" and optimizer_result is not None: + # If we used BFGS, the optimizer result will contain the inverse Hessian -- we can just use that rather than + # re-computing something + if hasattr(optimizer_result, "lowest_optimization_result"): + # We did basinhopping, need to get the inner optimizer results + H_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None) + else: + H_inv = getattr(optimizer_result, "hess_inv", None) + + elif method == "L-BFGS-B" and optimizer_result is not None: + # Here we will have a LinearOperator representing the inverse Hessian-Vector product. + if hasattr(optimizer_result, "lowest_optimization_result"): + # We did basinhopping, need to get the inner optimizer results + f_hessp_inv = getattr(optimizer_result.lowest_optimization_result, "hess_inv", None) + else: + f_hessp_inv = getattr(optimizer_result, "hess_inv", None) + + if f_hessp_inv is not None: + basis = np.eye(n_vars) + H_inv = np.stack([f_hessp_inv(basis[:, i]) for i in range(n_vars)], axis=-1) + else: + H_inv = None + + elif f_hessp is not None: + # In the case that hessp was used, the results object will not save the inverse Hessian, so we can compute it from + # the hessp function, using euclidian basis vector. + basis = np.eye(n_vars) + H = np.stack([f_hessp(x_star, basis[:, i]) for i in range(n_vars)], axis=-1) + H_inv = np.linalg.inv(get_nearest_psd(H)) + + elif use_hess and f_fused is not None: + # If we compiled a hessian function, just use it + _, _, H = f_fused(x_star) + H_inv = np.linalg.inv(get_nearest_psd(H)) + + else: + H_inv = None + + return H_inv + + +def find_MAP( + method: minimize_method | Literal["basinhopping"] = "L-BFGS-B", + *, + model: pm.Model | None = None, + use_grad: bool | None = None, + use_hessp: bool | None = None, + use_hess: bool | None = None, + initvals: dict | None = None, + random_seed: int | np.random.Generator | None = None, + jitter_rvs: list[TensorVariable] | None = None, + progressbar: bool = True, + include_transformed: bool = True, + gradient_backend: GradientBackend = "pytensor", + compile_kwargs: dict | None = None, + **optimizer_kwargs, +) -> ( + dict[str, np.ndarray] + | tuple[dict[str, np.ndarray], np.ndarray] + | tuple[dict[str, np.ndarray], OptimizeResult] + | tuple[dict[str, np.ndarray], OptimizeResult, np.ndarray] +): + """ + Fit a PyMC model via maximum a posteriori (MAP) estimation using JAX and scipy.optimize. + + Parameters + ---------- + model : pm.Model + The PyMC model to be fit. If None, the current model context is used. + method : str + The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, + trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. + + See scipy.optimize.minimize documentation for details. + use_grad : bool | None, optional + Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hessp : bool | None, optional + Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hess : bool | None, optional + Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + initvals : None | dict, optional + Initial values for the model parameters, as str:ndarray key-value pairs. Partial initialization is permitted. + If None, the model's default initial values are used. + random_seed : None | int | np.random.Generator, optional + Seed for the random number generator or a numpy Generator for reproducibility + jitter_rvs : list of TensorVariables, optional + Variables whose initial values should be jittered. If None, all variables are jittered. + progressbar : bool, optional + Whether to display a progress bar during optimization. Defaults to True. + include_transformed: bool, optional + Whether to include transformed variable values in the returned dictionary. Defaults to True. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "pytensor" or "jax". + compile_kwargs: dict, optional + Additional options to pass to the ``pytensor.function`` function when compiling loss functions. + **optimizer_kwargs + Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless + ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, + ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details. + + Returns + ------- + map_result: az.InferenceData + Results of Maximum A Posteriori (MAP) estimation, including the optimized point, inverse Hessian, transformed + latent variables, and optimizer results. + """ + model = pm.modelcontext(model) if model is None else model + frozen_model = freeze_dims_and_data(model) + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs) + + do_basinhopping = method == "basinhopping" + minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {}) + + if do_basinhopping: + # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need + # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default + # if one isn't provided. + + method = minimizer_kwargs.pop("method", "L-BFGS-B") + minimizer_kwargs["method"] = method + + use_grad, use_hess, use_hessp = set_optimizer_function_defaults( + method, use_grad, use_hess, use_hessp + ) + + f_fused, f_hessp = scipy_optimize_funcs_from_loss( + loss=-frozen_model.logp(), + inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, + initial_point_dict=DictToArrayBijection.rmap(initial_params), + use_grad=use_grad, + use_hess=use_hess, + use_hessp=use_hessp, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + ) + + args = optimizer_kwargs.pop("args", ()) + + # better_optimize.minimize will check if f_logp is a fused loss+grad Op, and automatically assign the jac argument + # if so. That is why the jac argument is not passed here in either branch. + + if do_basinhopping: + if "args" not in minimizer_kwargs: + minimizer_kwargs["args"] = args + if "hessp" not in minimizer_kwargs: + minimizer_kwargs["hessp"] = f_hessp + if "method" not in minimizer_kwargs: + minimizer_kwargs["method"] = method + + optimizer_result = basinhopping( + func=f_fused, + x0=cast(np.ndarray[float], initial_params.data), + progressbar=progressbar, + minimizer_kwargs=minimizer_kwargs, + **optimizer_kwargs, + ) + + else: + optimizer_result = minimize( + f=f_fused, + x0=cast(np.ndarray[float], initial_params.data), + args=args, + hessp=f_hessp, + progressbar=progressbar, + method=method, + **optimizer_kwargs, + ) + + H_inv = _compute_inverse_hessian( + optimizer_result=optimizer_result, + optimal_point=None, + f_fused=f_fused, + f_hessp=f_hessp, + use_hess=use_hess, + method=method, + ) + + raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info) + unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed) + unobserved_vars_values = model.compile_fn(unobserved_vars, mode="FAST_COMPILE")( + DictToArrayBijection.rmap(raveled_optimized) + ) + + optimized_point = { + var.name: value for var, value in zip(unobserved_vars, unobserved_vars_values) + } + + idata = map_results_to_inference_data(optimized_point, frozen_model) + idata = add_fit_to_inference_data(idata, raveled_optimized, H_inv) + idata = add_optimizer_result_to_inference_data( + idata, optimizer_result, method, raveled_optimized, model + ) + idata = add_data_to_inference_data( + idata, progressbar=False, model=model, compile_kwargs=compile_kwargs + ) + + return idata diff --git a/pymc_extras/inference/laplace_approx/idata.py b/pymc_extras/inference/laplace_approx/idata.py new file mode 100644 index 000000000..edf011dd4 --- /dev/null +++ b/pymc_extras/inference/laplace_approx/idata.py @@ -0,0 +1,392 @@ +from itertools import product +from typing import Literal + +import arviz as az +import numpy as np +import pymc as pm +import xarray as xr + +from arviz import dict_to_dataset +from better_optimize.constants import minimize_method +from pymc.backends.arviz import coords_and_dims_for_inferencedata, find_constants, find_observations +from pymc.blocking import RaveledVars +from pymc.util import get_default_varnames +from scipy.optimize import OptimizeResult +from scipy.sparse.linalg import LinearOperator + + +def make_default_labels(name: str, shape: tuple[int, ...]) -> list: + if len(shape) == 0: + return [name] + + return [list(range(dim)) for dim in shape] + + +def make_unpacked_variable_names(names: list[str], model: pm.Model) -> list[str]: + coords = model.coords + initial_point = model.initial_point() + + value_to_dim = { + value.name: model.named_vars_to_dims.get(model.values_to_rvs[value].name, None) + for value in model.value_vars + } + value_to_dim = {k: v for k, v in value_to_dim.items() if v is not None} + + rv_to_dim = model.named_vars_to_dims + dims_dict = rv_to_dim | value_to_dim + + unpacked_variable_names = [] + for name in names: + shape = initial_point[name].shape + if shape: + dims = dims_dict.get(name) + if dims: + labels_by_dim = [ + coords[dim] if shape[i] == len(coords[dim]) else np.arange(shape[i]) + for i, dim in enumerate(dims) + ] + else: + labels_by_dim = make_default_labels(name, shape) + labels = product(*labels_by_dim) + unpacked_variable_names.extend( + [f"{name}[{','.join(map(str, label))}]" for label in labels] + ) + else: + unpacked_variable_names.extend([name]) + return unpacked_variable_names + + +def map_results_to_inference_data( + map_point: dict[str, float | int | np.ndarray], + model: pm.Model | None = None, +): + """ + Add the MAP point to an InferenceData object in the posterior group. + + Unlike a typical posterior, the MAP point is a single point estimate rather than a distribution. As a result, it + does not have a chain or draw dimension, and is stored as a single point in the posterior group. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object to which the MAP point will be added. + map_point: dict + A dictionary containing the MAP point estimates for each variable. The keys should be the variable names, and + the values should be the corresponding MAP estimates. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the MAP point added to the posterior group. + """ + + model = pm.modelcontext(model) if model is None else model + coords, dims = coords_and_dims_for_inferencedata(model) + initial_point = model.initial_point() + + # The MAP point will have both the transformed and untransformed variables, so we need to ensure that + # we have the correct dimensions for each variable. + var_name_to_value_name = { + rv.name: value.name + for rv, value in model.rvs_to_values.items() + if rv not in model.observed_RVs + } + dims.update( + { + value_name: dims[var_name] + for var_name, value_name in var_name_to_value_name.items() + if var_name in dims and (initial_point[value_name].shape == map_point[var_name].shape) + } + ) + + constrained_names = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) + ] + all_varnames = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True) + ] + + unconstrained_names = set(all_varnames) - set(constrained_names) + + idata = az.from_dict( + posterior={ + k: np.expand_dims(v, (0, 1)) for k, v in map_point.items() if k in constrained_names + }, + coords=coords, + dims=dims, + ) + + if unconstrained_names: + unconstrained_posterior = az.from_dict( + posterior={ + k: np.expand_dims(v, (0, 1)) + for k, v in map_point.items() + if k in unconstrained_names + }, + coords=coords, + dims=dims, + ) + + idata["unconstrained_posterior"] = unconstrained_posterior.posterior + + return idata + + +def add_fit_to_inference_data( + idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None +) -> az.InferenceData: + """ + Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object. + + Parameters + ---------- + idata: az.InfereceData + An InferenceData object containing the approximated posterior samples. + mu: RaveledVars + The MAP estimate of the model parameters. + H_inv: np.ndarray + The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the mean vector and covariance matrix added to the "fit" group. + """ + model = pm.modelcontext(model) if model is None else model + + variable_names, *_ = zip(*mu.point_map_info) + + unpacked_variable_names = make_unpacked_variable_names(variable_names, model) + + mean_dataarray = xr.DataArray(mu.data, dims=["rows"], coords={"rows": unpacked_variable_names}) + + data = {"mean_vector": mean_dataarray} + + if H_inv is not None: + cov_dataarray = xr.DataArray( + H_inv, + dims=["rows", "columns"], + coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, + ) + data["covariance_matrix"] = cov_dataarray + + dataset = xr.Dataset(data) + idata.add_groups(fit=dataset) + + return idata + + +def add_data_to_inference_data( + idata: az.InferenceData, + progressbar: bool = True, + model: pm.Model | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Add observed and constant data to an InferenceData object. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + progressbar: bool + Whether to display a progress bar during computations. Default is True. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with observed and constant data added. + """ + model = pm.modelcontext(model) if model is None else model + + if model.deterministics: + expand_dims = {} + if "chain" not in idata.posterior.coords: + expand_dims["chain"] = [0] + if "draw" not in idata.posterior.coords: + expand_dims["draw"] = [0] + + idata.posterior = pm.compute_deterministics( + idata.posterior.expand_dims(expand_dims), + model=model, + merge_dataset=True, + progressbar=progressbar, + compile_kwargs=compile_kwargs, + ) + + coords, dims = coords_and_dims_for_inferencedata(model) + + observed_data = dict_to_dataset( + find_observations(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + constant_data = dict_to_dataset( + find_constants(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + idata.add_groups( + {"observed_data": observed_data, "constant_data": constant_data}, + coords=coords, + dims=dims, + ) + + return idata + + +def optimizer_result_to_dataset( + result: OptimizeResult, + method: minimize_method | Literal["basinhopping"], + mu: RaveledVars | None = None, + model: pm.Model | None = None, +) -> xr.Dataset: + """ + Convert an OptimizeResult object to an xarray Dataset object. + + Parameters + ---------- + result: OptimizeResult + The result of the optimization process. + method: minimize_method or "basinhopping" + The optimization method used. + + Returns + ------- + dataset: xr.Dataset + An xarray Dataset containing the optimization results. + """ + if not isinstance(result, OptimizeResult): + raise TypeError("result must be an instance of OptimizeResult") + + model = pm.modelcontext(model) if model is None else model + variable_names, *_ = zip(*mu.point_map_info) + unpacked_variable_names = make_unpacked_variable_names(variable_names, model) + + data_vars = {} + + if hasattr(result, "lowest_optimization_result"): + # If we did basinhopping, there's a results inside the results. We want to pop this out and collapse them, + # overwriting outer keys with the inner keys + inner_res = result.pop("lowest_optimization_result") + for key in inner_res.keys(): + result[key] = inner_res[key] + + if hasattr(result, "x"): + data_vars["x"] = xr.DataArray( + result.x, dims=["variables"], coords={"variables": unpacked_variable_names} + ) + if hasattr(result, "fun"): + data_vars["fun"] = xr.DataArray(result.fun, dims=[]) + if hasattr(result, "success"): + data_vars["success"] = xr.DataArray(result.success, dims=[]) + if hasattr(result, "message"): + data_vars["message"] = xr.DataArray(str(result.message), dims=[]) + if hasattr(result, "jac") and result.jac is not None: + jac = np.asarray(result.jac) + if jac.ndim == 1: + data_vars["jac"] = xr.DataArray( + jac, dims=["variables"], coords={"variables": unpacked_variable_names} + ) + else: + data_vars["jac"] = xr.DataArray( + jac, + dims=["variables", "variables_aux"], + coords={ + "variables": unpacked_variable_names, + "variables_aux": unpacked_variable_names, + }, + ) + + if hasattr(result, "hess_inv") and result.hess_inv is not None: + hess_inv = result.hess_inv + if isinstance(hess_inv, LinearOperator): + n = hess_inv.shape[0] + eye = np.eye(n) + hess_inv_mat = np.column_stack([hess_inv.matvec(eye[:, i]) for i in range(n)]) + hess_inv = hess_inv_mat + else: + hess_inv = np.asarray(hess_inv) + data_vars["hess_inv"] = xr.DataArray( + hess_inv, + dims=["variables", "variables_aux"], + coords={"variables": unpacked_variable_names, "variables_aux": unpacked_variable_names}, + ) + + if hasattr(result, "nit"): + data_vars["nit"] = xr.DataArray(result.nit, dims=[]) + if hasattr(result, "nfev"): + data_vars["nfev"] = xr.DataArray(result.nfev, dims=[]) + if hasattr(result, "njev"): + data_vars["njev"] = xr.DataArray(result.njev, dims=[]) + if hasattr(result, "status"): + data_vars["status"] = xr.DataArray(result.status, dims=[]) + + # Add any other fields present in result + for key, value in result.items(): + if key in data_vars: + continue # already added + if value is None: + continue + arr = np.asarray(value) + + # TODO: We can probably do something smarter here with a dictionary of all possible values and their expected + # dimensions. + dims = [f"{key}_dim_{i}" for i in range(arr.ndim)] + data_vars[key] = xr.DataArray( + arr, + dims=dims, + coords={f"{key}_dim_{i}": np.arange(arr.shape[i]) for i in range(len(dims))}, + ) + + data_vars["method"] = xr.DataArray(np.array(method), dims=[]) + + return xr.Dataset(data_vars) + + +def add_optimizer_result_to_inference_data( + idata: az.InferenceData, + result: OptimizeResult, + method: minimize_method | Literal["basinhopping"], + mu: RaveledVars | None = None, + model: pm.Model | None = None, +) -> az.InferenceData: + """ + Add the optimization result to an InferenceData object. + + Parameters + ---------- + idata: az.InferenceData + An InferenceData object containing the approximated posterior samples. + result: OptimizeResult + The result of the optimization process. + method: minimize_method or "basinhopping" + The optimization method used. + mu: RaveledVars, optional + The MAP estimate of the model parameters. + model: Model, optional + A PyMC model. If None, the model is taken from the current model context. + + Returns + ------- + idata: az.InferenceData + The provided InferenceData, with the optimization results added to the "optimizer" group. + """ + dataset = optimizer_result_to_dataset(result, method=method, mu=mu, model=model) + idata.add_groups({"optimizer_result": dataset}) + + return idata diff --git a/pymc_extras/inference/laplace_approx/laplace.py b/pymc_extras/inference/laplace_approx/laplace.py new file mode 100644 index 000000000..2b5ef6a16 --- /dev/null +++ b/pymc_extras/inference/laplace_approx/laplace.py @@ -0,0 +1,451 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging + +from collections.abc import Callable +from functools import partial +from typing import Literal +from typing import cast as type_cast + +import arviz as az +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt +import xarray as xr + +from better_optimize.constants import minimize_method +from numpy.typing import ArrayLike +from pymc.blocking import DictToArrayBijection +from pymc.model.transform.optimization import freeze_dims_and_data +from pymc.pytensorf import join_nonshared_inputs +from pymc.util import get_default_varnames +from pytensor.graph import vectorize_graph +from pytensor.tensor import TensorVariable +from pytensor.tensor.optimize import minimize +from pytensor.tensor.type import Variable + +from pymc_extras.inference.laplace_approx.find_map import ( + _compute_inverse_hessian, + _make_initial_point, + find_MAP, +) +from pymc_extras.inference.laplace_approx.scipy_interface import ( + GradientBackend, + scipy_optimize_funcs_from_loss, +) + +_log = logging.getLogger(__name__) + + +def get_conditional_gaussian_approximation( + x: TensorVariable, + Q: TensorVariable | ArrayLike, + mu: TensorVariable | ArrayLike, + args: list[TensorVariable] | None = None, + model: pm.Model | None = None, + method: minimize_method = "BFGS", + use_jac: bool = True, + use_hess: bool = False, + optimizer_kwargs: dict | None = None, +) -> Callable: + """ + Returns a function to estimate the a posteriori log probability of a latent Gaussian field x and its mode x0 using the Laplace approximation. + + That is: + y | x, sigma ~ N(Ax, sigma^2 W) + x | params ~ N(mu, Q(params)^-1) + + We seek to estimate log(p(x | y, params)): + + log(p(x | y, params)) = log(p(y | x, params)) + log(p(x | params)) + const + + Let f(x) = log(p(y | x, params)). From the definition of our model above, we have log(p(x | params)) = -0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). + + This gives log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) + 0.5*logdet(Q). We will estimate this using the Laplace approximation by Taylor expanding f(x) about the mode. + + Thus: + + 1. Maximize log(p(x | y, params)) = f(x) - 0.5*(x - mu).T Q (x - mu) wrt x (note that logdet(Q) does not depend on x) to find the mode x0. + + 2. Substitute x0 into the Laplace approximation expanded about the mode: log(p(x | y, params)) ~= -0.5*x.T (-f''(x0) + Q) x + x.T (Q.mu + f'(x0) - f''(x0).x0) + 0.5*logdet(Q). + + Parameters + ---------- + x: TensorVariable + The parameter with which to maximize wrt (that is, find the mode in x). In INLA, this is the latent field x~N(mu,Q^-1). + Q: TensorVariable | ArrayLike + The precision matrix of the latent field x. + mu: TensorVariable | ArrayLike + The mean of the latent field x. + args: list[TensorVariable] + Args to supply to the compiled function. That is, (x0, logp) = f(x, *args). If set to None, assumes the model RVs are args. + model: Model + PyMC model to use. + method: minimize_method + Which minimization algorithm to use. + use_jac: bool + If true, the minimizer will compute the gradient of log(p(x | y, params)). + use_hess: bool + If true, the minimizer will compute the Hessian log(p(x | y, params)). + optimizer_kwargs: dict + Kwargs to pass to scipy.optimize.minimize. + + Returns + ------- + f: Callable + A function which accepts a value of x and args and returns [x0, log(p(x | y, params))], where x0 is the mode. x is currently both the point at which to evaluate logp and the initial guess for the minimizer. + """ + model = pm.modelcontext(model) + + if args is None: + args = model.continuous_value_vars + model.discrete_value_vars + + # f = log(p(y | x, params)) + f_x = model.logp() + jac = pytensor.gradient.grad(f_x, x) + hess = pytensor.gradient.jacobian(jac.flatten(), x) + + # log(p(x | y, params)) only including terms that depend on x for the minimization step (logdet(Q) ignored as it is a constant wrt x) + log_x_posterior = f_x - 0.5 * (x - mu).T @ Q @ (x - mu) + + # Maximize log(p(x | y, params)) wrt x to find mode x0 + x0, _ = minimize( + objective=-log_x_posterior, + x=x, + method=method, + jac=use_jac, + hess=use_hess, + optimizer_kwargs=optimizer_kwargs, + ) + + # require f'(x0) and f''(x0) for Laplace approx + jac = pytensor.graph.replace.graph_replace(jac, {x: x0}) + hess = pytensor.graph.replace.graph_replace(hess, {x: x0}) + + # Full log(p(x | y, params)) using the Laplace approximation (up to a constant) + _, logdetQ = pt.nlinalg.slogdet(Q) + conditional_gaussian_approx = ( + -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ + ) + + # Currently x is passed both as the query point for f(x, args) = logp(x | y, params) AND as an initial guess for x0. This may cause issues if the query point is + # far from the mode x0 or in a neighbourhood which results in poor convergence. + return pytensor.function(args, [x0, conditional_gaussian_approx]) + + +def _unconstrained_vector_to_constrained_rvs(model): + outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True) + constrained_names = [ + x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False) + ] + names = [x.name for x in outputs] + + unconstrained_names = [name for name in names if name not in constrained_names] + + new_outputs, unconstrained_vector = join_nonshared_inputs( + model.initial_point(), + inputs=model.value_vars, + outputs=outputs, + ) + + constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names] + value_rvs = [x for x in new_outputs if x not in constrained_rvs] + + unconstrained_vector.name = "unconstrained_vector" + + # Redo the names list to ensure it is sorted to match the return order + names = [*constrained_names, *unconstrained_names] + + return names, constrained_rvs, value_rvs, unconstrained_vector + + +def model_to_laplace_approx( + model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500 +): + initial_point = model.initial_point() + raveled_vars = DictToArrayBijection.map(initial_point) + raveled_shape = raveled_vars.data.shape[0] + + # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov, + # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved. + names, constrained_rvs, value_rvs, unconstrained_vector = ( + _unconstrained_vector_to_constrained_rvs(model) + ) + + coords = model.coords | { + "temp_chain": np.arange(chains), + "temp_draw": np.arange(draws), + "unpacked_variable_names": unpacked_variable_names, + } + + with pm.Model(coords=coords, model=None) as laplace_model: + mu = pm.Flat("mean_vector", shape=(raveled_shape,)) + cov = pm.Flat("covariance_matrix", shape=(raveled_shape, raveled_shape)) + laplace_approximation = pm.MvNormal( + "laplace_approximation", + mu=mu, + cov=cov, + dims=["temp_chain", "temp_draw", "unpacked_variable_names"], + method="svd", + ) + + cast_to_var = partial(type_cast, Variable) + batched_rvs = vectorize_graph( + type_cast(list[Variable], constrained_rvs), + replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)}, + ) + + for name, batched_rv in zip(names, batched_rvs): + batch_dims = ("temp_chain", "temp_draw") + if batched_rv.ndim == 2: + dims = batch_dims + elif name in model.named_vars_to_dims: + dims = (*batch_dims, *model.named_vars_to_dims[name]) + else: + dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)]) + initval = initial_point.get(name, None) + dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:] + laplace_model.add_coords( + {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)} + ) + + pm.Deterministic(name, batched_rv, dims=dims) + + return laplace_model + + +def unstack_laplace_draws(laplace_data, model, chains=2, draws=500): + """ + The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are + in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a + single vector, it's not easy to work with. + + This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and + coordinates, where possible. + """ + initial_point = DictToArrayBijection.map(model.initial_point()) + + cursor = 0 + unstacked_laplace_draws = {} + coords = model.coords | {"chain": range(chains), "draw": range(draws)} + + # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g. + # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just + # add an arviz-style default dim and label. + for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info): + rv_dims = [] + for i, dim in enumerate( + model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))]) + ): + if coords.get(dim) and shape[i] == len(coords[dim]): + rv_dims.append(dim) + else: + rv_dims.append(f"{name}_dim_{i}") + coords[f"{name}_dim_{i}"] = np.arange(shape[i]) + + dims = ("chain", "draw", *rv_dims) + + values = ( + laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype) + ) + unstacked_laplace_draws[name] = xr.DataArray( + values, dims=dims, coords={dim: list(coords[dim]) for dim in dims} + ) + + cursor += size + + unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws) + + return unstacked_laplace_draws + + +def fit_laplace( + optimize_method: minimize_method | Literal["basinhopping"] = "BFGS", + *, + model: pm.Model | None = None, + use_grad: bool | None = None, + use_hessp: bool | None = None, + use_hess: bool | None = None, + initvals: dict | None = None, + random_seed: int | np.random.Generator | None = None, + jitter_rvs: list[pt.TensorVariable] | None = None, + progressbar: bool = True, + include_transformed: bool = True, + gradient_backend: GradientBackend = "pytensor", + chains: int = 2, + draws: int = 500, + optimizer_kwargs: dict | None = None, + compile_kwargs: dict | None = None, +) -> az.InferenceData: + """ + Create a Laplace (quadratic) approximation for a posterior distribution. + + This function generates a Laplace approximation for a given posterior distribution using a specified + number of draws. This is useful for obtaining a parametric approximation to the posterior distribution + that can be used for further analysis. + + Parameters + ---------- + model : pm.Model + The PyMC model to be fit. If None, the current model context is used. + method : str + The optimization method to use. Valid choices are: Nelder-Mead, Powell, CG, BFGS, L-BFGS-B, TNC, SLSQP, + trust-constr, dogleg, trust-ncg, trust-exact, trust-krylov, and basinhopping. + + See scipy.optimize.minimize documentation for details. + use_grad : bool | None, optional + Whether to use gradients in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hessp : bool | None, optional + Whether to use Hessian-vector products in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + use_hess : bool | None, optional + Whether to use the Hessian matrix in the optimization. Defaults to None, which determines this automatically based on + the ``method``. + initvals : None | dict, optional + Initial values for the model parameters, as str:ndarray key-value pairs. Paritial initialization is permitted. + If None, the model's default initial values are used. + random_seed : None | int | np.random.Generator, optional + Seed for the random number generator or a numpy Generator for reproducibility + jitter_rvs : list of TensorVariables, optional + Variables whose initial values should be jittered. If None, all variables are jittered. + progressbar : bool, optional + Whether to display a progress bar during optimization. Defaults to True. + include_transformed: bool, default True + Whether to include transformed variables in the output. If True, transformed variables will be included in the + output InferenceData object. If False, only the original variables will be included. + gradient_backend: str, default "pytensor" + The backend to use for gradient computations. Must be one of "pytensor" or "jax". + chains: int, default: 2 + The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel, + because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are + compatible with the ArviZ library. + draws: int, default: 500 + The number of samples to draw from the approximated posterior. Totals samples will be chains * draws. + optimizer_kwargs + Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless + ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``, + ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details. + compile_kwargs: dict, optional + Additional keyword arguments to pass to pytensor.function. + + Returns + ------- + :class:`~arviz.InferenceData` + An InferenceData object containing the approximated posterior samples. + + Examples + -------- + >>> from pymc_extras.inference import fit_laplace + >>> import numpy as np + >>> import pymc as pm + >>> import arviz as az + >>> y = np.array([2642, 3503, 4358]*10) + >>> with pm.Model() as m: + >>> logsigma = pm.Uniform("logsigma", 1, 100) + >>> mu = pm.Uniform("mu", -10000, 10000) + >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + >>> idata = fit_laplace() + + Notes + ----- + This method of approximation may not be suitable for all types of posterior distributions, + especially those with significant skewness or multimodality. + + See Also + -------- + fit : Calling the inference function 'fit' like pmx.fit(method="laplace", model=m) + will forward the call to 'fit_laplace'. + + """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs + model = pm.modelcontext(model) if model is None else model + + idata = find_MAP( + method=optimize_method, + model=model, + use_grad=use_grad, + use_hessp=use_hessp, + use_hess=use_hess, + initvals=initvals, + random_seed=random_seed, + jitter_rvs=jitter_rvs, + progressbar=progressbar, + include_transformed=include_transformed, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + **optimizer_kwargs, + ) + + unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist() + + if "covariance_matrix" not in idata.fit: + # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so + # we have to go back and compute the Hessian at the MAP point now. + frozen_model = freeze_dims_and_data(model) + initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs) + + _, f_hessp = scipy_optimize_funcs_from_loss( + loss=-frozen_model.logp(jacobian=False), + inputs=frozen_model.continuous_value_vars + frozen_model.discrete_value_vars, + initial_point_dict=DictToArrayBijection.rmap(initial_params), + use_grad=False, + use_hess=False, + use_hessp=True, + gradient_backend=gradient_backend, + compile_kwargs=compile_kwargs, + ) + H_inv = _compute_inverse_hessian( + optimizer_result=None, + optimal_point=idata.fit.mean_vector.values, + f_fused=None, + f_hessp=f_hessp, + use_hess=False, + method=optimize_method, + ) + + idata.fit["covariance_matrix"] = xr.DataArray( + H_inv, + dims=("rows", "columns"), + coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names}, + ) + + with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model: + new_posterior = ( + pm.sample_posterior_predictive( + idata.fit.expand_dims(chain=[0], draw=[0]), + extend_inferencedata=False, + random_seed=random_seed, + var_names=[ + "laplace_approximation", + *[x.name for x in laplace_model.deterministics], + ], + ) + .posterior_predictive.squeeze(["chain", "draw"]) + .drop_vars(["chain", "draw"]) + .rename({"temp_chain": "chain", "temp_draw": "draw"}) + ) + + idata.unconstrained_posterior = unstack_laplace_draws( + new_posterior.laplace_approximation.values, model, chains=chains, draws=draws + ) + idata.posterior = new_posterior.drop_vars( + ["laplace_approximation", "unpacked_variable_names"] + ) + + return idata diff --git a/pymc_extras/inference/laplace_approx/scipy_interface.py b/pymc_extras/inference/laplace_approx/scipy_interface.py new file mode 100644 index 000000000..a7489be3e --- /dev/null +++ b/pymc_extras/inference/laplace_approx/scipy_interface.py @@ -0,0 +1,242 @@ +from collections.abc import Callable +from importlib.util import find_spec +from typing import Literal, get_args + +import numpy as np +import pymc as pm +import pytensor + +from pymc import join_nonshared_inputs +from pytensor import tensor as pt +from pytensor.compile import Function +from pytensor.tensor import TensorVariable + +GradientBackend = Literal["pytensor", "jax"] +VALID_BACKENDS = get_args(GradientBackend) + + +def _compile_grad_and_hess_to_jax( + f_fused: Function, use_hess: bool, use_hessp: bool +) -> tuple[Callable | None, Callable | None]: + """ + Compile loss function gradients using JAX. + + Parameters + ---------- + f_fused: Function + The loss function to compile gradients for. Expected to be a pytensor function that returns a scalar loss, + compiled with mode="JAX". + use_hess: bool + Whether to compile a function to compute the hessian of the loss function. + use_hessp: bool + Whether to compile a function to compute the hessian-vector product of the loss function. + + Returns + ------- + f_fused: Callable + The compiled loss function and gradient function, which may also compute the hessian if requested. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + import jax + + f_hessp = None + + orig_loss_fn = f_fused.vm.jit_fn + + if use_hess: + + @jax.jit + def loss_fn_fused(x): + loss_and_grad = jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + hess = jax.hessian(lambda x: orig_loss_fn(x)[0])(x) + return *loss_and_grad, hess + + else: + + @jax.jit + def loss_fn_fused(x): + return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x) + + if use_hessp: + + def f_hessp_jax(x, p): + y, u = jax.jvp(lambda x: loss_fn_fused(x)[1], (x,), (p,)) + return jax.numpy.stack(u) + + f_hessp = jax.jit(f_hessp_jax) + + return loss_fn_fused, f_hessp + + +def _compile_functions_for_scipy_optimize( + loss: TensorVariable, + inputs: list[TensorVariable], + compute_grad: bool, + compute_hess: bool, + compute_hessp: bool, + compile_kwargs: dict | None = None, +) -> list[Function] | list[Function, Function | None, Function | None]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + A single flat vector input variable, collecting all inputs to the loss function. Scipy optimize routines + expect the function signature to be f(x, *args), where x is a 1D array of parameters. + compute_grad: bool + Whether to compile a function that computes the gradients of the loss function. + compute_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + compute_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + compile_kwargs: dict, optional + Additional keyword arguments to pass to the ``pm.compile`` function. + + Returns + ------- + f_fused: Function + The compiled loss function, which may also include gradients and hessian if requested. + f_hessp: Function | None + The compiled hessian-vector product function, or None if compute_hessp is False. + """ + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + loss = pm.pytensorf.rewrite_pregrad(loss) + f_hessp = None + + # In the simplest case, we only compile the loss function. Return it as a list to keep the return type consistent + # with the case where we also compute gradients, hessians, or hessian-vector products. + if not (compute_grad or compute_hess or compute_hessp): + f_loss = pm.compile(inputs, loss, **compile_kwargs) + return [f_loss] + + # Otherwise there are three cases. If the user only wants the loss function and gradients, we compile a single + # fused function and return it. If the user also wants the hessian, the fused function will return the loss, + # gradients and hessian. If the user wants gradients and hess_p, we return a fused function that returns the loss + # and gradients, and a separate function for the hessian-vector product. + + if compute_hessp: + # Handle this first, since it can be compiled alone. + p = pt.tensor("p", shape=inputs[0].type.shape) + hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p) + f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs) + + outputs = [loss] + + if compute_grad: + grads = pytensor.gradient.grad(loss, inputs) + grad = pt.concatenate([grad.ravel() for grad in grads]) + outputs.append(grad) + + if compute_hess: + hess = pytensor.gradient.jacobian(grad, inputs)[0] + outputs.append(hess) + + f_fused = pm.compile(inputs, outputs, **compile_kwargs) + + return [f_fused, f_hessp] + + +def scipy_optimize_funcs_from_loss( + loss: TensorVariable, + inputs: list[TensorVariable], + initial_point_dict: dict[str, np.ndarray | float | int], + use_grad: bool, + use_hess: bool, + use_hessp: bool, + gradient_backend: GradientBackend = "pytensor", + compile_kwargs: dict | None = None, +) -> tuple[Callable, ...]: + """ + Compile loss functions for use with scipy.optimize.minimize. + + Parameters + ---------- + loss: TensorVariable + The loss function to compile. + inputs: list[TensorVariable] + The input variables to the loss function. + initial_point_dict: dict[str, np.ndarray | float | int] + Dictionary mapping variable names to initial values. Used to determine the shapes of the input variables. + use_grad: bool + Whether to compile a function that computes the gradients of the loss function. + use_hess: bool + Whether to compile a function that computes the Hessian of the loss function. + use_hessp: bool + Whether to compile a function that computes the Hessian-vector product of the loss function. + gradient_backend: str, default "pytensor" + Which backend to use to compute gradients. Must be one of "jax" or "pytensor" + compile_kwargs: + Additional keyword arguments to pass to the ``pm.compile`` function. + + Returns + ------- + f_fused: Callable + The compiled loss function, which may also include gradients and hessian if requested. + f_hessp: Callable | None + The compiled hessian-vector product function, or None if use_hessp is False. + """ + + compile_kwargs = {} if compile_kwargs is None else compile_kwargs + + if use_hess and not use_grad: + raise ValueError("Cannot compute hessian without also computing the gradient") + + if gradient_backend not in VALID_BACKENDS: + raise ValueError( + f"Invalid gradient backend: {gradient_backend}. Must be one of {VALID_BACKENDS}" + ) + + use_jax_gradients = (gradient_backend == "jax") and use_grad + if use_jax_gradients and not find_spec("jax"): + raise ImportError("JAX must be installed to use JAX gradients") + + mode = compile_kwargs.get("mode", None) + if mode is None and use_jax_gradients: + compile_kwargs["mode"] = "JAX" + elif mode != "JAX" and use_jax_gradients: + raise ValueError( + 'jax gradients can only be used when ``compile_kwargs["mode"]`` is set to "JAX"' + ) + + if not isinstance(inputs, list): + inputs = [inputs] + + [loss], flat_input = join_nonshared_inputs( + point=initial_point_dict, outputs=[loss], inputs=inputs + ) + + # If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When + # computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them + # away. + if use_jax_gradients: + from pymc.sampling.jax import _replace_shared_variables + + [loss] = _replace_shared_variables([loss]) + + compute_grad = use_grad and not use_jax_gradients + compute_hess = use_hess and not use_jax_gradients + compute_hessp = use_hessp and not use_jax_gradients + + funcs = _compile_functions_for_scipy_optimize( + loss=loss, + inputs=[flat_input], + compute_grad=compute_grad, + compute_hess=compute_hess, + compute_hessp=compute_hessp, + compile_kwargs=compile_kwargs, + ) + + # Depending on the requested functions, f_fused will either be the loss function, the loss function with gradients, + # or the loss function with gradients and hessian. + f_fused = funcs.pop(0) + f_hessp = funcs.pop(0) if compute_hessp else None + + if use_jax_gradients: + f_fused, f_hessp = _compile_grad_and_hess_to_jax(f_fused, use_hess, use_hessp) + + return f_fused, f_hessp diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index cddc175ba..774541bc4 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -63,7 +63,7 @@ # TODO: change to typing.Self after Python versions greater than 3.10 from typing_extensions import Self -from pymc_extras.inference.laplace import add_data_to_inferencedata +from pymc_extras.inference.laplace_approx.idata import add_data_to_inference_data from pymc_extras.inference.pathfinder.importance_sampling import ( importance_sampling as _importance_sampling, ) @@ -1759,6 +1759,6 @@ def fit_pathfinder( importance_sampling=importance_sampling, ) - idata = add_data_to_inferencedata(idata, progressbar, model, compile_kwargs) + idata = add_data_to_inference_data(idata, progressbar, model, compile_kwargs) return idata diff --git a/pyproject.toml b/pyproject.toml index 0be357d96..c90ff1c4d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ dynamic = ["version"] # specify the version in the __init__.py file dependencies = [ "pymc>=5.21.1", "scikit-learn", - "better-optimize>=0.1.2", + "better-optimize>=0.1.4", "pydantic>=2.0.0", ] diff --git a/tests/inference/__init__.py b/tests/inference/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/laplace_approx/__init__.py b/tests/inference/laplace_approx/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/inference/laplace_approx/test_find_map.py b/tests/inference/laplace_approx/test_find_map.py new file mode 100644 index 000000000..bf0cb292e --- /dev/null +++ b/tests/inference/laplace_approx/test_find_map.py @@ -0,0 +1,322 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest + +from pymc_extras.inference.laplace_approx.find_map import ( + find_MAP, + get_nearest_psd, + set_optimizer_function_defaults, +) +from pymc_extras.inference.laplace_approx.scipy_interface import ( + GradientBackend, + scipy_optimize_funcs_from_loss, +) + + +@pytest.fixture(scope="session") +def rng(): + seed = sum(map(ord, "test_fit_map")) + return np.random.default_rng(seed) + + +def test_get_nearest_psd_returns_psd(rng): + # Matrix with negative eigenvalues + A = np.array([[2, -3], [-3, 2]]) + psd = get_nearest_psd(A) + + # Should be symmetric + np.testing.assert_allclose(psd, psd.T) + + # All eigenvalues should be >= 0 + eigvals = np.linalg.eigvalsh(psd) + assert np.all(eigvals >= -1e-12), "All eigenvalues should be non-negative" + + +def test_get_nearest_psd_given_psd_input(rng): + L = rng.normal(size=(2, 2)) + A = L @ L.T + psd = get_nearest_psd(A) + + # Given PSD input, should return the same matrix + assert np.allclose(psd, A) + + +def test_set_optimizer_function_defaults_warns_and_prefers_hessp(caplog): + # "trust-ncg" uses_grad=True, uses_hess=True, uses_hessp=True + method = "trust-ncg" + with caplog.at_level("WARNING"): + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, True, True, True) + + message = caplog.messages[0] + assert message.startswith('Both "use_hess" and "use_hessp" are set to True') + + assert use_grad + assert not use_hess + assert use_hessp + + +def test_set_optimizer_function_defaults_infers_hess_and_hessp(): + # "trust-ncg" uses_grad=True, uses_hess=True, uses_hessp=True + method = "trust-ncg" + + # If only use_hessp is set, use_hess should be False but use_grad should be inferred as True + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, None, None, True) + assert use_grad + assert not use_hess + assert use_hessp + + # Only use_hess is set + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, None, True, None) + assert use_hess + assert not use_hessp + + +def test_set_optimizer_function_defaults_defaults(): + # "trust-ncg" uses_grad=True, uses_hess=True, uses_hessp=True + method = "trust-ncg" + use_grad, use_hess, use_hessp = set_optimizer_function_defaults(method, None, None, None) + assert use_grad + assert not use_hess + assert use_hessp + + +@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) +def test_jax_functions_from_graph(gradient_backend: GradientBackend): + pytest.importorskip("jax") + + x = pt.tensor("x", shape=(2,)) + + def compute_z(x): + z1 = x[0] ** 2 + 2 + z2 = x[0] * x[1] + 3 + return z1, z2 + + z = pt.stack(compute_z(x)) + f_fused, f_hessp = scipy_optimize_funcs_from_loss( + loss=z.sum(), + inputs=[x], + initial_point_dict={"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=True, + use_hessp=True, + gradient_backend=gradient_backend, + compile_kwargs=dict(mode="JAX"), + ) + + x_val = np.array([1.0, 2.0]) + expected_z = sum(compute_z(x_val)) + + z_jax, grad_val, hess_val = f_fused(x_val) + np.testing.assert_allclose(z_jax, expected_z) + np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]])) + + hess_val = np.array(hess_val) + np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]])) + + hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0]))) + np.testing.assert_allclose(hessp_val.squeeze(), np.array([2, 1])) + + +@pytest.mark.parametrize( + "method, use_grad, use_hess, use_hessp", + [ + ( + "Newton-CG", + True, + True, + False, + ), + ("Newton-CG", True, False, True), + ("BFGS", True, False, False), + ("L-BFGS-B", True, False, False), + ], +) +@pytest.mark.parametrize( + "backend, gradient_backend", + [("jax", "jax"), ("jax", "pytensor")], + ids=str, +) +def test_find_MAP( + method, use_grad, use_hess, use_hessp, backend, gradient_backend: GradientBackend, rng +): + pytest.importorskip("jax") + + with pm.Model() as m: + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=10)) + + idata = find_MAP( + method=method, + use_grad=use_grad, + use_hess=use_hess, + use_hessp=use_hessp, + progressbar=False, + gradient_backend=gradient_backend, + compile_kwargs={"mode": backend.upper()}, + maxiter=5, + ) + + assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") + assert hasattr(idata, "fit") + assert hasattr(idata, "optimizer_result") + assert hasattr(idata, "observed_data") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + assert "mu" in posterior and "sigma" in posterior + assert posterior["mu"].shape == () + assert posterior["sigma"].shape == () + + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () + + +@pytest.mark.parametrize( + "backend, gradient_backend", + [("jax", "jax")], + ids=str, +) +def test_map_shared_variables(backend, gradient_backend: GradientBackend): + pytest.importorskip("jax") + + with pm.Model() as m: + data = pm.Data("data", np.random.normal(loc=3, scale=1.5, size=10)) + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data) + + idata = find_MAP( + method="L-BFGS-B", + use_grad=True, + use_hess=False, + use_hessp=False, + progressbar=False, + gradient_backend=gradient_backend, + compile_kwargs={"mode": backend.upper()}, + ) + + assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") + assert hasattr(idata, "fit") + assert hasattr(idata, "optimizer_result") + assert hasattr(idata, "observed_data") + assert hasattr(idata, "constant_data") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + + assert "mu" in posterior and "sigma" in posterior + assert posterior["mu"].shape == () + assert posterior["sigma"].shape == () + + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () + + +@pytest.mark.parametrize( + "method, use_grad, use_hess, use_hessp", + [ + ("Newton-CG", True, True, False), + ("Newton-CG", True, False, True), + ], +) +@pytest.mark.parametrize( + "backend, gradient_backend", + [("jax", "pytensor")], + ids=str, +) +def test_find_MAP_basinhopping( + method, use_grad, use_hess, use_hessp, backend, gradient_backend, rng +): + pytest.importorskip("jax") + + with pm.Model() as m: + mu = pm.Normal("mu") + sigma = pm.Exponential("sigma", 1) + pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=10)) + + idata = find_MAP( + method="basinhopping", + use_grad=use_grad, + use_hess=use_hess, + use_hessp=use_hessp, + progressbar=False, + gradient_backend=gradient_backend, + compile_kwargs={"mode": backend.upper()}, + minimizer_kwargs=dict(method=method), + niter=1, + ) + + assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + assert "mu" in posterior + assert posterior["mu"].shape == () + + assert "sigma_log__" in unconstrained_posterior + assert unconstrained_posterior["sigma_log__"].shape == () + + +def test_find_MAP_with_coords(): + with pm.Model(coords={"group": [1, 2, 3, 4, 5]}) as m: + mu_loc = pm.Normal("mu_loc", 0, 1) + mu_scale = pm.HalfNormal("mu_scale", 1) + + mu = pm.Normal("mu", mu_loc, mu_scale, dims=["group"]) + sigma = pm.HalfNormal("sigma", 1, dims=["group"]) + + obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=np.random.normal(size=(10, 5))) + + idata = find_MAP(progressbar=False, method="L-BFGS-B") + + assert hasattr(idata, "posterior") + assert hasattr(idata, "unconstrained_posterior") + assert hasattr(idata, "fit") + + posterior = idata.posterior.squeeze(["chain", "draw"]) + unconstrained_posterior = idata.unconstrained_posterior.squeeze(["chain", "draw"]) + + assert ( + "mu_loc" in posterior + and "mu_scale" in posterior + and "mu" in posterior + and "sigma" in posterior + ) + assert "mu_scale_log__" in unconstrained_posterior and "sigma_log__" in unconstrained_posterior + + assert posterior["mu_loc"].shape == () + assert posterior["mu_scale"].shape == () + assert posterior["mu"].shape == (5,) + assert posterior["sigma"].shape == (5,) + + assert unconstrained_posterior["mu_scale_log__"].shape == () + assert unconstrained_posterior["sigma_log__"].shape == (5,) + + +def test_map_nonscalar_rv_without_dims(): + with pm.Model(coords={"test": ["A", "B", "C"]}) as model: + x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"]) + x = pm.Normal("x", mu=x_loc, sigma=1, shape=(2, 3)) + y = pm.Normal("y", mu=x, sigma=1, observed=np.random.randn(10, 2, 3)) + + idata = find_MAP(method="L-BFGS-B", progressbar=False) + + assert idata.posterior["x"].shape == (1, 1, 2, 3) + assert all(f"x_dim_{i}" in idata.posterior.coords for i in range(2)) + + assert idata.fit.rows.values.tolist() == [ + "x_loc[A]", + "x_loc[B]", + "x_loc[C]", + "x[0,0]", + "x[0,1]", + "x[0,2]", + "x[1,0]", + "x[1,1]", + "x[1,2]", + ] diff --git a/tests/inference/laplace_approx/test_idata.py b/tests/inference/laplace_approx/test_idata.py new file mode 100644 index 000000000..8a2cd4444 --- /dev/null +++ b/tests/inference/laplace_approx/test_idata.py @@ -0,0 +1,297 @@ +from contextlib import contextmanager + +import arviz as az +import numpy as np +import pymc as pm +import pytest +import xarray as xr + +from pymc.blocking import RaveledVars +from scipy.optimize import OptimizeResult +from scipy.sparse.linalg import LinearOperator + +from pymc_extras.inference.laplace_approx.idata import ( + add_data_to_inference_data, + add_fit_to_inference_data, + optimizer_result_to_dataset, +) + + +@contextmanager +def no_op(): + yield + + +@pytest.fixture +def rng(): + return np.random.default_rng() + + +@pytest.fixture +def simple_model(rng): + with pm.Model() as model: + x = pm.Data("data", rng.normal(size=(10,))) + mu = pm.Normal("mu", 0, 1) + sigma = pm.HalfNormal("sigma", 1) + obs = pm.Normal("obs", mu + x, sigma, observed=rng.normal(size=(10,))) + + mu_val = np.array([0.5, 1.0]) + H_inv = np.eye(2) + + point_map_info = (("mu", (), 1, "float64"), ("sigma_log__", (), 1, "float64")) + test_point = RaveledVars(mu_val, point_map_info) + + return model, mu_val, H_inv, test_point + + +@pytest.fixture +def hierarchical_model(rng): + with pm.Model(coords={"group": [1, 2, 3, 4, 5]}) as model: + mu_loc = pm.Normal("mu_loc", 0, 1) + mu_scale = pm.HalfNormal("mu_scale", 1) + mu = pm.Normal("mu", mu_loc, mu_scale, dims="group") + sigma = pm.HalfNormal("sigma", 1) + obs = pm.Normal("obs", mu=mu, sigma=sigma, observed=rng.normal(size=(5, 10))) + + mu_val = rng.normal(size=(8,)) + H_inv = np.eye(8) + + point_map_info = ( + ("mu_loc", (), 1, "float64"), + ("mu_scale_log__", (), 1, "float64"), + ("mu", (5,), 5, "float64"), + ("sigma_log__", (), 1, "float64"), + ) + + test_point = RaveledVars(mu_val, point_map_info) + + return model, mu_val, H_inv, test_point + + +class TestFittoInferenceData: + def check_idata(self, idata, var_names, n_vars): + assert "fit" in idata.groups() + + fit = idata.fit + assert "mean_vector" in fit + assert "covariance_matrix" in fit + assert fit["mean_vector"].shape[0] == n_vars + assert fit["covariance_matrix"].shape == (n_vars, n_vars) + + assert list(fit.coords.keys()) == ["rows", "columns"] + assert fit.coords["rows"].values.tolist() == var_names + assert fit.coords["columns"].values.tolist() == var_names + + @pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) + def test_add_fit_to_inferencedata(self, use_context, simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + idata = az.from_dict( + posterior={"mu": rng.normal(size=()), "sigma_log__": rng.normal(size=())} + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model_arg) + + self.check_idata(idata2, ["mu", "sigma_log__"], 2) + + @pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) + def test_add_fit_with_coords_to_inferencedata(self, use_context, hierarchical_model, rng): + model, mu_val, H_inv, test_point = hierarchical_model + idata = az.from_dict( + posterior={ + "mu_loc": rng.normal(size=()), + "mu_scale_log__": rng.normal(size=()), + "mu": rng.normal(size=(5,)), + "sigma_log__": rng.normal(size=()), + } + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + idata2 = add_fit_to_inference_data(idata, test_point, H_inv, model=model_arg) + + self.check_idata( + idata2, + [ + "mu_loc", + "mu_scale_log__", + "mu[1]", + "mu[2]", + "mu[3]", + "mu[4]", + "mu[5]", + "sigma_log__", + ], + 8, + ) + + +@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) +def test_add_data_to_inferencedata(use_context, simple_model, rng): + model, *_ = simple_model + + idata = az.from_dict( + posterior={"mu": rng.standard_normal((1, 1)), "sigma_log__": rng.standard_normal((1, 1))} + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + idata2 = add_data_to_inference_data(idata, model=model_arg) + + assert "observed_data" in idata2.groups() + assert "constant_data" in idata2.groups() + assert "obs" in idata2.observed_data + + +@pytest.mark.parametrize("use_context", [False, True], ids=["model_arg", "model_context"]) +def test_optimizer_result_to_dataset_basic(use_context, simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + result = OptimizeResult( + x=np.array([1.0, 2.0]), + fun=0.5, + success=True, + message="Optimization succeeded", + jac=np.array([0.1, 0.2]), + nit=5, + nfev=10, + njev=3, + status=0, + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + with context: + ds = optimizer_result_to_dataset(result, method="BFGS", model=model_arg, mu=test_point) + + assert isinstance(ds, xr.Dataset) + assert all( + key in ds + for key in [ + "x", + "fun", + "success", + "message", + "jac", + "nit", + "nfev", + "njev", + "status", + "method", + ] + ) + + assert list(ds["x"].coords.keys()) == ["variables"] + assert ds["x"].coords["variables"].values.tolist() == ["mu", "sigma_log__"] + + assert list(ds["jac"].coords.keys()) == ["variables"] + assert ds["jac"].coords["variables"].values.tolist() == ["mu", "sigma_log__"] + + +@pytest.mark.parametrize( + "optimizer_method, use_context, model_name", + [("BFGS", True, "hierarchical_model"), ("L-BFGS-B", False, "simple_model")], +) +def test_optimizer_result_to_dataset_hess_inv_types( + optimizer_method, use_context, model_name, rng, request +): + def get_hess_inv_and_expected_names(method): + model, mu_val, H_inv, test_point = request.getfixturevalue(model_name) + n = mu_val.shape[0] + + if method == "BFGS": + hess_inv = np.eye(n) + expected_names = [ + "mu_loc", + "mu_scale_log__", + "mu[1]", + "mu[2]", + "mu[3]", + "mu[4]", + "mu[5]", + "sigma_log__", + ] + result = OptimizeResult( + x=np.zeros((n,)), + hess_inv=hess_inv, + ) + elif method == "L-BFGS-B": + + def linop_func(x): + return np.array([2 * xi for xi in x]) + + linop = LinearOperator((n, n), matvec=linop_func) + hess_inv = 2 * np.eye(n) + expected_names = ["mu", "sigma_log__"] + result = OptimizeResult( + x=np.ones(n), + hess_inv=linop, + ) + else: + raise ValueError("Unknown optimizer_method") + + return model, test_point, hess_inv, expected_names, result + + model, test_point, hess_inv, expected_names, result = get_hess_inv_and_expected_names( + optimizer_method + ) + + context = model if use_context else no_op() + model_arg = model if not use_context else None + + with context: + ds = optimizer_result_to_dataset( + result, method=optimizer_method, mu=test_point, model=model_arg + ) + + assert "hess_inv" in ds + assert ds["hess_inv"].shape == (len(expected_names), len(expected_names)) + assert list(ds["hess_inv"].coords.keys()) == ["variables", "variables_aux"] + assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names + assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names + np.testing.assert_allclose(ds["hess_inv"].values, hess_inv) + + +def test_optimizer_result_to_dataset_extra_fields(simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + + result = OptimizeResult( + x=np.array([1.0, 2.0]), + custom_stat=np.array([42, 43]), + ) + + with model: + ds = optimizer_result_to_dataset(result, method="BFGS", mu=test_point) + + assert "custom_stat" in ds + assert ds["custom_stat"].shape == (2,) + assert list(ds["custom_stat"].coords.keys()) == ["custom_stat_dim_0"] + assert ds["custom_stat"].coords["custom_stat_dim_0"].values.tolist() == [0, 1] + + +def test_optimizer_result_to_dataset_hess_inv_basinhopping(simple_model, rng): + model, mu_val, H_inv, test_point = simple_model + n = mu_val.shape[0] + hess_inv_inner = np.eye(n) * 3.0 + + # Basinhopping returns an OptimizeResult with a nested OptimizeResult + result = OptimizeResult( + x=np.ones(n), + lowest_optimization_result=OptimizeResult(x=np.ones(n), hess_inv=hess_inv_inner), + ) + + with model: + ds = optimizer_result_to_dataset(result, method="basinhopping", mu=test_point) + + assert "hess_inv" in ds + assert ds["hess_inv"].shape == (n, n) + np.testing.assert_allclose(ds["hess_inv"].values, hess_inv_inner) + expected_names = ["mu", "sigma_log__"] + assert ds["hess_inv"].coords["variables"].values.tolist() == expected_names + assert ds["hess_inv"].coords["variables_aux"].values.tolist() == expected_names diff --git a/tests/test_laplace.py b/tests/inference/laplace_approx/test_laplace.py similarity index 58% rename from tests/test_laplace.py rename to tests/inference/laplace_approx/test_laplace.py index 72ff3e937..be5665d07 100644 --- a/tests/test_laplace.py +++ b/tests/inference/laplace_approx/test_laplace.py @@ -19,12 +19,10 @@ import pymc_extras as pmx -from pymc_extras.inference.find_map import GradientBackend, find_MAP -from pymc_extras.inference.laplace import ( +from pymc_extras.inference.laplace_approx.find_map import GradientBackend +from pymc_extras.inference.laplace_approx.laplace import ( fit_laplace, - fit_mvn_at_MAP, get_conditional_gaussian_approximation, - sample_laplace_posterior, ) @@ -42,7 +40,7 @@ def rng(): "mode, gradient_backend", [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], ) -def test_laplace(mode, gradient_backend: GradientBackend): +def test_fit_laplace_basic(mode, gradient_backend: GradientBackend): # Example originates from Bayesian Data Analyses, 3rd Edition # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, # Aki Vehtari, and Donald Rubin. @@ -53,8 +51,8 @@ def test_laplace(mode, gradient_backend: GradientBackend): draws = 100000 with pm.Model() as m: - mu = pm.Uniform("mu", -10000, 10000) - logsigma = pm.Uniform("logsigma", 1, 100) + mu = pm.Flat("mu") + logsigma = pm.Flat("logsigma") yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) vars = [mu, logsigma] @@ -67,6 +65,7 @@ def test_laplace(mode, gradient_backend: GradientBackend): chains=1, compile_kwargs={"mode": mode}, gradient_backend=gradient_backend, + optimizer_kwargs=dict(tol=1e-20), ) assert idata.posterior["mu"].shape == (1, draws) @@ -78,59 +77,13 @@ def test_laplace(mode, gradient_backend: GradientBackend): bda_map = [y.mean(), np.log(y.std())] bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) - np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) - np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + np.testing.assert_allclose(idata.posterior["mu"].mean(), bda_map[0], atol=1) + np.testing.assert_allclose(idata.posterior["logsigma"].mean(), bda_map[1], rtol=1e-3) + np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, rtol=1e-3, atol=1e-3) -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_laplace_only_fit(mode, gradient_backend: GradientBackend): - # Example originates from Bayesian Data Analyses, 3rd Edition - # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, - # Aki Vehtari, and Donald Rubin. - # See section. 4.1 - - y = np.array([2642, 3503, 4358], dtype=np.float64) - n = y.size - with pm.Model() as m: - logsigma = pm.Uniform("logsigma", 1, 100) - mu = pm.Uniform("mu", -10000, 10000) - yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) - vars = [mu, logsigma] - - idata = pmx.fit( - method="laplace", - optimize_method="BFGS", - progressbar=True, - gradient_backend=gradient_backend, - compile_kwargs={"mode": mode}, - optimizer_kwargs=dict(maxiter=100_000, gtol=1e-100), - random_seed=173300, - ) - - assert idata.fit["mean_vector"].shape == (len(vars),) - assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars)) - - bda_map = [np.log(y.std()), y.mean()] - bda_cov = np.array([[1 / (2 * n), 0], [0, y.var() / n]]) - - np.testing.assert_allclose(idata.fit["mean_vector"].values, bda_map) - np.testing.assert_allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) - - -@pytest.mark.parametrize( - "transform_samples", - [True, False], - ids=["transformed", "untransformed"], -) -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: GradientBackend): +def test_fit_laplace_coords(rng): coords = {"city": ["A", "B", "C"], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as model: mu = pm.Normal("mu", mu=3, sigma=0.5, dims=["city"]) @@ -143,49 +96,32 @@ def test_fit_laplace_coords(rng, transform_samples, mode, gradient_backend: Grad dims=["obs_idx", "city"], ) - optimized_point = find_MAP( - method="trust-ncg", - use_grad=True, - use_hessp=True, - progressbar=False, - compile_kwargs=dict(mode=mode), - gradient_backend=gradient_backend, - ) - - for value in optimized_point.values(): - assert value.shape == (3,) - - mu, H_inv = fit_mvn_at_MAP( - optimized_point=optimized_point, - model=model, - transform_samples=transform_samples, - ) - - idata = sample_laplace_posterior( - mu=mu, H_inv=H_inv, model=model, transform_samples=transform_samples + idata = pmx.fit( + method="laplace", + optimize_method="trust-ncg", + chains=1, + draws=1000, + optimizer_kwargs=dict(tol=1e-20), ) - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2, 3), 3), atol=0.5) np.testing.assert_allclose( - np.mean(idata.posterior.sigma, axis=1), np.full((2, 3), 1.5), atol=0.3 + idata.posterior.mu.mean(dim=["chain", "draw"]).values, np.full((3,), 3), atol=0.5 + ) + np.testing.assert_allclose( + idata.posterior.sigma.mean(dim=["chain", "draw"]).values, np.full((3,), 1.5), atol=0.3 ) - suffix = "_log__" if transform_samples else "" assert idata.fit.rows.values.tolist() == [ "mu[A]", "mu[B]", "mu[C]", - f"sigma{suffix}[A]", - f"sigma{suffix}[B]", - f"sigma{suffix}[C]", + "sigma_log__[A]", + "sigma_log__[B]", + "sigma_log__[C]", ] -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng): +def test_fit_laplace_ragged_coords(rng): coords = {"city": ["A", "B", "C"], "feature": [0, 1], "obs_idx": np.arange(100)} with pm.Model(coords=coords) as ragged_dim_model: X = pm.Data("X", np.ones((100, 2)), dims=["obs_idx", "feature"]) @@ -210,10 +146,12 @@ def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng) progressbar=False, use_grad=True, use_hessp=True, - gradient_backend=gradient_backend, - compile_kwargs={"mode": mode}, ) + # These should have been dropped when the laplace idata was created + assert "laplace_approximation" not in list(idata.posterior.data_vars.keys()) + assert "unpacked_var_names" not in list(idata.posterior.coords.keys()) + assert idata["posterior"].beta.shape[-2:] == (3, 2) assert idata["posterior"].sigma.shape[-1:] == (3,) @@ -223,50 +161,71 @@ def test_fit_laplace_ragged_coords(mode, gradient_backend: GradientBackend, rng) assert (idata["posterior"].beta.sel(feature=1).to_numpy() > 0).all() -@pytest.mark.parametrize( - "fit_in_unconstrained_space", - [True, False], - ids=["transformed", "untransformed"], -) -@pytest.mark.parametrize( - "mode, gradient_backend", - [(None, "pytensor"), ("NUMBA", "pytensor"), ("JAX", "jax"), ("JAX", "pytensor")], -) -def test_fit_laplace(fit_in_unconstrained_space, mode, gradient_backend: GradientBackend): - with pm.Model() as simp_model: - mu = pm.Normal("mu", mu=3, sigma=0.5) - sigma = pm.Exponential("sigma", 1) - obs = pm.Normal( - "obs", - mu=mu, - sigma=sigma, - observed=np.random.default_rng().normal(loc=3, scale=1.5, size=(10000,)), - ) +def test_model_with_nonstandard_dimensionality(rng): + y_obs = np.concatenate( + [rng.normal(-1, 2, size=150), rng.normal(3, 1, size=350), rng.normal(5, 4, size=50)] + ) - idata = fit_laplace( - optimize_method="trust-ncg", - use_grad=True, - use_hessp=True, - fit_in_unconstrained_space=fit_in_unconstrained_space, - optimizer_kwargs=dict(maxiter=100_000, tol=1e-100), - compile_kwargs={"mode": mode}, - gradient_backend=gradient_backend, - ) + with pm.Model(coords={"obs_idx": range(y_obs.size), "class": ["A", "B", "C"]}) as model: + y = pm.Data("y", y_obs, dims=["obs_idx"]) + + mu = pm.Normal("mu", mu=1, sigma=3, dims=["class"]) + sigma = pm.HalfNormal("sigma", sigma=3, dims=["class"]) - np.testing.assert_allclose(np.mean(idata.posterior.mu, axis=1), np.full((2,), 3), atol=0.1) - np.testing.assert_allclose( - np.mean(idata.posterior.sigma, axis=1), np.full((2,), 1.5), atol=0.1 + w = pm.Dirichlet( + "w", + a=np.ones( + 3, + ), + dims=["class"], + ) + class_idx = pm.Categorical("class_idx", p=w, dims=["obs_idx"]) + y_hat = pm.Normal( + "obs", mu=mu[class_idx], sigma=sigma[class_idx], observed=y, dims=["obs_idx"] ) - if fit_in_unconstrained_space: - assert idata.fit.rows.values.tolist() == ["mu", "sigma_log__"] - np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 0.4]), atol=0.1) - else: - assert idata.fit.rows.values.tolist() == ["mu", "sigma"] - np.testing.assert_allclose(idata.fit.mean_vector.values, np.array([3.0, 1.5]), atol=0.1) + with pmx.marginalize(model, [class_idx]): + idata = pmx.fit_laplace(progressbar=False) + + # The dirichlet value variable has a funky shape; check that it got a default + assert "w_simplex___dim_0" in list(idata.unconstrained_posterior.w_simplex__.coords.keys()) + assert "class" not in list(idata.unconstrained_posterior.w_simplex__.coords.keys()) + assert len(idata.unconstrained_posterior.coords["w_simplex___dim_0"]) == 2 + + # On the other hand, check that the actual w has the correct dims + assert "class" in list(idata.posterior.w.coords.keys()) + # The log transform is 1-to-1, so it should have the same dims as the original rv + assert "class" in list(idata.unconstrained_posterior.sigma_log__.coords.keys()) -def test_laplace_scalar(): + +def test_laplace_nonscalar_rv_without_dims(): + with pm.Model(coords={"test": ["A", "B", "C"]}) as model: + x_loc = pm.Normal("x_loc", mu=0, sigma=1, dims=["test"]) + x = pm.Normal("x", mu=x_loc, sigma=1, shape=(2, 3)) + y = pm.Normal("y", mu=x, sigma=1, observed=np.random.randn(10, 2, 3)) + + idata = pmx.fit_laplace(progressbar=False) + + assert idata.posterior["x"].shape == (2, 500, 2, 3) + assert all(f"x_dim_{i}" in idata.posterior.coords for i in range(2)) + assert idata.fit.rows.values.tolist() == [ + "x_loc[A]", + "x_loc[B]", + "x_loc[C]", + "x[0,0]", + "x[0,1]", + "x[0,2]", + "x[1,0]", + "x[1,1]", + "x[1,2]", + ] + + +# Test these three optimizers because they are either special cases for H_inv (BFGS, L-BFGS-B) or are +# gradient free and require re-compilation of hessp (powell). +@pytest.mark.parametrize("optimizer_method", ["BFGS", "L-BFGS-B", "powell"]) +def test_laplace_scalar_basinhopping(optimizer_method): # Example model from Statistical Rethinking data = np.array([0, 0, 0, 1, 1, 1, 1, 1, 1]) @@ -274,12 +233,18 @@ def test_laplace_scalar(): p = pm.Uniform("p", 0, 1) w = pm.Binomial("w", n=len(data), p=p, observed=data.sum()) - idata_laplace = pmx.fit_laplace(progressbar=False) + idata_laplace = pmx.fit_laplace( + optimize_method="basinhopping", + optimizer_kwargs={"minimizer_kwargs": {"method": optimizer_method}, "niter": 1}, + progressbar=False, + ) assert idata_laplace.fit.mean_vector.shape == (1,) assert idata_laplace.fit.covariance_matrix.shape == (1, 1) - np.testing.assert_allclose(idata_laplace.fit.mean_vector.values.item(), data.mean(), atol=0.1) + np.testing.assert_allclose( + idata_laplace.posterior.p.mean(dim=["chain", "draw"]), data.mean(), atol=0.1 + ) def test_get_conditional_gaussian_approximation(): diff --git a/tests/inference/laplace_approx/test_scipy_interface.py b/tests/inference/laplace_approx/test_scipy_interface.py new file mode 100644 index 000000000..4e8b56a08 --- /dev/null +++ b/tests/inference/laplace_approx/test_scipy_interface.py @@ -0,0 +1,118 @@ +import numpy as np +import pytest + +from pytensor import tensor as pt + +from pymc_extras.inference.laplace_approx import scipy_interface + + +@pytest.fixture +def simple_loss_and_inputs(): + x = pt.vector("x") + loss = pt.sum(x**2) + return loss, [x] + + +def test_compile_functions_for_scipy_optimize_loss_only(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=False, compute_hess=False, compute_hessp=False + ) + assert len(funcs) == 1 + f_loss = funcs[0] + x_val = np.array([1.0, 2.0, 3.0]) + result = f_loss(x_val) + assert np.isclose(result, np.sum(x_val**2)) + + +def test_compile_functions_for_scipy_optimize_with_grad(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=True, compute_hess=False, compute_hessp=False + ) + f_fused = funcs[0] + x_val = np.array([1.0, 2.0, 3.0]) + loss_val, grad_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + + +def test_compile_functions_for_scipy_optimize_with_hess(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=True, compute_hess=True, compute_hessp=False + ) + f_fused = funcs[0] + x_val = np.array([1.0, 2.0]) + loss_val, grad_val, hess_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + assert np.allclose(hess_val, 2 * np.eye(len(x_val))) + + +def test_compile_functions_for_scipy_optimize_with_hessp(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + funcs = scipy_interface._compile_functions_for_scipy_optimize( + loss, inputs, compute_grad=True, compute_hess=False, compute_hessp=True + ) + f_fused, f_hessp = funcs + x_val = np.array([1.0, 2.0]) + p_val = np.array([1.0, 0.0]) + + loss_val, grad_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + + hessp_val = f_hessp(x_val, p_val) + assert np.allclose(hessp_val, 2 * p_val) + + +def test_scipy_optimize_funcs_from_loss_invalid_backend(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + with pytest.raises(ValueError, match="Invalid gradient backend"): + scipy_interface.scipy_optimize_funcs_from_loss( + loss, + inputs, + {"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=False, + use_hessp=False, + gradient_backend="not_a_backend", + ) + + +def test_scipy_optimize_funcs_from_loss_hess_without_grad(simple_loss_and_inputs): + loss, inputs = simple_loss_and_inputs + with pytest.raises( + ValueError, match="Cannot compute hessian without also computing the gradient" + ): + scipy_interface.scipy_optimize_funcs_from_loss( + loss, + inputs, + {"x": np.array([1.0, 2.0])}, + use_grad=False, + use_hess=True, + use_hessp=False, + ) + + +@pytest.mark.parametrize("backend", ["pytensor", "jax"], ids=str) +def test_scipy_optimize_funcs_from_loss_backend(backend, simple_loss_and_inputs): + if backend == "jax": + pytest.importorskip("jax", reason="JAX is not installed") + + loss, inputs = simple_loss_and_inputs + f_fused, f_hessp = scipy_interface.scipy_optimize_funcs_from_loss( + loss, + inputs, + {"x": np.array([1.0, 2.0])}, + use_grad=True, + use_hess=False, + use_hessp=False, + gradient_backend=backend, + ) + x_val = np.array([1.0, 2.0]) + loss_val, grad_val = f_fused(x_val) + assert np.isclose(loss_val, np.sum(x_val**2)) + assert np.allclose(grad_val, 2 * x_val) + assert f_hessp is None diff --git a/tests/test_find_map.py b/tests/test_find_map.py deleted file mode 100644 index f5aa549c7..000000000 --- a/tests/test_find_map.py +++ /dev/null @@ -1,158 +0,0 @@ -import numpy as np -import pymc as pm -import pytensor -import pytensor.tensor as pt -import pytest - -from pymc_extras.inference.find_map import ( - GradientBackend, - find_MAP, - scipy_optimize_funcs_from_loss, -) - -pytest.importorskip("jax") - - -@pytest.fixture(scope="session") -def rng(): - seed = sum(map(ord, "test_fit_map")) - return np.random.default_rng(seed) - - -@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) -def test_jax_functions_from_graph(gradient_backend: GradientBackend): - x = pt.tensor("x", shape=(2,)) - - def compute_z(x): - z1 = x[0] ** 2 + 2 - z2 = x[0] * x[1] + 3 - return z1, z2 - - z = pt.stack(compute_z(x)) - f_loss, f_hess, f_hessp = scipy_optimize_funcs_from_loss( - loss=z.sum(), - inputs=[x], - initial_point_dict={"x": np.array([1.0, 2.0])}, - use_grad=True, - use_hess=True, - use_hessp=True, - gradient_backend=gradient_backend, - compile_kwargs=dict(mode="JAX"), - ) - - x_val = np.array([1.0, 2.0]) - expected_z = sum(compute_z(x_val)) - - z_jax, grad_val = f_loss(x_val) - np.testing.assert_allclose(z_jax, expected_z) - np.testing.assert_allclose(grad_val.squeeze(), np.array([2 * x_val[0] + x_val[1], x_val[0]])) - - hess_val = np.array(f_hess(x_val)) - np.testing.assert_allclose(hess_val.squeeze(), np.array([[2, 1], [1, 0]])) - - hessp_val = np.array(f_hessp(x_val, np.array([1.0, 0.0]))) - np.testing.assert_allclose(hessp_val.squeeze(), np.array([2, 1])) - - -@pytest.mark.parametrize( - "method, use_grad, use_hess, use_hessp", - [ - ("nelder-mead", False, False, False), - ("powell", False, False, False), - ("CG", True, False, False), - ("BFGS", True, False, False), - ("L-BFGS-B", True, False, False), - ("TNC", True, False, False), - ("SLSQP", True, False, False), - ("dogleg", True, True, False), - ("Newton-CG", True, True, False), - ("Newton-CG", True, False, True), - ("trust-ncg", True, True, False), - ("trust-ncg", True, False, True), - ("trust-exact", True, True, False), - ("trust-krylov", True, True, False), - ("trust-krylov", True, False, True), - ("trust-constr", True, True, False), - ], -) -@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str) -def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng): - extra_kwargs = {} - if method == "dogleg": - # HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point - # where this is true - extra_kwargs = {"initvals": {"mu": 2, "sigma_log__": 1}} - - with pm.Model() as m: - mu = pm.Normal("mu") - sigma = pm.Exponential("sigma", 1) - pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) - - optimized_point = find_MAP( - method=method, - **extra_kwargs, - use_grad=use_grad, - use_hess=use_hess, - use_hessp=use_hessp, - progressbar=False, - gradient_backend=gradient_backend, - compile_kwargs={"mode": "JAX"}, - ) - mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] - - assert np.isclose(mu_hat, 3, atol=0.5) - assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) - - -def test_JAX_map_shared_variables(): - with pm.Model() as m: - data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data") - mu = pm.Normal("mu") - sigma = pm.Exponential("sigma", 1) - y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data) - - optimized_point = find_MAP( - method="L-BFGS-B", - use_grad=True, - use_hess=False, - use_hessp=False, - progressbar=False, - gradient_backend="jax", - compile_kwargs={"mode": "JAX"}, - ) - mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] - - assert np.isclose(mu_hat, 3, atol=0.5) - assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5) - - -@pytest.mark.parametrize( - "method, use_grad, use_hess, use_hessp", - [ - ("nelder-mead", False, False, False), - ("L-BFGS-B", True, False, False), - ("trust-exact", True, True, False), - ("trust-ncg", True, False, True), - ], -) -def test_find_MAP_basinhopping(method, use_grad, use_hess, use_hessp, rng): - with pm.Model() as m: - mu = pm.Normal("mu") - sigma = pm.Exponential("sigma", 1) - pm.Normal("y_hat", mu=mu, sigma=sigma, observed=rng.normal(loc=3, scale=1.5, size=100)) - - optimized_point = find_MAP( - method="basinhopping", - use_grad=use_grad, - use_hess=use_hess, - use_hessp=use_hessp, - progressbar=False, - gradient_backend="pytensor", - compile_kwargs={"mode": "JAX"}, - minimizer_kwargs=dict(method=method), - ) - - mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"] - - assert np.isclose(mu_hat, 3, atol=0.5) - assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)