Description
Description
In the example below we end up computing 4 dots, whereas only 3 are needed
import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode
A = pt.dmatrix('A')
x = pt.col('x')
f = (x.T @ A @ x), A @ x, A.T @ x
fn = pytensor.function([A, x], f, mode=get_default_mode().excluding("BlasOpt"))
fn.dprint()
# dot [id A] 2
# ├─ dot [id B] 1
# │ ├─ Transpose{axes=[1, 0]} [id C] 'x.T' 0
# │ │ └─ x [id D]
# │ └─ A [id E]
# └─ x [id D]
# dot [id F] 3
# ├─ A [id E]
# └─ x [id D]
# dot [id G] 5
# ├─ Transpose{axes=[1, 0]} [id H] 'A.T' 4
# │ └─ A [id E]
# └─ x [id D]
We could use associativity to write (x.T @ A) @ x -> as x.T @ (A @ x)
, where the inner dot is equivalent to the second output, so they could be merged
Alternatively, we could use the transpose rule to write the third output A.T @ x -> (x.T @ A).T
which is the innermost dot in the first output, so they could be merge (with an extra transpose which is just a cheap view anyway).
With associativity we may want to be careful as order can impact a lot on performance. If we know the static shapes we can optimize like einsum does (see also ##961), but if we are already computing it anyway then we can't possible be doing worse.
The example above can easily show up in the gradient of a quadratic form graph.