Skip to content

Commit 341cf45

Browse files
Implement simple wrapper Op for scipy.signal.convolve2d
1 parent 73c0d4d commit 341cf45

File tree

2 files changed

+144
-2
lines changed

2 files changed

+144
-2
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import numpy as np
44
from numpy import convolve as numpy_convolve
5+
from scipy.signal import convolve2d as scipy_convolve2d
56

67
from pytensor.gradient import DisconnectedType
78
from pytensor.graph import Apply, Constant
@@ -11,7 +12,7 @@
1112
from pytensor.tensor.basic import as_tensor_variable, join, zeros
1213
from pytensor.tensor.blockwise import Blockwise
1314
from pytensor.tensor.math import maximum, minimum, switch
14-
from pytensor.tensor.type import vector
15+
from pytensor.tensor.type import matrix, vector
1516
from pytensor.tensor.variable import TensorVariable
1617

1718

@@ -211,3 +212,116 @@ def convolve1d(
211212

212213
full_mode = as_scalar(np.bool_(mode == "full"))
213214
return cast(TensorVariable, blockwise_convolve_1d(in1, in2, full_mode))
215+
216+
217+
class Convolve2D(Op):
218+
__props__ = ("mode", "boundary", "fillvalue")
219+
gufunc_signature = "(n,m),(k,l)->(o,p)"
220+
221+
def __init__(
222+
self,
223+
mode: Literal["full", "valid", "same"] = "full",
224+
boundary: Literal["fill", "wrap", "symm"] = "fill",
225+
fillvalue: float | int = 0,
226+
):
227+
if mode not in ("full", "valid", "same"):
228+
raise ValueError(f"Invalid mode: {mode}")
229+
if boundary not in ("fill", "wrap", "symm"):
230+
raise ValueError(f"Invalid boundary: {boundary}")
231+
232+
self.mode = mode
233+
self.boundary = boundary
234+
self.fillvalue = fillvalue
235+
236+
def make_node(self, in1, in2):
237+
in1, in2 = map(as_tensor_variable, (in1, in2))
238+
239+
assert in1.ndim == 2
240+
assert in2.ndim == 2
241+
242+
dtype = upcast(in1.dtype, in2.dtype)
243+
244+
n, m = in1.type.shape
245+
k, l = in2.type.shape
246+
247+
if any(x is None for x in (n, m, k, l)):
248+
out_shape = (None, None)
249+
elif self.mode == "full":
250+
out_shape = (n + k - 1, m + l - 1)
251+
elif self.mode == "valid":
252+
out_shape = (n - k + 1, m - l + 1)
253+
else: # mode == "same"
254+
out_shape = (n, m)
255+
256+
out = matrix(dtype=dtype, shape=out_shape)
257+
return Apply(self, [in1, in2], [out])
258+
259+
def perform(self, node, inputs, outputs):
260+
in1, in2 = inputs
261+
outputs[0][0] = scipy_convolve2d(
262+
in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue
263+
)
264+
265+
def infer_shape(self, fgraph, node, shapes):
266+
in1_shape, in2_shape = shapes
267+
n, m = in1_shape
268+
k, l = in2_shape
269+
270+
if self.mode == "full":
271+
shape = (n + k - 1, m + l - 1)
272+
elif self.mode == "valid":
273+
shape = (
274+
maximum(n, k) - minimum(n, k) + 1,
275+
maximum(m, l) - minimum(m, l) + 1,
276+
)
277+
else: # self.mode == 'same':
278+
shape = (n, m)
279+
280+
return [shape]
281+
282+
def L_op(self, inputs, outputs, output_grads):
283+
raise NotImplementedError
284+
285+
286+
def convolve2d(
287+
in1: "TensorLike",
288+
in2: "TensorLike",
289+
mode: Literal["full", "valid", "same"] = "full",
290+
boundary: Literal["fill", "wrap", "symm"] = "fill",
291+
fillvalue: float | int = 0,
292+
) -> TensorVariable:
293+
"""Convolve two two-dimensional arrays.
294+
295+
Convolve in1 and in2, with the output size determined by the mode argument.
296+
297+
Parameters
298+
----------
299+
in1 : (..., N, M) tensor_like
300+
First input.
301+
in2 : (..., K, L) tensor_like
302+
Second input.
303+
mode : {'full', 'valid', 'same'}, optional
304+
A string indicating the size of the output:
305+
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
306+
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
307+
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
308+
boundary : {'fill', 'wrap', 'symm'}, optional
309+
A string indicating how to handle boundaries:
310+
- 'fill': Pads the input arrays with fillvalue.
311+
- 'wrap': Circularly wraps the input arrays.
312+
- 'symm': Symmetrically reflects the input arrays.
313+
fillvalue : float or int, optional
314+
The value to use for padding when boundary is 'fill'. Default is 0.
315+
Returns
316+
-------
317+
out: tensor_variable
318+
The discrete linear convolution of in1 with in2.
319+
320+
"""
321+
in1 = as_tensor_variable(in1)
322+
in2 = as_tensor_variable(in2)
323+
324+
blockwise_convolve = Blockwise(
325+
Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue)
326+
)
327+
return cast(TensorVariable, blockwise_convolve(in1, in2))

tests/tensor/signal/test_conv.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import numpy as np
44
import pytest
55
from scipy.signal import convolve as scipy_convolve
6+
from scipy.signal import convolve2d as scipy_convolve2d
67

78
from pytensor import config, function, grad
89
from pytensor.graph.basic import ancestors, io_toposort
910
from pytensor.graph.rewriting import rewrite_graph
1011
from pytensor.tensor import matrix, tensor, vector
1112
from pytensor.tensor.blockwise import Blockwise
12-
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
13+
from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d
1314
from tests import unittest_tools as utt
1415

1516

@@ -137,3 +138,30 @@ def convolve1d_grad_benchmarker(convolve_mode, mode, benchmark):
137138
@pytest.mark.parametrize("convolve_mode", ["full", "valid"])
138139
def test_convolve1d_grad_benchmark_c(convolve_mode, benchmark):
139140
convolve1d_grad_benchmarker(convolve_mode, "FAST_RUN", benchmark)
141+
142+
143+
@pytest.mark.parametrize(
144+
"kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}"
145+
)
146+
@pytest.mark.parametrize(
147+
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
148+
)
149+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
150+
@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"])
151+
def test_convolve2d(kernel_shape, data_shape, mode, boundary):
152+
data = matrix("data")
153+
kernel = matrix("kernel")
154+
op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0)
155+
156+
rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
157+
data_val = rng.normal(size=data_shape).astype(data.dtype)
158+
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
159+
160+
fn = function([data, kernel], op(data, kernel))
161+
np.testing.assert_allclose(
162+
fn(data_val, kernel_val),
163+
scipy_convolve2d(
164+
data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0
165+
),
166+
rtol=1e-6 if config.floatX == "float32" else 1e-15,
167+
)

0 commit comments

Comments
 (0)