Skip to content

Commit d1ea0cc

Browse files
committed
_generate_problem change input type NumpyLike -> TensorLike
1 parent e4b8b3f commit d1ea0cc

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

autoemulate/experimental/sensitivity_analysis.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def __init__(
5858
self.x = x
5959
elif x is not None:
6060
self.x, _ = self._convert_to_numpy(x)
61-
problem = self._generate_problem(self.x)
61+
problem = self._generate_problem(x)
6262
else:
6363
msg = "Either problem or x must be provided."
6464
raise ValueError(msg)
@@ -95,13 +95,13 @@ def _check_problem(problem: dict) -> dict:
9595
return problem
9696

9797
@staticmethod
98-
def _generate_problem(x: NumpyLike) -> dict:
98+
def _generate_problem(x: TensorLike) -> dict:
9999
"""
100100
Generate a problem definition from a design matrix.
101101
102102
Parameters
103103
----------
104-
x : NumpyLike
104+
x : TensorLike
105105
Simulator input parameter values [n_samples, n_parameters].
106106
"""
107107
if x.ndim == 1:
@@ -111,7 +111,9 @@ def _generate_problem(x: NumpyLike) -> dict:
111111
return {
112112
"num_vars": x.shape[1],
113113
"names": [f"X{i + 1}" for i in range(x.shape[1])],
114-
"bounds": [[x[:, i].min(), x[:, i].max()] for i in range(x.shape[1])],
114+
"bounds": [
115+
[x[:, i].min().item(), x[:, i].max().item()] for i in range(x.shape[1])
116+
],
115117
}
116118

117119
def _sample(self, method: str, N: int) -> NumpyLike:

0 commit comments

Comments
 (0)