diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 29eafef992..4cc59fd0cf 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -344,81 +344,66 @@ def connection_pattern(self, node): return [[True for _ in node.outputs] for _ in node.inputs] - def _bgrad(self, inputs, outputs, ograds): - # Grad, with respect to broadcasted versions of inputs - - def as_core(t, core_t): - # Inputs could be NullType or DisconnectedType - if isinstance(t.type, NullType | DisconnectedType): - return t - return core_t.type() + def L_op(self, inputs, outputs, output_gradients): + batch_ndim = self.batch_ndim(outputs[0].owner) + # Obtain core_op gradients with config.change_flags(compute_test_value="off"): - safe_inputs = [ - tensor(dtype=inp.type.dtype, shape=(None,) * len(sig)) - for inp, sig in zip(inputs, self.inputs_sig, strict=True) - ] - core_node = self._create_dummy_core_node(safe_inputs) - core_inputs = [ - as_core(inp, core_inp) - for inp, core_inp in zip(inputs, core_node.inputs, strict=True) - ] - core_ograds = [ - as_core(ograd, core_ograd) - for ograd, core_ograd in zip(ograds, core_node.outputs, strict=True) + tensor( + dtype=inp.type.dtype, + shape=inp.type.shape[batch_ndim:], + ) + for inp in inputs ] - # FIXME: These core_outputs do not depend on core_inputs, not pretty - # It's not neccessarily a problem because if they are referenced by the gradient, - # they get replaced later in vectorize. But if the Op was to make any decision - # by introspecting the dependencies of output on inputs it would fail badly! - core_outputs = core_node.outputs - - core_igrads = self.core_op.L_op(core_inputs, core_outputs, core_ograds) - - igrads = vectorize_graph( - [core_igrad for core_igrad in core_igrads if core_igrad is not None], - replace=dict( - zip( - core_inputs + core_outputs + core_ograds, - inputs + outputs + ograds, - strict=True, + core_outputs = self._create_dummy_core_node(core_inputs).outputs + + # Define core output_gradients, but keep original disconnected/null output_gradients (if any) + core_output_gradients = [ + output_grad + if isinstance(output_grad.type, NullType | DisconnectedType) + else core_output.type() + for output_grad, core_output in zip( + output_gradients, core_outputs, strict=True ) - ), - ) - - igrads_iter = iter(igrads) - return [ - None if core_igrad is None else next(igrads_iter) - for core_igrad in core_igrads - ] + ] - def L_op(self, inputs, outs, ograds): - from pytensor.tensor.math import sum as pt_sum + core_input_gradients = self.core_op.L_op( + core_inputs, core_outputs, core_output_gradients + ) - # Compute grad with respect to broadcasted input - rval = self._bgrad(inputs, outs, ograds) + # Vectorize core gradients to original inputs + input_gradients = list( + vectorize_graph( + core_input_gradients, + replace=dict( + zip( + core_inputs + core_outputs + core_output_gradients, + inputs + outputs + output_gradients, + strict=True, + ) + ), + ) + ) - # Sum out the broadcasted dimensions - batch_ndims = self.batch_ndim(outs[0].owner) - batch_shape = outs[0].type.shape[:batch_ndims] + # Sum out the broadcasted batch dimensions + batch_shape = outputs[0].type.shape[:batch_ndim] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig, strict=True)): - if isinstance(rval[i].type, NullType | DisconnectedType): + if isinstance(input_gradients[i].type, NullType | DisconnectedType): continue - assert inp.type.ndim == batch_ndims + len(sig) + assert inp.type.ndim == batch_ndim + len(sig) - to_sum = [ + if to_sum := [ j for j, (inp_s, out_s) in enumerate( zip(inp.type.shape, batch_shape, strict=False) ) if inp_s == 1 and out_s != 1 - ] - if to_sum: - rval[i] = pt_sum(rval[i], axis=to_sum, keepdims=True) + ]: + input_gradients[i] = input_gradients[i].sum(axis=to_sum, keepdims=True) - return rval + return input_gradients def _create_node_gufunc(self, node: Apply, impl) -> Callable: """Define (or retrieve) the node gufunc used in `perform`. diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index 30af86c038..9d48f310fe 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -6,11 +6,11 @@ import scipy.linalg import pytensor -from pytensor import In, config, function +from pytensor import In, config, function, scan from pytensor.compile import get_default_mode, get_mode from pytensor.gradient import grad from pytensor.graph import Apply, Op -from pytensor.graph.replace import vectorize_node +from pytensor.graph.replace import vectorize_graph, vectorize_node from pytensor.raise_op import assert_op from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback @@ -162,13 +162,13 @@ def perform(self, *args, **kwargs): raise NotImplementedError("Test Op should not be present in final graph") -test_op = MyTestOp() +my_test_op = MyTestOp() def test_vectorize_node_default_signature(): vec = tensor(shape=(None,)) mat = tensor(shape=(5, None)) - node = test_op.make_node(vec, mat) + node = my_test_op.make_node(vec, mat) vect_node = vectorize_node(node, mat, mat) assert isinstance(vect_node.op, Blockwise) and isinstance( @@ -179,9 +179,9 @@ def test_vectorize_node_default_signature(): with pytest.raises( ValueError, match="Signature not provided nor found in core_op MyTestOp" ): - Blockwise(test_op) + Blockwise(my_test_op) - vect_node = Blockwise(test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) + vect_node = Blockwise(my_test_op, signature="(m),(n)->(m),(n)").make_node(vec, mat) assert vect_node.outputs[0].type.shape == ( 5, None, @@ -198,7 +198,7 @@ def test_blockwise_shape(): inp_test = np.zeros((5, 4, 3), dtype=config.floatX) # Shape can be inferred from inputs - op = Blockwise(test_op, signature="(m, n) -> (n, m)") + op = Blockwise(my_test_op, signature="(m, n) -> (n, m)") out = op(inp) assert out.type.shape == (5, None, None) @@ -210,7 +210,7 @@ def test_blockwise_shape(): assert tuple(shape_fn(inp_test)) == (5, 3, 4) # Shape can only be partially inferred from inputs - op = Blockwise(test_op, signature="(m, n) -> (m, k)") + op = Blockwise(my_test_op, signature="(m, n) -> (m, k)") out = op(inp) assert out.type.shape == (5, None, None) @@ -233,7 +233,7 @@ def test_blockwise_shape(): inp1_test = np.zeros((7, 1, 4, 3), dtype=config.floatX) inp2_test = np.zeros((1, 5, 4, 3), dtype=config.floatX) - op = Blockwise(test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") + op = Blockwise(my_test_op, signature="(m, n), (m, n) -> (n, m), (m, k)") outs = op(inp1, inp2) assert outs[0].type.shape == (7, 5, None, None) assert outs[1].type.shape == (7, 5, None, None) @@ -650,3 +650,51 @@ def L_op(self, inputs, outputs, output_gradients): np.ones(12, dtype=config.floatX), strict=True, ) + + +def test_blockwise_grad_core_type(): + class StrictCoreTypeOp(Op): + def make_node(self, x): + assert x.type.shape[-1] == 2 + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + output_storage[0][0] = inputs[0] + 1 + + def L_op(self, inputs, outputs, output_grads): + [x] = inputs + assert x.type.shape == (2,) + return [x.zeros_like()] + + strict_core_type_op = StrictCoreTypeOp() + block_strict_core_type_op = Blockwise(strict_core_type_op, signature="(a)->(a)") + + x = tensor("x", shape=(5, 2), dtype="float64") + y = block_strict_core_type_op(x) + assert y.type.shape == (5, 2) + + grad_y = grad(y.sum(), x) + assert grad_y.type.shape == (5, 2) + np.testing.assert_allclose( + grad_y.eval({x: np.ones((5, 2))}), + np.zeros((5, 2)), + ) + + +def test_scan_gradient_core_type(): + n_steps = 3 + seq = tensor("seq", shape=(n_steps, 1), dtype="float64") + out, _ = scan( + lambda s: s, + sequences=[seq], + n_steps=n_steps, + ) + + vec_seq = tensor("vec_seq", shape=(None, n_steps, 1), dtype="float64") + vec_out = vectorize_graph(out, replace={seq: vec_seq}) + grad_sit_sot0 = grad(vec_out.sum(), vec_seq) + + np.testing.assert_allclose( + grad_sit_sot0.eval({vec_seq: np.ones((4, n_steps, 1))}), + np.ones((4, n_steps, 1)), + )