@@ -27,12 +27,14 @@ def grad(output_, input_, components=None, d=None):
27
27
computed.
28
28
:param LabelTensor input_: The input tensor with respect to which the
29
29
gradient is computed.
30
- :param list[str] components: The names of the output variables for which to
30
+ :param components: The names of the output variables for which to
31
31
compute the gradient. It must be a subset of the output labels.
32
32
If ``None``, all output variables are considered. Default is ``None``.
33
- :param list[str] d: The names of the input variables with respect to which
33
+ :type components: str | list[str]
34
+ :param d: The names of the input variables with respect to which
34
35
the gradient is computed. It must be a subset of the input labels.
35
36
If ``None``, all input variables are considered. Default is ``None``.
37
+ :type d: str | list[str]
36
38
:raises TypeError: If the input tensor is not a LabelTensor.
37
39
:raises RuntimeError: If the output is a scalar field and the components
38
40
are not equal to the output labels.
@@ -50,9 +52,10 @@ def grad_scalar_output(output_, input_, d):
50
52
computed. It must be a column tensor.
51
53
:param LabelTensor input_: The input tensor with respect to which the
52
54
gradient is computed.
53
- :param list[str] d: The names of the input variables with respect to
55
+ :param d: The names of the input variables with respect to
54
56
which the gradient is computed. It must be a subset of the input
55
57
labels. If ``None``, all input variables are considered.
58
+ :type d: str | list[str]
56
59
:raises RuntimeError: If a vectorial function is passed.
57
60
:raises RuntimeError: If missing derivative labels.
58
61
:return: The computed gradient tensor.
@@ -89,6 +92,12 @@ def grad_scalar_output(output_, input_, d):
89
92
if components is None :
90
93
components = output_ .labels
91
94
95
+ if not isinstance (components , list ):
96
+ components = [components ]
97
+
98
+ if not isinstance (d , list ):
99
+ d = [d ]
100
+
92
101
if output_ .shape [1 ] == 1 : # scalar output ################################
93
102
94
103
if components != output_ .labels :
@@ -120,12 +129,14 @@ def div(output_, input_, components=None, d=None):
120
129
computed.
121
130
:param LabelTensor input_: The input tensor with respect to which the
122
131
divergence is computed.
123
- :param list[str] components: The names of the output variables for which to
132
+ :param components: The names of the output variables for which to
124
133
compute the divergence. It must be a subset of the output labels.
125
134
If ``None``, all output variables are considered. Default is ``None``.
126
- :param list[str] d: The names of the input variables with respect to which
135
+ :type components: str | list[str]
136
+ :param d: The names of the input variables with respect to which
127
137
the divergence is computed. It must be a subset of the input labels.
128
138
If ``None``, all input variables are considered. Default is ``None``.
139
+ :type d: str | list[str]
129
140
:raises TypeError: If the input tensor is not a LabelTensor.
130
141
:raises ValueError: If the output is a scalar field.
131
142
:raises ValueError: If the number of components is not equal to the number
@@ -142,6 +153,12 @@ def div(output_, input_, components=None, d=None):
142
153
if components is None :
143
154
components = output_ .labels
144
155
156
+ if not isinstance (components , list ):
157
+ components = [components ]
158
+
159
+ if not isinstance (d , list ):
160
+ d = [d ]
161
+
145
162
if output_ .shape [1 ] < 2 or len (components ) < 2 :
146
163
raise ValueError ("div supported only for vector fields" )
147
164
@@ -170,12 +187,14 @@ def laplacian(output_, input_, components=None, d=None, method="std"):
170
187
computed.
171
188
:param LabelTensor input_: The input tensor with respect to which the
172
189
laplacian is computed.
173
- :param list[str] components: The names of the output variables for which to
190
+ :param components: The names of the output variables for which to
174
191
compute the laplacian. It must be a subset of the output labels.
175
192
If ``None``, all output variables are considered. Default is ``None``.
176
- :param list[str] d: The names of the input variables with respect to which
193
+ :type components: str | list[str]
194
+ :param d: The names of the input variables with respect to which
177
195
the laplacian is computed. It must be a subset of the input labels.
178
196
If ``None``, all input variables are considered. Default is ``None``.
197
+ :type d: str | list[str]
179
198
:param str method: The method used to compute the Laplacian. Default is
180
199
``std``.
181
200
:raises NotImplementedError: If ``std=divgrad``.
@@ -191,12 +210,14 @@ def scalar_laplace(output_, input_, components, d):
191
210
computed. It must be a column tensor.
192
211
:param LabelTensor input_: The input tensor with respect to which the
193
212
laplacian is computed.
194
- :param list[str] components: The names of the output variables for which
213
+ :param components: The names of the output variables for which
195
214
to compute the laplacian. It must be a subset of the output labels.
196
215
If ``None``, all output variables are considered.
197
- :param list[str] d: The names of the input variables with respect to
216
+ :type components: str | list[str]
217
+ :param d: The names of the input variables with respect to
198
218
which the laplacian is computed. It must be a subset of the input
199
219
labels. If ``None``, all input variables are considered.
220
+ :type d: str | list[str]
200
221
:return: The computed laplacian tensor.
201
222
:rtype: LabelTensor
202
223
"""
@@ -216,22 +237,24 @@ def scalar_laplace(output_, input_, components, d):
216
237
if components is None :
217
238
components = output_ .labels
218
239
240
+ if not isinstance (components , list ):
241
+ components = [components ]
242
+
243
+ if not isinstance (d , list ):
244
+ d = [d ]
245
+
219
246
if method == "divgrad" :
220
247
raise NotImplementedError ("divgrad not implemented as method" )
221
248
222
249
if method == "std" :
223
- if len (components ) == 1 :
224
- result = scalar_laplace (output_ , input_ , components , d )
225
- labels = [f"dd{ components [0 ]} " ]
226
-
227
- else :
228
- result = torch .empty (
229
- input_ .shape [0 ], len (components ), device = output_ .device
230
- )
231
- labels = [None ] * len (components )
232
- for idx , c in enumerate (components ):
233
- result [:, idx ] = scalar_laplace (output_ , input_ , c , d ).flatten ()
234
- labels [idx ] = f"dd{ c } "
250
+
251
+ result = torch .empty (
252
+ input_ .shape [0 ], len (components ), device = output_ .device
253
+ )
254
+ labels = [None ] * len (components )
255
+ for idx , c in enumerate (components ):
256
+ result [:, idx ] = scalar_laplace (output_ , input_ , [c ], d ).flatten ()
257
+ labels [idx ] = f"dd{ c } "
235
258
236
259
result = result .as_subclass (LabelTensor )
237
260
result .labels = labels
@@ -251,12 +274,14 @@ def advection(output_, input_, velocity_field, components=None, d=None):
251
274
is computed.
252
275
:param str velocity_field: The name of the output variable used as velocity
253
276
field. It must be chosen among the output labels.
254
- :param list[str] components: The names of the output variables for which
277
+ :param components: The names of the output variables for which
255
278
to compute the advection. It must be a subset of the output labels.
256
279
If ``None``, all output variables are considered. Default is ``None``.
257
- :param list[str] d: The names of the input variables with respect to which
280
+ :type components: str | list[str]
281
+ :param d: The names of the input variables with respect to which
258
282
the advection is computed. It must be a subset of the input labels.
259
283
If ``None``, all input variables are considered. Default is ``None``.
284
+ :type d: str | list[str]
260
285
:return: The computed advection tensor.
261
286
:rtype: LabelTensor
262
287
"""
@@ -266,6 +291,12 @@ def advection(output_, input_, velocity_field, components=None, d=None):
266
291
if components is None :
267
292
components = output_ .labels
268
293
294
+ if not isinstance (components , list ):
295
+ components = [components ]
296
+
297
+ if not isinstance (d , list ):
298
+ d = [d ]
299
+
269
300
tmp = (
270
301
grad (output_ , input_ , components , d )
271
302
.reshape (- 1 , len (components ), len (d ))
0 commit comments