Skip to content

Commit a6d6c11

Browse files
Update numba QR dispatch
1 parent 973feb8 commit a6d6c11

File tree

6 files changed

+1139
-92
lines changed

6 files changed

+1139
-92
lines changed

pytensor/link/numba/dispatch/linalg/_LAPACK.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ def numba_xgetrs(cls, dtype):
283283
284284
Called by scipy.linalg.lu_solve
285285
"""
286-
...
287286
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
288287
functype = ctypes.CFUNCTYPE(
289288
None,
@@ -457,3 +456,90 @@ def numba_xgtcon(cls, dtype):
457456
_ptr_int, # INFO
458457
)
459458
return functype(lapack_ptr)
459+
460+
@classmethod
461+
def numba_xgeqrf(cls, dtype):
462+
"""
463+
Compute the QR factorization of a general M-by-N matrix A.
464+
465+
Used in QR decomposition (no pivoting).
466+
"""
467+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf")
468+
functype = ctypes.CFUNCTYPE(
469+
None,
470+
_ptr_int, # M
471+
_ptr_int, # N
472+
float_pointer, # A
473+
_ptr_int, # LDA
474+
float_pointer, # TAU
475+
float_pointer, # WORK
476+
_ptr_int, # LWORK
477+
_ptr_int, # INFO
478+
)
479+
return functype(lapack_ptr)
480+
481+
@classmethod
482+
def numba_xgeqp3(cls, dtype):
483+
"""
484+
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
485+
486+
Used in QR decomposition with pivoting.
487+
"""
488+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
489+
functype = ctypes.CFUNCTYPE(
490+
None,
491+
_ptr_int, # M
492+
_ptr_int, # N
493+
float_pointer, # A
494+
_ptr_int, # LDA
495+
_ptr_int, # JPVT
496+
float_pointer, # TAU
497+
float_pointer, # WORK
498+
_ptr_int, # LWORK
499+
_ptr_int, # INFO
500+
)
501+
return functype(lapack_ptr)
502+
503+
@classmethod
504+
def numba_xorgqr(cls, dtype):
505+
"""
506+
Generate the orthogonal matrix Q from a QR factorization (real types).
507+
508+
Used in QR decomposition to form Q.
509+
"""
510+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr")
511+
functype = ctypes.CFUNCTYPE(
512+
None,
513+
_ptr_int, # M
514+
_ptr_int, # N
515+
_ptr_int, # K
516+
float_pointer, # A
517+
_ptr_int, # LDA
518+
float_pointer, # TAU
519+
float_pointer, # WORK
520+
_ptr_int, # LWORK
521+
_ptr_int, # INFO
522+
)
523+
return functype(lapack_ptr)
524+
525+
@classmethod
526+
def numba_xungqr(cls, dtype):
527+
"""
528+
Generate the unitary matrix Q from a QR factorization (complex types).
529+
530+
Used in QR decomposition to form Q for complex types.
531+
"""
532+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr")
533+
functype = ctypes.CFUNCTYPE(
534+
None,
535+
_ptr_int, # M
536+
_ptr_int, # N
537+
_ptr_int, # K
538+
float_pointer, # A
539+
_ptr_int, # LDA
540+
float_pointer, # TAU
541+
float_pointer, # WORK
542+
_ptr_int, # LWORK
543+
_ptr_int, # INFO
544+
)
545+
return functype(lapack_ptr)

0 commit comments

Comments
 (0)