Skip to content

Commit 72ba0dc

Browse files
Rename BandedDot to BandedGEMV and move to blas.py
1 parent 976422f commit 72ba0dc

File tree

9 files changed

+346
-174
lines changed

9 files changed

+346
-174
lines changed

pytensor/link/numba/dispatch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
import pytensor.link.numba.dispatch.sparse
1515
import pytensor.link.numba.dispatch.subtensor
1616
import pytensor.link.numba.dispatch.tensor_basic
17-
17+
import pytensor.link.numba.dispatch.blas
1818

1919
# isort: on

pytensor/link/numba/dispatch/blas.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from pytensor.link.numba.dispatch import numba_funcify
2+
from pytensor.link.numba.dispatch.basic import numba_njit
3+
from pytensor.link.numba.dispatch.linalg.dot.banded import _gbmv
4+
from pytensor.link.numba.dispatch.slinalg import _COMPLEX_DTYPE_NOT_SUPPORTED_MSG
5+
from pytensor.tensor.blas import BandedGEMV
6+
from pytensor.tensor.type import complex_dtypes
7+
8+
9+
@numba_funcify.register(BandedGEMV)
10+
def numba_funcify_BandedGEMV(op, node, **kwargs):
11+
kl = op.lower_diags
12+
ku = op.upper_diags
13+
overwrite_y = op.overwrite_y
14+
trans = int(op.transpose)
15+
dtype = node.inputs[0].dtype
16+
17+
if dtype in complex_dtypes:
18+
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
19+
20+
@numba_njit(cache=False)
21+
def banded_gemv(A, x, y, alpha, beta):
22+
return _gbmv(
23+
A=A,
24+
x=x,
25+
kl=kl,
26+
ku=ku,
27+
y=y,
28+
alpha=alpha,
29+
beta=beta,
30+
overwrite_y=overwrite_y,
31+
trans=trans,
32+
)
33+
34+
return banded_gemv

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 99 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@
1212
_get_underlying_float,
1313
val_to_int_ptr,
1414
)
15-
from pytensor.link.numba.dispatch.linalg.utils import _check_scipy_linalg_matrix
15+
from pytensor.link.numba.dispatch.linalg.utils import (
16+
_check_scipy_linalg_matrix,
17+
_copy_to_fortran_order_even_if_1d,
18+
_trans_char_to_int,
19+
)
1620

1721

1822
@numba_njit(inline="always")
@@ -32,69 +36,140 @@ def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
3236
return A_banded
3337

3438

35-
def _dot_banded(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> Any:
39+
def _gbmv(
40+
alpha: np.ndarray,
41+
A: np.ndarray,
42+
x: np.ndarray,
43+
kl: int,
44+
ku: int,
45+
beta: np.ndarray | None = None,
46+
y: np.ndarray | None = None,
47+
overwrite_y: bool = False,
48+
trans: int = 1,
49+
) -> Any:
3650
"""
3751
Thin wrapper around gmbv. This code will only be called if njit is disabled globally
3852
(e.g. during testing)
3953
"""
40-
fn = linalg.get_blas_funcs("gbmv", (A, x))
54+
(fn,) = linalg.get_blas_funcs(("gbmv",), (A, x))
4155
m, n = A.shape
4256
A_banded = A_to_banded(A, kl=kl, ku=ku)
4357

44-
return fn(m=m, n=n, kl=kl, ku=ku, alpha=1, a=A_banded, x=x)
45-
46-
47-
@overload(_dot_banded)
48-
def dot_banded_impl(
49-
A: np.ndarray, x: np.ndarray, kl: int, ku: int
50-
) -> Callable[[np.ndarray, np.ndarray, int, int], np.ndarray]:
58+
incx = x.strides[0] // x.itemsize
59+
incy = y.strides[0] // y.itemsize if y is not None else 1
60+
61+
offx = 0 if incx >= 0 else -x.size + 1
62+
offy = 0 if incy >= 0 else -y.size + 1
63+
64+
return fn(
65+
m=m,
66+
n=n,
67+
kl=kl,
68+
ku=ku,
69+
a=A_banded,
70+
alpha=alpha,
71+
x=x,
72+
incx=incx,
73+
offx=offx,
74+
beta=beta,
75+
y=y,
76+
overwrite_y=overwrite_y,
77+
incy=incy,
78+
offy=offy,
79+
trans=trans,
80+
)
81+
82+
83+
@overload(_gbmv)
84+
def gbmv_impl(
85+
alpha: np.ndarray,
86+
A: np.ndarray,
87+
x: np.ndarray,
88+
kl: int,
89+
ku: int,
90+
beta: np.ndarray | None = None,
91+
y: np.ndarray | None = None,
92+
overwrite_y: bool = False,
93+
trans: int = 1,
94+
) -> Callable[
95+
[
96+
np.ndarray,
97+
np.ndarray,
98+
np.ndarray,
99+
int,
100+
int,
101+
np.ndarray | None,
102+
np.ndarray | None,
103+
bool,
104+
int,
105+
],
106+
np.ndarray,
107+
]:
51108
ensure_lapack()
52109
ensure_blas()
53110
_check_scipy_linalg_matrix(A, "dot_banded")
54111
dtype = A.dtype
55112
w_type = _get_underlying_float(dtype)
56113
numba_gbmv = _BLAS().numba_xgbmv(dtype)
57114

58-
def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
115+
def impl(
116+
alpha: np.ndarray,
117+
A: np.ndarray,
118+
x: np.ndarray,
119+
kl: int,
120+
ku: int,
121+
beta: np.ndarray | None = None,
122+
y: np.ndarray | None = None,
123+
overwrite_y: bool = False,
124+
trans: int = 1,
125+
) -> np.ndarray:
59126
m, n = A.shape
60127

61128
A_banded = A_to_banded(A, kl=kl, ku=ku)
62-
stride = x.strides[0] // x.itemsize
129+
x_stride = x.strides[0] // x.itemsize
130+
131+
if beta is None:
132+
beta = np.zeros((), dtype=dtype)
63133

64-
TRANS = val_to_int_ptr(ord("N"))
134+
if y is None:
135+
y_copy = np.empty(shape=(m,), dtype=dtype)
136+
elif overwrite_y and y.flags.f_contiguous:
137+
y_copy = y
138+
else:
139+
y_copy = _copy_to_fortran_order_even_if_1d(y)
140+
141+
y_stride = y_copy.strides[0] // y_copy.itemsize
142+
143+
TRANS = val_to_int_ptr(_trans_char_to_int(trans))
65144
M = val_to_int_ptr(m)
66145
N = val_to_int_ptr(n)
67146
LDA = val_to_int_ptr(A_banded.shape[0])
68147

69148
KL = val_to_int_ptr(kl)
70149
KU = val_to_int_ptr(ku)
71150

72-
ALPHA = np.array(1.0, dtype=dtype)
73-
74-
INCX = val_to_int_ptr(stride)
75-
BETA = np.array(0.0, dtype=dtype)
76-
Y = np.empty(m, dtype=dtype)
77-
INCY = val_to_int_ptr(1)
151+
INCX = val_to_int_ptr(x_stride)
152+
INCY = val_to_int_ptr(y_stride)
78153

79154
numba_gbmv(
80155
TRANS,
81156
M,
82157
N,
83158
KL,
84159
KU,
85-
ALPHA.view(w_type).ctypes,
160+
alpha.view(w_type).ctypes,
86161
A_banded.view(w_type).ctypes,
87162
LDA,
88163
# x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have
89164
# a negative stride, we need to trick BLAS by pointing to the last element of the array.
90165
# The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes)
91-
(x if stride >= 0 else x[-1:]).view(w_type).ctypes,
166+
(x if x_stride >= 0 else x[-1:]).view(w_type).ctypes,
92167
INCX,
93-
BETA.view(w_type).ctypes,
94-
Y.view(w_type).ctypes,
168+
beta.view(w_type).ctypes,
169+
y_copy.view(w_type).ctypes,
95170
INCY,
96171
)
97172

98-
return Y
173+
return y_copy
99174

100175
return impl

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
_pivot_to_permutation,
1212
)
1313
from pytensor.link.numba.dispatch.linalg.decomposition.lu_factor import _lu_factor
14-
from pytensor.link.numba.dispatch.linalg.dot.banded import _dot_banded
1514
from pytensor.link.numba.dispatch.linalg.solve.cholesky import _cho_solve
1615
from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
1716
from pytensor.link.numba.dispatch.linalg.solve.posdef import _solve_psd
@@ -20,7 +19,6 @@
2019
from pytensor.link.numba.dispatch.linalg.solve.tridiagonal import _solve_tridiagonal
2120
from pytensor.tensor.slinalg import (
2221
LU,
23-
BandedDot,
2422
BlockDiagonal,
2523
Cholesky,
2624
CholeskySolve,
@@ -313,19 +311,3 @@ def cho_solve(c, b):
313311
)
314312

315313
return cho_solve
316-
317-
318-
@numba_funcify.register(BandedDot)
319-
def numba_funcify_BandedDot(op, node, **kwargs):
320-
kl = op.lower_diags
321-
ku = op.upper_diags
322-
dtype = node.inputs[0].dtype
323-
324-
if dtype in complex_dtypes:
325-
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
326-
327-
@numba_njit(cache=False)
328-
def banded_dot(A, x):
329-
return _dot_banded(A, x, kl=kl, ku=ku)
330-
331-
return banded_dot

0 commit comments

Comments
 (0)