Skip to content

Commit 183f247

Browse files
Implement simple wrapper Op for scipy.signal.convolve2d
1 parent 5335a68 commit 183f247

File tree

2 files changed

+144
-2
lines changed

2 files changed

+144
-2
lines changed

pytensor/tensor/signal/conv.py

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import TYPE_CHECKING, Literal, cast
22

33
from numpy import convolve as numpy_convolve
4+
from scipy.signal import convolve2d as scipy_convolve2d
45

56
from pytensor.graph import Apply, Op
67
from pytensor.scalar.basic import upcast
78
from pytensor.tensor.basic import as_tensor_variable, join, zeros
89
from pytensor.tensor.blockwise import Blockwise
910
from pytensor.tensor.math import maximum, minimum
10-
from pytensor.tensor.type import vector
11+
from pytensor.tensor.type import matrix, vector
1112
from pytensor.tensor.variable import TensorVariable
1213

1314

@@ -131,3 +132,116 @@ def convolve1d(
131132
mode = "valid"
132133

133134
return cast(TensorVariable, Blockwise(Convolve1d(mode=mode))(in1, in2))
135+
136+
137+
class Convolve2D(Op):
138+
__props__ = ("mode", "boundary", "fillvalue")
139+
gufunc_signature = "(n,m),(k,l)->(o,p)"
140+
141+
def __init__(
142+
self,
143+
mode: Literal["full", "valid", "same"] = "full",
144+
boundary: Literal["fill", "wrap", "symm"] = "fill",
145+
fillvalue: float | int = 0,
146+
):
147+
if mode not in ("full", "valid", "same"):
148+
raise ValueError(f"Invalid mode: {mode}")
149+
if boundary not in ("fill", "wrap", "symm"):
150+
raise ValueError(f"Invalid boundary: {boundary}")
151+
152+
self.mode = mode
153+
self.boundary = boundary
154+
self.fillvalue = fillvalue
155+
156+
def make_node(self, in1, in2):
157+
in1, in2 = map(as_tensor_variable, (in1, in2))
158+
159+
assert in1.ndim == 2
160+
assert in2.ndim == 2
161+
162+
dtype = upcast(in1.dtype, in2.dtype)
163+
164+
n, m = in1.type.shape
165+
k, l = in2.type.shape
166+
167+
if any(x is None for x in (n, m, k, l)):
168+
out_shape = (None, None)
169+
elif self.mode == "full":
170+
out_shape = (n + k - 1, m + l - 1)
171+
elif self.mode == "valid":
172+
out_shape = (n - k + 1, m - l + 1)
173+
else: # mode == "same"
174+
out_shape = (n, m)
175+
176+
out = matrix(dtype=dtype, shape=out_shape)
177+
return Apply(self, [in1, in2], [out])
178+
179+
def perform(self, node, inputs, outputs):
180+
in1, in2 = inputs
181+
outputs[0][0] = scipy_convolve2d(
182+
in1, in2, mode=self.mode, boundary=self.boundary, fillvalue=self.fillvalue
183+
)
184+
185+
def infer_shape(self, fgraph, node, shapes):
186+
in1_shape, in2_shape = shapes
187+
n, m = in1_shape
188+
k, l = in2_shape
189+
190+
if self.mode == "full":
191+
shape = (n + k - 1, m + l - 1)
192+
elif self.mode == "valid":
193+
shape = (
194+
maximum(n, k) - minimum(n, k) + 1,
195+
maximum(m, l) - minimum(m, l) + 1,
196+
)
197+
else: # self.mode == 'same':
198+
shape = (n, m)
199+
200+
return [shape]
201+
202+
def L_op(self, inputs, outputs, output_grads):
203+
raise NotImplementedError
204+
205+
206+
def convolve2d(
207+
in1: "TensorLike",
208+
in2: "TensorLike",
209+
mode: Literal["full", "valid", "same"] = "full",
210+
boundary: Literal["fill", "wrap", "symm"] = "fill",
211+
fillvalue: float | int = 0,
212+
) -> TensorVariable:
213+
"""Convolve two two-dimensional arrays.
214+
215+
Convolve in1 and in2, with the output size determined by the mode argument.
216+
217+
Parameters
218+
----------
219+
in1 : (..., N, M) tensor_like
220+
First input.
221+
in2 : (..., K, L) tensor_like
222+
Second input.
223+
mode : {'full', 'valid', 'same'}, optional
224+
A string indicating the size of the output:
225+
- 'full': The output is the full discrete linear convolution of the inputs, with shape (..., N+K-1, M+L-1).
226+
- 'valid': The output consists only of elements that do not rely on zero-padding, with shape (..., max(N, K) - min(N, K) + 1, max(M, L) - min(M, L) + 1).
227+
- 'same': The output is the same size as in1, centered with respect to the 'full' output.
228+
boundary : {'fill', 'wrap', 'symm'}, optional
229+
A string indicating how to handle boundaries:
230+
- 'fill': Pads the input arrays with fillvalue.
231+
- 'wrap': Circularly wraps the input arrays.
232+
- 'symm': Symmetrically reflects the input arrays.
233+
fillvalue : float or int, optional
234+
The value to use for padding when boundary is 'fill'. Default is 0.
235+
Returns
236+
-------
237+
out: tensor_variable
238+
The discrete linear convolution of in1 with in2.
239+
240+
"""
241+
in1 = as_tensor_variable(in1)
242+
in2 = as_tensor_variable(in2)
243+
244+
blockwise_convolve = Blockwise(
245+
Convolve2D(mode=mode, boundary=boundary, fillvalue=fillvalue)
246+
)
247+
return cast(TensorVariable, blockwise_convolve(in1, in2))

tests/tensor/signal/test_conv.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
import numpy as np
44
import pytest
55
from scipy.signal import convolve as scipy_convolve
6+
from scipy.signal import convolve2d as scipy_convolve2d
67

78
from pytensor import config, function, grad
89
from pytensor.graph.basic import ancestors, io_toposort
910
from pytensor.graph.rewriting import rewrite_graph
1011
from pytensor.tensor import matrix, vector
1112
from pytensor.tensor.blockwise import Blockwise
12-
from pytensor.tensor.signal.conv import Convolve1d, convolve1d
13+
from pytensor.tensor.signal.conv import Convolve1d, convolve1d, convolve2d
1314
from tests import unittest_tools as utt
1415

1516

@@ -109,3 +110,30 @@ def test_convolve1d_valid_grad_rewrite(static_shape):
109110
if isinstance(node.op, Convolve1d)
110111
]
111112
assert conv_op.mode == ("valid" if static_shape else "full")
113+
114+
115+
@pytest.mark.parametrize(
116+
"kernel_shape", [(3, 3), (5, 3), (5, 8)], ids=lambda x: f"kernel_shape={x}"
117+
)
118+
@pytest.mark.parametrize(
119+
"data_shape", [(3, 3), (5, 5), (8, 8)], ids=lambda x: f"data_shape={x}"
120+
)
121+
@pytest.mark.parametrize("mode", ["full", "valid", "same"])
122+
@pytest.mark.parametrize("boundary", ["fill", "wrap", "symm"])
123+
def test_convolve2d(kernel_shape, data_shape, mode, boundary):
124+
data = matrix("data")
125+
kernel = matrix("kernel")
126+
op = partial(convolve2d, mode=mode, boundary=boundary, fillvalue=0)
127+
128+
rng = np.random.default_rng((26, kernel_shape, data_shape, sum(map(ord, mode))))
129+
data_val = rng.normal(size=data_shape).astype(data.dtype)
130+
kernel_val = rng.normal(size=kernel_shape).astype(kernel.dtype)
131+
132+
fn = function([data, kernel], op(data, kernel))
133+
np.testing.assert_allclose(
134+
fn(data_val, kernel_val),
135+
scipy_convolve2d(
136+
data_val, kernel_val, mode=mode, boundary=boundary, fillvalue=0
137+
),
138+
rtol=1e-6 if config.floatX == "float32" else 1e-15,
139+
)

0 commit comments

Comments
 (0)