Skip to content

Commit b4faf29

Browse files
authored
Merge pull request #539 from alan-turing-institute/refactor_sa
Refactor sensitivity analysis
2 parents 7ae1d2a + 77f83b0 commit b4faf29

File tree

4 files changed

+435
-20
lines changed

4 files changed

+435
-20
lines changed

autoemulate/compare.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -986,8 +986,6 @@ def sensitivity_analysis(
986986
problem=None,
987987
N=1024,
988988
conf_level=0.95,
989-
as_df=True,
990-
**plot_kwargs,
991989
):
992990
"""Perform Sobol sensitivity analysis on a fitted emulator.
993991
@@ -1024,23 +1022,15 @@ def sensitivity_analysis(
10241022
conf_level : float, optional
10251023
Confidence level (between 0 and 1) for calculating confidence intervals of the
10261024
sensitivity indices. Default is 0.95 (95% confidence).
1027-
as_df : bool, optional
1028-
If True, returns results as a long-format pandas DataFrame with columns for
1029-
parameters, sensitivity indices, and confidence intervals. If False, returns
1030-
the raw SALib results dictionary. Default is True.
10311025
10321026
Returns
10331027
-------
1034-
pandas.DataFrame or dict
1035-
If as_df=True (default), returns a DataFrame with columns:
1036-
1028+
pandas.DataFrame
10371029
- 'parameter': Input parameter name
10381030
- 'output': Output variable name
10391031
- 'S1', 'S2', 'ST': First, second, and total order sensitivity indices
10401032
- 'S1_conf', 'S2_conf', 'ST_conf': Confidence intervals for each index
10411033
1042-
If as_df=False, returns the raw SALib results dictionary.
1043-
10441034
Notes
10451035
-----
10461036
The analysis requires N * (2D + 2) model evaluations, where D is the number of input
Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,258 @@
1+
import pandas as pd
2+
from SALib.analyze.morris import analyze as morris_analyze
3+
from SALib.analyze.sobol import analyze as sobol_analyze
4+
from SALib.sample.morris import sample as morris_sample
5+
from SALib.sample.sobol import sample as sobol_sample
6+
7+
from autoemulate.experimental.data.utils import ConversionMixin
8+
from autoemulate.experimental.emulators.base import Emulator
9+
from autoemulate.experimental.types import DistributionLike, NumpyLike, TensorLike
10+
11+
# NOTE: we still use these functions from main
12+
# should we just move them to experimental as well?
13+
from autoemulate.sensitivity_analysis import (
14+
_morris_results_to_df,
15+
_plot_morris_analysis,
16+
_plot_sobol_analysis,
17+
_sobol_results_to_df,
18+
)
19+
20+
21+
class SensitivityAnalysis(ConversionMixin):
22+
"""
23+
Global sensitivity analysis.
24+
"""
25+
26+
def __init__(
27+
self,
28+
emulator: Emulator,
29+
x: TensorLike | None = None,
30+
problem: dict | None = None,
31+
):
32+
"""
33+
Parameters
34+
----------
35+
emulator : Emulator
36+
Fitted emulator.
37+
x : InputLike | None
38+
Simulator input parameter values.
39+
problem : dict | None
40+
The problem definition dictionary. If None, the problem is generated
41+
from x using minimum and maximum values of the features as bounds.
42+
The dictionary should contain:
43+
- 'num_vars': Number of input variables (int)
44+
- 'names': List of variable names (list of str)
45+
- 'bounds': List of [min, max] bounds for each variable (list of lists)
46+
- 'output_names': Optional list of output names (list of str)
47+
48+
Example::
49+
problem = {
50+
"num_vars": 2,
51+
"names": ["x1", "x2"],
52+
"bounds": [[0, 1], [0, 1]],
53+
"output_names": ["y1", "y2"], # optional
54+
}
55+
"""
56+
if problem is not None:
57+
problem = self._check_problem(problem)
58+
elif x is not None:
59+
problem = self._generate_problem(x)
60+
else:
61+
msg = "Either problem or x must be provided."
62+
raise ValueError(msg)
63+
64+
self.emulator = emulator
65+
self.problem = problem
66+
67+
@staticmethod
68+
def _check_problem(problem: dict) -> dict:
69+
"""
70+
Check that the problem definition is valid.
71+
"""
72+
if not isinstance(problem, dict):
73+
msg = "problem must be a dictionary."
74+
raise ValueError(msg)
75+
76+
if "num_vars" not in problem:
77+
msg = "problem must contain 'num_vars'."
78+
raise ValueError(msg)
79+
if "names" not in problem:
80+
msg = "problem must contain 'names'."
81+
raise ValueError(msg)
82+
if "bounds" not in problem:
83+
msg = "problem must contain 'bounds'."
84+
raise ValueError(msg)
85+
86+
if len(problem["names"]) != problem["num_vars"]:
87+
msg = "Length of 'names' must match 'num_vars'."
88+
raise ValueError(msg)
89+
if len(problem["bounds"]) != problem["num_vars"]:
90+
msg = "Length of 'bounds' must match 'num_vars'."
91+
raise ValueError(msg)
92+
93+
return problem
94+
95+
@staticmethod
96+
def _generate_problem(x: TensorLike) -> dict:
97+
"""
98+
Generate a problem definition from a design matrix.
99+
100+
Parameters
101+
----------
102+
x : TensorLike
103+
Simulator input parameter values [n_samples, n_parameters].
104+
"""
105+
if x.ndim == 1:
106+
msg = "x must be a 2D array."
107+
raise ValueError(msg)
108+
109+
return {
110+
"num_vars": x.shape[1],
111+
"names": [f"X{i + 1}" for i in range(x.shape[1])],
112+
"bounds": [
113+
[x[:, i].min().item(), x[:, i].max().item()] for i in range(x.shape[1])
114+
],
115+
}
116+
117+
def _sample(self, method: str, N: int) -> NumpyLike:
118+
if method == "sobol":
119+
# Saltelli sampling
120+
return sobol_sample(self.problem, N)
121+
if method == "morris":
122+
# vanilla Morris (1991) sampling
123+
return morris_sample(self.problem, N)
124+
msg = f"Unknown method: {method}. Must be 'sobol' or 'morris'."
125+
raise ValueError(msg)
126+
127+
def _predict(self, param_samples: NumpyLike) -> NumpyLike:
128+
"""
129+
Make predictions with emulator for N input samples.
130+
"""
131+
132+
param_tensor = self._convert_to_tensors(param_samples)
133+
assert isinstance(param_tensor, TensorLike)
134+
y_pred = self.emulator.predict(param_tensor)
135+
136+
# handle types, convert to numpy
137+
if isinstance(y_pred, TensorLike):
138+
y_pred_np, _ = self._convert_to_numpy(y_pred)
139+
elif isinstance(y_pred, DistributionLike):
140+
y_pred_np, _ = self._convert_to_numpy(y_pred.mean)
141+
else:
142+
msg = "Emulator has to return Tensor or Distribution"
143+
raise ValueError(msg)
144+
145+
return y_pred_np
146+
147+
def _get_output_names(self, num_outputs: int) -> list[str]:
148+
"""
149+
Get the output names from the problem definition or generate default names.
150+
"""
151+
# check if output_names is given
152+
if "output_names" not in self.problem:
153+
output_names = [f"y{i + 1}" for i in range(num_outputs)]
154+
elif isinstance(self.problem["output_names"], list):
155+
output_names = self.problem["output_names"]
156+
else:
157+
msg = "'output_names' must be a list of strings."
158+
raise ValueError(msg)
159+
160+
return output_names
161+
162+
def run(
163+
self,
164+
method: str = "sobol",
165+
n_samples: int = 1024,
166+
conf_level: float = 0.95,
167+
) -> pd.DataFrame:
168+
"""
169+
Perform global sensitivity analysis on a fitted emulator.
170+
171+
Parameters
172+
----------
173+
method: str
174+
The sensitivity analysis method to perform, one of ["sobol", "morris"].
175+
n_samples : int
176+
Number of samples to generate for the analysis. Higher values give more
177+
accurate results but increase computation time. Default is 1024.
178+
conf_level : float
179+
Confidence level (between 0 and 1) for calculating confidence intervals
180+
of the Sobol sensitivity indices. Default is 0.95 (95% confidence). This
181+
is not used in Morris sensitivity analysis.
182+
183+
Returns
184+
-------
185+
pandas.DataFrame
186+
DataFrame with columns:
187+
- 'parameter': Input parameter name
188+
- 'output': Output variable name
189+
- 'S1', 'S2', 'ST': First, second, and total order sensitivity indices
190+
- 'S1_conf', 'S2_conf', 'ST_conf': Confidence intervals for each index
191+
192+
Notes
193+
-----
194+
The Sobol method requires N * (2D + 2) model evaluations, where D is the number
195+
of input parameters. For example, with N=1024 and 5 parameters, this requires
196+
12,288 evaluations. The Morris method requires far fewer computations.
197+
"""
198+
if method not in ["sobol", "morris"]:
199+
msg = f"Unknown method: {method}. Must be 'sobol' or 'morris'."
200+
raise ValueError(msg)
201+
202+
param_samples = self._sample(method, n_samples)
203+
y = self._predict(param_samples)
204+
output_names = self._get_output_names(y.shape[1])
205+
206+
results = {}
207+
for i, name in enumerate(output_names):
208+
if method == "sobol":
209+
Si = sobol_analyze(self.problem, y[:, i], conf_level=conf_level)
210+
elif method == "morris":
211+
Si = morris_analyze(self.problem, param_samples, y[:, i])
212+
results[name] = Si # type: ignore PGH003
213+
214+
if method == "sobol":
215+
return _sobol_results_to_df(results)
216+
return _morris_results_to_df(results, self.problem)
217+
218+
@staticmethod
219+
def plot_sobol(results, index="S1", n_cols=None, figsize=None):
220+
"""
221+
Plot Sobol sensitivity analysis results.
222+
223+
Parameters:
224+
-----------
225+
results : pd.DataFrame
226+
The results from sobol_results_to_df.
227+
index : str, default "S1"
228+
The type of sensitivity index to plot.
229+
- "S1": first-order indices
230+
- "S2": second-order/interaction indices
231+
- "ST": total-order indices
232+
n_cols : int, optional
233+
The number of columns in the plot. Defaults to 3 if there are 3 or
234+
more outputs, otherwise the number of outputs.
235+
figsize : tuple, optional
236+
Figure size as (width, height) in inches. If None, set automatically.
237+
"""
238+
return _plot_sobol_analysis(results, index, n_cols, figsize)
239+
240+
@staticmethod
241+
def plot_morris(results, param_groups=None, n_cols=None, figsize=None):
242+
"""
243+
Plot Morris analysis results.
244+
245+
Parameters:
246+
-----------
247+
results : pd.DataFrame
248+
The results from sobol_results_to_df.
249+
param_groups : dic[str, list[str]] | None
250+
Optional parameter groupings used to give all the same plot color
251+
of the form ({<group name> : [param1, ...], }).
252+
n_cols : int, optional
253+
The number of columns in the plot. Defaults to 3 if there are 3 or
254+
more outputs, otherwise the number of outputs.
255+
figsize : tuple, optional
256+
Figure size as (width, height) in inches.If None, set calculated.
257+
"""
258+
return _plot_morris_analysis(results, param_groups, n_cols, figsize)

autoemulate/sensitivity_analysis.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Dict
2-
31
import matplotlib.pyplot as plt
42
import numpy as np
53
import pandas as pd
@@ -16,7 +14,8 @@
1614
def _sensitivity_analysis(
1715
model, method="sobol", problem=None, X=None, N=1024, conf_level=0.95, as_df=True
1816
):
19-
"""Perform Sobol sensitivity analysis on a fitted emulator.
17+
"""
18+
Perform global sensitivity analysis on a fitted emulator.
2019
2120
Parameters:
2221
-----------
@@ -34,6 +33,8 @@ def _sensitivity_analysis(
3433
"bounds": [[0, 1], [0, 1]],
3534
}
3635
```
36+
X : array-like, shape (n_samples, n_features)
37+
Simulation input.
3738
N : int, optional
3839
The number of samples to generate (default is 1024).
3940
conf_level : float, optional
@@ -130,18 +131,28 @@ def _generate_problem(X):
130131

131132
def _sobol_analysis(
132133
model, problem=None, X=None, N=1024, conf_level=0.95
133-
) -> Dict[str, ResultDict]:
134+
) -> dict[str, ResultDict]:
134135
"""
135136
Perform Sobol sensitivity analysis on a fitted emulator.
136137
138+
Sobol sensitivity analysis is a variance-based method that decomposes the variance of the model
139+
output into contributions from individual input parameters and their interactions. It calculates:
140+
- First-order indices (S1): Direct contribution of each input parameter
141+
- Second-order indices (S2): Contribution from pairwise interactions between parameters
142+
- Total-order indices (ST): Total contribution of a parameter, including all its interactions
143+
137144
Parameters:
138145
-----------
139146
model : fitted emulator model
140147
The emulator model to analyze.
141148
problem : dict
142149
The problem definition, including 'num_vars', 'names', and 'bounds'.
150+
X : array-like, shape (n_samples, n_features)
151+
Simulation input.
143152
N : int, optional
144-
The number of samples to generate (default is 1000).
153+
The number of samples to generate (default is 1024).
154+
conf_level : float, optional
155+
The confidence level for the confidence intervals (default is 0.95).
145156
146157
Returns:
147158
--------
@@ -177,7 +188,7 @@ def _sobol_analysis(
177188
return results
178189

179190

180-
def _sobol_results_to_df(results: Dict[str, ResultDict]) -> pd.DataFrame:
191+
def _sobol_results_to_df(results: dict[str, ResultDict]) -> pd.DataFrame:
181192
"""
182193
Convert Sobol results to a (long-format) pandas DataFrame.
183194
@@ -274,7 +285,7 @@ def _create_bar_plot(ax, output_data, output_name):
274285

275286
def _plot_sobol_analysis(results, index="S1", n_cols=None, figsize=None):
276287
"""
277-
Plot the sensitivity analysis results.
288+
Plot the sobol sensitivity analysis results.
278289
279290
Parameters:
280291
-----------
@@ -338,10 +349,12 @@ def _plot_sobol_analysis(results, index="S1", n_cols=None, figsize=None):
338349
"""
339350

340351

341-
def _morris_analysis(model, problem=None, X=None, N=1024) -> Dict[str, ResultDict]:
352+
def _morris_analysis(model, problem=None, X=None, N=1024) -> dict[str, ResultDict]:
342353
"""
343354
Perform Morris sensitivity analysis on a fitted emulator.
344355
356+
TODO: can we say more about the method here?
357+
345358
Parameters:
346359
-----------
347360
model : fitted emulator model
@@ -387,7 +400,7 @@ def _morris_analysis(model, problem=None, X=None, N=1024) -> Dict[str, ResultDic
387400

388401

389402
def _morris_results_to_df(
390-
results: Dict[str, ResultDict], problem: dict
403+
results: dict[str, ResultDict], problem: dict
391404
) -> pd.DataFrame:
392405
"""
393406
Convert Morris results to a (long-format) pandas DataFrame.

0 commit comments

Comments
 (0)