Skip to content

Fix bug in gradient of Blockwise'd Scan #1482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 42 additions & 57 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,81 +344,66 @@ def connection_pattern(self, node):

return [[True for _ in node.outputs] for _ in node.inputs]

def _bgrad(self, inputs, outputs, ograds):
# Grad, with respect to broadcasted versions of inputs

def as_core(t, core_t):
# Inputs could be NullType or DisconnectedType
if isinstance(t.type, NullType | DisconnectedType):
return t
return core_t.type()
def L_op(self, inputs, outputs, output_gradients):
batch_ndim = self.batch_ndim(outputs[0].owner)

# Obtain core_op gradients
with config.change_flags(compute_test_value="off"):
safe_inputs = [
tensor(dtype=inp.type.dtype, shape=(None,) * len(sig))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was the problematic one: shape=(None,) * len(sig)

for inp, sig in zip(inputs, self.inputs_sig, strict=True)
]
core_node = self._create_dummy_core_node(safe_inputs)

core_inputs = [
as_core(inp, core_inp)
for inp, core_inp in zip(inputs, core_node.inputs, strict=True)
]
core_ograds = [
as_core(ograd, core_ograd)
for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True)
tensor(
dtype=inp.type.dtype,
shape=inp.type.shape[batch_ndim:],
)
for inp in inputs
]
# FIXME: These core_outputs do not depend on core_inputs, not pretty
# It's not neccessarily a problem because if they are referenced by the gradient,
# they get replaced later in vectorize. But if the Op was to make any decision
# by introspecting the dependencies of output on inputs it would fail badly!
Comment on lines -371 to -374
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also fixed

core_outputs = core_node.outputs

core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds)

igrads = vectorize_graph(
[core_igrad for core_igrad in core_igrads if core_igrad is not None],
replace=dict(
zip(
core_inputs + core_outputs + core_ograds,
inputs + outputs + ograds,
strict=True,
core_outputs = self._create_dummy_core_node(core_inputs).outputs

# Define core output_gradients, but keep original disconnected/null output_gradients (if any)
core_output_gradients = [
output_grad
if isinstance(output_grad.type, NullType | DisconnectedType)
else core_output.type()
for output_grad, core_output in zip(
output_gradients, core_outputs, strict=True
)
),
)

igrads_iter = iter(igrads)
return [
None if core_igrad is None else next(igrads_iter)
for core_igrad in core_igrads
]
]

def L_op(self, inputs, outs, ograds):
from pytensor.tensor.math import sum as pt_sum
core_input_gradients = self.core_op.L_op(
core_inputs, core_outputs, core_output_gradients
)

# Compute grad with respect to broadcasted input
rval = self._bgrad(inputs, outs, ograds)
# Vectorize core gradients to original inputs
input_gradients = list(
vectorize_graph(
core_input_gradients,
replace=dict(
zip(
core_inputs + core_outputs + core_output_gradients,
inputs + outputs + output_gradients,
strict=True,
)
),
)
)

# Sum out the broadcasted dimensions
batch_ndims = self.batch_ndim(outs[0].owner)
batch_shape = outs[0].type.shape[:batch_ndims]
# Sum out the broadcasted batch dimensions
batch_shape = outputs[0].type.shape[:batch_ndim]
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)):
if isinstance(rval[i].type, NullType | DisconnectedType):
if isinstance(input_gradients[i].type, NullType | DisconnectedType):
continue

assert inp.type.ndim == batch_ndims + len(sig)
assert inp.type.ndim == batch_ndim + len(sig)

to_sum = [
if to_sum := [
j
for j, (inp_s, out_s) in enumerate(
zip(inp.type.shape, batch_shape, strict=False)
)
if inp_s == 1 and out_s != 1
]
if to_sum:
rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True)
]:
input_gradients[i] = input_gradients[i].sum(axis=to_sum, keepdims=True)

return rval
return input_gradients

def _create_node_gufunc(self, node: Apply, impl) -> Callable:
"""Define (or retrieve) the node gufunc used in `perform`.
Expand Down
66 changes: 57 additions & 9 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
import scipy.linalg

import pytensor
from pytensor import In, config, function
from pytensor import In, config, function, scan
from pytensor.compile import get_default_mode, get_mode
from pytensor.gradient import grad
from pytensor.graph import Apply, Op
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph, vectorize_node
from pytensor.raise_op import assert_op
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
Expand Down Expand Up @@ -162,13 +162,13 @@ def perform(self, *args, **kwargs):
raise NotImplementedError("Test Op should not be present in final graph")


test_op = MyTestOp()
my_test_op = MyTestOp()


def test_vectorize_node_default_signature():
vec = tensor(shape=(None,))
mat = tensor(shape=(5, None))
node = test_op.make_node(vec, mat)
node = my_test_op.make_node(vec, mat)

vect_node = vectorize_node(node, mat, mat)
assert isinstance(vect_node.op, Blockwise) and isinstance(
Expand All @@ -179,9 +179,9 @@ def test_vectorize_node_default_signature():
with pytest.raises(
ValueError, match="Signature not provided nor found in core_op MyTestOp"
):
Blockwise(test_op)
Blockwise(my_test_op)

vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat)
vect_node = Blockwise(my_test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat)
assert vect_node.outputs[0].type.shape == (
5,
None,
Expand All @@ -198,7 +198,7 @@ def test_blockwise_shape():
inp_test = np.zeros((5, 4, 3), dtype=config.floatX)

# Shape can be inferred from inputs
op = Blockwise(test_op, signature="(m, n) -> (n, m)")
op = Blockwise(my_test_op, signature="(m, n) -> (n, m)")
out = op(inp)
assert out.type.shape == (5, None, None)

Expand All @@ -210,7 +210,7 @@ def test_blockwise_shape():
assert tuple(shape_fn(inp_test)) == (5, 3, 4)

# Shape can only be partially inferred from inputs
op = Blockwise(test_op, signature="(m, n) -> (m, k)")
op = Blockwise(my_test_op, signature="(m, n) -> (m, k)")
out = op(inp)
assert out.type.shape == (5, None, None)

Expand All @@ -233,7 +233,7 @@ def test_blockwise_shape():
inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX)
inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX)

op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)")
op = Blockwise(my_test_op, signature="(m, n), (m, n) -> (n, m), (m, k)")
outs = op(inp1, inp2)
assert outs[0].type.shape == (7, 5, None, None)
assert outs[1].type.shape == (7, 5, None, None)
Expand Down Expand Up @@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients):
np.ones(12, dtype=config.floatX),
strict=True,
)


def test_blockwise_grad_core_type():
class StrictCoreTypeOp(Op):
def make_node(self, x):
assert x.type.shape[-1] == 2
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, output_storage):
output_storage[0][0] = inputs[0] + 1

def L_op(self, inputs, outputs, output_grads):
[x] = inputs
assert x.type.shape == (2,)
return [x.zeros_like()]

strict_core_type_op = StrictCoreTypeOp()
block_strict_core_type_op = Blockwise(strict_core_type_op, signature="(a)->(a)")

x = tensor("x", shape=(5, 2), dtype="float64")
y = block_strict_core_type_op(x)
assert y.type.shape == (5, 2)

grad_y = grad(y.sum(), x)
assert grad_y.type.shape == (5, 2)
np.testing.assert_allclose(
grad_y.eval({x: np.ones((5, 2))}),
np.zeros((5, 2)),
)


def test_scan_gradient_core_type():
n_steps = 3
seq = tensor("seq", shape=(n_steps, 1), dtype="float64")
out, _ = scan(
lambda s: s,
sequences=[seq],
n_steps=n_steps,
)

vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64")
vec_out = vectorize_graph(out, replace={seq: vec_seq})
grad_sit_sot0 = grad(vec_out.sum(), vec_seq)

np.testing.assert_allclose(
grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}),
np.ones((4, n_steps, 1)),
)