Skip to content

Commit b958c0f

Browse files
fix labels management in operators (mathLab#524)
* fix bug in laplace labels * fix label management and add test
1 parent ef29f0a commit b958c0f

File tree

2 files changed

+93
-23
lines changed

2 files changed

+93
-23
lines changed

pina/operator.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,14 @@ def grad(output_, input_, components=None, d=None):
2727
computed.
2828
:param LabelTensor input_: The input tensor with respect to which the
2929
gradient is computed.
30-
:param list[str] components: The names of the output variables for which to
30+
:param components: The names of the output variables for which to
3131
compute the gradient. It must be a subset of the output labels.
3232
If ``None``, all output variables are considered. Default is ``None``.
33-
:param list[str] d: The names of the input variables with respect to which
33+
:type components: str | list[str]
34+
:param d: The names of the input variables with respect to which
3435
the gradient is computed. It must be a subset of the input labels.
3536
If ``None``, all input variables are considered. Default is ``None``.
37+
:type d: str | list[str]
3638
:raises TypeError: If the input tensor is not a LabelTensor.
3739
:raises RuntimeError: If the output is a scalar field and the components
3840
are not equal to the output labels.
@@ -50,9 +52,10 @@ def grad_scalar_output(output_, input_, d):
5052
computed. It must be a column tensor.
5153
:param LabelTensor input_: The input tensor with respect to which the
5254
gradient is computed.
53-
:param list[str] d: The names of the input variables with respect to
55+
:param d: The names of the input variables with respect to
5456
which the gradient is computed. It must be a subset of the input
5557
labels. If ``None``, all input variables are considered.
58+
:type d: str | list[str]
5659
:raises RuntimeError: If a vectorial function is passed.
5760
:raises RuntimeError: If missing derivative labels.
5861
:return: The computed gradient tensor.
@@ -89,6 +92,12 @@ def grad_scalar_output(output_, input_, d):
8992
if components is None:
9093
components = output_.labels
9194

95+
if not isinstance(components, list):
96+
components = [components]
97+
98+
if not isinstance(d, list):
99+
d = [d]
100+
92101
if output_.shape[1] == 1: # scalar output ################################
93102

94103
if components != output_.labels:
@@ -120,12 +129,14 @@ def div(output_, input_, components=None, d=None):
120129
computed.
121130
:param LabelTensor input_: The input tensor with respect to which the
122131
divergence is computed.
123-
:param list[str] components: The names of the output variables for which to
132+
:param components: The names of the output variables for which to
124133
compute the divergence. It must be a subset of the output labels.
125134
If ``None``, all output variables are considered. Default is ``None``.
126-
:param list[str] d: The names of the input variables with respect to which
135+
:type components: str | list[str]
136+
:param d: The names of the input variables with respect to which
127137
the divergence is computed. It must be a subset of the input labels.
128138
If ``None``, all input variables are considered. Default is ``None``.
139+
:type d: str | list[str]
129140
:raises TypeError: If the input tensor is not a LabelTensor.
130141
:raises ValueError: If the output is a scalar field.
131142
:raises ValueError: If the number of components is not equal to the number
@@ -142,6 +153,12 @@ def div(output_, input_, components=None, d=None):
142153
if components is None:
143154
components = output_.labels
144155

156+
if not isinstance(components, list):
157+
components = [components]
158+
159+
if not isinstance(d, list):
160+
d = [d]
161+
145162
if output_.shape[1] < 2 or len(components) < 2:
146163
raise ValueError("div supported only for vector fields")
147164

@@ -170,12 +187,14 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
170187
computed.
171188
:param LabelTensor input_: The input tensor with respect to which the
172189
laplacian is computed.
173-
:param list[str] components: The names of the output variables for which to
190+
:param components: The names of the output variables for which to
174191
compute the laplacian. It must be a subset of the output labels.
175192
If ``None``, all output variables are considered. Default is ``None``.
176-
:param list[str] d: The names of the input variables with respect to which
193+
:type components: str | list[str]
194+
:param d: The names of the input variables with respect to which
177195
the laplacian is computed. It must be a subset of the input labels.
178196
If ``None``, all input variables are considered. Default is ``None``.
197+
:type d: str | list[str]
179198
:param str method: The method used to compute the Laplacian. Default is
180199
``std``.
181200
:raises NotImplementedError: If ``std=divgrad``.
@@ -191,12 +210,14 @@ def scalar_laplace(output_, input_, components, d):
191210
computed. It must be a column tensor.
192211
:param LabelTensor input_: The input tensor with respect to which the
193212
laplacian is computed.
194-
:param list[str] components: The names of the output variables for which
213+
:param components: The names of the output variables for which
195214
to compute the laplacian. It must be a subset of the output labels.
196215
If ``None``, all output variables are considered.
197-
:param list[str] d: The names of the input variables with respect to
216+
:type components: str | list[str]
217+
:param d: The names of the input variables with respect to
198218
which the laplacian is computed. It must be a subset of the input
199219
labels. If ``None``, all input variables are considered.
220+
:type d: str | list[str]
200221
:return: The computed laplacian tensor.
201222
:rtype: LabelTensor
202223
"""
@@ -216,22 +237,24 @@ def scalar_laplace(output_, input_, components, d):
216237
if components is None:
217238
components = output_.labels
218239

240+
if not isinstance(components, list):
241+
components = [components]
242+
243+
if not isinstance(d, list):
244+
d = [d]
245+
219246
if method == "divgrad":
220247
raise NotImplementedError("divgrad not implemented as method")
221248

222249
if method == "std":
223-
if len(components) == 1:
224-
result = scalar_laplace(output_, input_, components, d)
225-
labels = [f"dd{components[0]}"]
226-
227-
else:
228-
result = torch.empty(
229-
input_.shape[0], len(components), device=output_.device
230-
)
231-
labels = [None] * len(components)
232-
for idx, c in enumerate(components):
233-
result[:, idx] = scalar_laplace(output_, input_, c, d).flatten()
234-
labels[idx] = f"dd{c}"
250+
251+
result = torch.empty(
252+
input_.shape[0], len(components), device=output_.device
253+
)
254+
labels = [None] * len(components)
255+
for idx, c in enumerate(components):
256+
result[:, idx] = scalar_laplace(output_, input_, [c], d).flatten()
257+
labels[idx] = f"dd{c}"
235258

236259
result = result.as_subclass(LabelTensor)
237260
result.labels = labels
@@ -251,12 +274,14 @@ def advection(output_, input_, velocity_field, components=None, d=None):
251274
is computed.
252275
:param str velocity_field: The name of the output variable used as velocity
253276
field. It must be chosen among the output labels.
254-
:param list[str] components: The names of the output variables for which
277+
:param components: The names of the output variables for which
255278
to compute the advection. It must be a subset of the output labels.
256279
If ``None``, all output variables are considered. Default is ``None``.
257-
:param list[str] d: The names of the input variables with respect to which
280+
:type components: str | list[str]
281+
:param d: The names of the input variables with respect to which
258282
the advection is computed. It must be a subset of the input labels.
259283
If ``None``, all input variables are considered. Default is ``None``.
284+
:type d: str | list[str]
260285
:return: The computed advection tensor.
261286
:rtype: LabelTensor
262287
"""
@@ -266,6 +291,12 @@ def advection(output_, input_, velocity_field, components=None, d=None):
266291
if components is None:
267292
components = output_.labels
268293

294+
if not isinstance(components, list):
295+
components = [components]
296+
297+
if not isinstance(d, list):
298+
d = [d]
299+
269300
tmp = (
270301
grad(output_, input_, components, d)
271302
.reshape(-1, len(components), len(d))

tests/test_operator.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,42 @@ def test_laplacian_vector_output2():
164164

165165
assert torch.allclose(lap_f.extract("ddu"), lap_u)
166166
assert torch.allclose(lap_f.extract("ddv"), lap_v)
167+
168+
169+
def test_label_format():
170+
# Testing the format of `components` or `d` in case of single str of length
171+
# greater than 1; e.g.: "aaa".
172+
# This test is conducted only for gradient and laplacian, since div is not
173+
# implemented for single components.
174+
inp.labels = ["xx", "yy", "zz"]
175+
tensor_v = LabelTensor(func_vector(inp), ["aa", "bbb", "c"])
176+
comp = tensor_v.labels[0]
177+
single_d = inp.labels[0]
178+
179+
# Single component as string + list of d
180+
grad_tensor_v = grad(tensor_v, inp, components=comp, d=None)
181+
assert grad_tensor_v.labels == [f"d{comp}d{i}" for i in inp.labels]
182+
183+
lap_tensor_v = laplacian(tensor_v, inp, components=comp, d=None)
184+
assert lap_tensor_v.labels == [f"dd{comp}"]
185+
186+
# Single component as list + list of d
187+
grad_tensor_v = grad(tensor_v, inp, components=[comp], d=None)
188+
assert grad_tensor_v.labels == [f"d{comp}d{i}" for i in inp.labels]
189+
190+
lap_tensor_v = laplacian(tensor_v, inp, components=[comp], d=None)
191+
assert lap_tensor_v.labels == [f"dd{comp}"]
192+
193+
# List of components + single d as string
194+
grad_tensor_v = grad(tensor_v, inp, components=None, d=single_d)
195+
assert grad_tensor_v.labels == [f"d{i}d{single_d}" for i in tensor_v.labels]
196+
197+
lap_tensor_v = laplacian(tensor_v, inp, components=None, d=single_d)
198+
assert lap_tensor_v.labels == [f"dd{i}" for i in tensor_v.labels]
199+
200+
# List of components + single d as list
201+
grad_tensor_v = grad(tensor_v, inp, components=None, d=[single_d])
202+
assert grad_tensor_v.labels == [f"d{i}d{single_d}" for i in tensor_v.labels]
203+
204+
lap_tensor_v = laplacian(tensor_v, inp, components=None, d=[single_d])
205+
assert lap_tensor_v.labels == [f"dd{i}" for i in tensor_v.labels]

0 commit comments

Comments
 (0)