Skip to content

Equivalent dots not merged #1537

Open
Open
@ricardoV94

Description

@ricardoV94

Description

In the example below we end up computing 4 dots, whereas only 3 are needed

import pytensor
import pytensor.tensor as pt
from pytensor.compile.mode import get_default_mode

A = pt.dmatrix('A')
x = pt.col('x')
f = (x.T @ A @ x), A @ x, A.T @ x

fn = pytensor.function([A, x], f, mode=get_default_mode().excluding("BlasOpt"))
fn.dprint()
# dot [id A] 2
#  ├─ dot [id B] 1
#  │  ├─ Transpose{axes=[1, 0]} [id C] 'x.T' 0
#  │  │  └─ x [id D]
#  │  └─ A [id E]
#  └─ x [id D]
# dot [id F] 3
#  ├─ A [id E]
#  └─ x [id D]
# dot [id G] 5
#  ├─ Transpose{axes=[1, 0]} [id H] 'A.T' 4
#  │  └─ A [id E]
#  └─ x [id D]

We could use associativity to write (x.T @ A) @ x -> as x.T @ (A @ x), where the inner dot is equivalent to the second output, so they could be merged

Alternatively, we could use the transpose rule to write the third output A.T @ x -> (x.T @ A).T which is the innermost dot in the first output, so they could be merge (with an extra transpose which is just a cheap view anyway).

With associativity we may want to be careful as order can impact a lot on performance. If we know the static shapes we can optimize like einsum does (see also ##961), but if we are already computing it anyway then we can't possible be doing worse.

The example above can easily show up in the gradient of a quadratic form graph.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions