Skip to content

Commit c44361c

Browse files
Better shape inference
1 parent 341cf45 commit c44361c

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,19 @@ def make_node(self, in1, in2):
244244
n, m = in1.type.shape
245245
k, l = in2.type.shape
246246

247-
if any(x is None for x in (n, m, k, l)):
248-
out_shape = (None, None)
249-
elif self.mode == "full":
250-
out_shape = (n + k - 1, m + l - 1)
247+
if self.mode == "full":
248+
shape_1 = None if (n is None or k is None) else n + k - 1
249+
shape_2 = None if (m is None or l is None) else m + l - 1
250+
251251
elif self.mode == "valid":
252-
out_shape = (n - k + 1, m - l + 1)
252+
shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1
253+
shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1
254+
253255
else: # mode == "same"
254-
out_shape = (n, m)
256+
shape_1 = n
257+
shape_2 = m
255258

259+
out_shape = (shape_1, shape_2)
256260
out = matrix(dtype=dtype, shape=out_shape)
257261
return Apply(self, [in1, in2], [out])
258262

0 commit comments

Comments
 (0)