Skip to content

Commit 4230f63

Browse files
Better shape inference
1 parent 183f247 commit 4230f63

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,19 @@ def make_node(self, in1, in2):
164164
n, m = in1.type.shape
165165
k, l = in2.type.shape
166166

167-
if any(x is None for x in (n, m, k, l)):
168-
out_shape = (None, None)
169-
elif self.mode == "full":
170-
out_shape = (n + k - 1, m + l - 1)
167+
if self.mode == "full":
168+
shape_1 = None if (n is None or k is None) else n + k - 1
169+
shape_2 = None if (m is None or l is None) else m + l - 1
170+
171171
elif self.mode == "valid":
172-
out_shape = (n - k + 1, m - l + 1)
172+
shape_1 = None if (n is None or k is None) else max(n, k) - max(n, k) + 1
173+
shape_2 = None if (m is None or l is None) else max(m, l) - min(m, l) + 1
174+
173175
else: # mode == "same"
174-
out_shape = (n, m)
176+
shape_1 = n
177+
shape_2 = m
175178

179+
out_shape = (shape_1, shape_2)
176180
out = matrix(dtype=dtype, shape=out_shape)
177181
return Apply(self, [in1, in2], [out])
178182

0 commit comments

Comments
 (0)