Skip to content

Commit 4a7b7a8

Browse files
Update JAX QR dispatch
1 parent 6b61f6c commit 4a7b7a8

File tree

4 files changed

+23
-17
lines changed

4 files changed

+23
-17
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
KroneckerProduct,
1010
MatrixInverse,
1111
MatrixPinv,
12-
QRFull,
1312
SLogDet,
1413
)
1514

@@ -67,16 +66,6 @@ def matrix_inverse(x):
6766
return matrix_inverse
6867

6968

70-
@jax_funcify.register(QRFull)
71-
def jax_funcify_QRFull(op, **kwargs):
72-
mode = op.mode
73-
74-
def qr_full(x, mode=mode):
75-
return jnp.linalg.qr(x, mode=mode)
76-
77-
return qr_full
78-
79-
8069
@jax_funcify.register(MatrixPinv)
8170
def jax_funcify_Pinv(op, **kwargs):
8271
def pinv(x):

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.tensor.slinalg import (
77
LU,
8+
QR,
89
BlockDiagonal,
910
Cholesky,
1011
CholeskySolve,
@@ -168,3 +169,13 @@ def cho_solve(c, b):
168169
)
169170

170171
return cho_solve
172+
173+
174+
@jax_funcify.register(QR)
175+
def jax_funcify_QR(op, **kwargs):
176+
mode = op.mode
177+
178+
def qr(x, mode=mode):
179+
return jax.scipy.linalg.qr(x, mode=mode)
180+
181+
return qr

tests/link/jax/test_nlinalg.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,6 @@ def assert_fn(x, y):
2929
outs = pt_nlinalg.eigh(x)
3030
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
3131

32-
outs = pt_nlinalg.qr(x, mode="full")
33-
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
34-
35-
outs = pt_nlinalg.qr(x, mode="reduced")
36-
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
37-
3832
outs = pt_nlinalg.svd(x)
3933
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
4034

tests/link/jax/test_slinalg.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,18 @@ def test_jax_basic():
103103
],
104104
)
105105

106+
def assert_fn(x, y):
107+
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
108+
109+
M = rng.normal(size=(3, 3))
110+
X = M.dot(M.T)
111+
112+
outs = pt_slinalg.qr(x, mode="full")
113+
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
114+
115+
outs = pt_slinalg.qr(x, mode="economic")
116+
compare_jax_and_py([x], outs, [X.astype(config.floatX)], assert_fn=assert_fn)
117+
106118

107119
def test_jax_solve():
108120
rng = np.random.default_rng(utt.fetch_seed())

0 commit comments

Comments
 (0)