12
12
_get_underlying_float ,
13
13
val_to_int_ptr ,
14
14
)
15
- from pytensor .link .numba .dispatch .linalg .utils import _check_scipy_linalg_matrix
15
+ from pytensor .link .numba .dispatch .linalg .utils import (
16
+ _check_scipy_linalg_matrix ,
17
+ _copy_to_fortran_order_even_if_1d ,
18
+ _trans_char_to_int ,
19
+ )
16
20
17
21
18
22
@numba_njit (inline = "always" )
@@ -32,69 +36,140 @@ def A_to_banded(A: np.ndarray, kl: int, ku: int) -> np.ndarray:
32
36
return A_banded
33
37
34
38
35
- def _dot_banded (A : np .ndarray , x : np .ndarray , kl : int , ku : int ) -> Any :
39
+ def _gbmv (
40
+ alpha : np .ndarray ,
41
+ A : np .ndarray ,
42
+ x : np .ndarray ,
43
+ kl : int ,
44
+ ku : int ,
45
+ beta : np .ndarray | None = None ,
46
+ y : np .ndarray | None = None ,
47
+ overwrite_y : bool = False ,
48
+ trans : int = 1 ,
49
+ ) -> Any :
36
50
"""
37
51
Thin wrapper around gmbv. This code will only be called if njit is disabled globally
38
52
(e.g. during testing)
39
53
"""
40
- fn = linalg .get_blas_funcs ("gbmv" , (A , x ))
54
+ ( fn ,) = linalg .get_blas_funcs (( "gbmv" ,) , (A , x ))
41
55
m , n = A .shape
42
56
A_banded = A_to_banded (A , kl = kl , ku = ku )
43
57
44
- return fn (m = m , n = n , kl = kl , ku = ku , alpha = 1 , a = A_banded , x = x )
45
-
46
-
47
- @overload (_dot_banded )
48
- def dot_banded_impl (
49
- A : np .ndarray , x : np .ndarray , kl : int , ku : int
50
- ) -> Callable [[np .ndarray , np .ndarray , int , int ], np .ndarray ]:
58
+ incx = x .strides [0 ] // x .itemsize
59
+ incy = y .strides [0 ] // y .itemsize if y is not None else 1
60
+
61
+ offx = 0 if incx >= 0 else - x .size + 1
62
+ offy = 0 if incy >= 0 else - y .size + 1
63
+
64
+ return fn (
65
+ m = m ,
66
+ n = n ,
67
+ kl = kl ,
68
+ ku = ku ,
69
+ a = A_banded ,
70
+ alpha = alpha ,
71
+ x = x ,
72
+ incx = incx ,
73
+ offx = offx ,
74
+ beta = beta ,
75
+ y = y ,
76
+ overwrite_y = overwrite_y ,
77
+ incy = incy ,
78
+ offy = offy ,
79
+ trans = trans ,
80
+ )
81
+
82
+
83
+ @overload (_gbmv )
84
+ def gbmv_impl (
85
+ alpha : np .ndarray ,
86
+ A : np .ndarray ,
87
+ x : np .ndarray ,
88
+ kl : int ,
89
+ ku : int ,
90
+ beta : np .ndarray | None = None ,
91
+ y : np .ndarray | None = None ,
92
+ overwrite_y : bool = False ,
93
+ trans : int = 1 ,
94
+ ) -> Callable [
95
+ [
96
+ np .ndarray ,
97
+ np .ndarray ,
98
+ np .ndarray ,
99
+ int ,
100
+ int ,
101
+ np .ndarray | None ,
102
+ np .ndarray | None ,
103
+ bool ,
104
+ int ,
105
+ ],
106
+ np .ndarray ,
107
+ ]:
51
108
ensure_lapack ()
52
109
ensure_blas ()
53
110
_check_scipy_linalg_matrix (A , "dot_banded" )
54
111
dtype = A .dtype
55
112
w_type = _get_underlying_float (dtype )
56
113
numba_gbmv = _BLAS ().numba_xgbmv (dtype )
57
114
58
- def impl (A : np .ndarray , x : np .ndarray , kl : int , ku : int ) -> np .ndarray :
115
+ def impl (
116
+ alpha : np .ndarray ,
117
+ A : np .ndarray ,
118
+ x : np .ndarray ,
119
+ kl : int ,
120
+ ku : int ,
121
+ beta : np .ndarray | None = None ,
122
+ y : np .ndarray | None = None ,
123
+ overwrite_y : bool = False ,
124
+ trans : int = 1 ,
125
+ ) -> np .ndarray :
59
126
m , n = A .shape
60
127
61
128
A_banded = A_to_banded (A , kl = kl , ku = ku )
62
- stride = x .strides [0 ] // x .itemsize
129
+ x_stride = x .strides [0 ] // x .itemsize
130
+
131
+ if beta is None :
132
+ beta = np .zeros ((), dtype = dtype )
63
133
64
- TRANS = val_to_int_ptr (ord ("N" ))
134
+ if y is None :
135
+ y_copy = np .empty (shape = (m ,), dtype = dtype )
136
+ elif overwrite_y and y .flags .f_contiguous :
137
+ y_copy = y
138
+ else :
139
+ y_copy = _copy_to_fortran_order_even_if_1d (y )
140
+
141
+ y_stride = y_copy .strides [0 ] // y_copy .itemsize
142
+
143
+ TRANS = val_to_int_ptr (_trans_char_to_int (trans ))
65
144
M = val_to_int_ptr (m )
66
145
N = val_to_int_ptr (n )
67
146
LDA = val_to_int_ptr (A_banded .shape [0 ])
68
147
69
148
KL = val_to_int_ptr (kl )
70
149
KU = val_to_int_ptr (ku )
71
150
72
- ALPHA = np .array (1.0 , dtype = dtype )
73
-
74
- INCX = val_to_int_ptr (stride )
75
- BETA = np .array (0.0 , dtype = dtype )
76
- Y = np .empty (m , dtype = dtype )
77
- INCY = val_to_int_ptr (1 )
151
+ INCX = val_to_int_ptr (x_stride )
152
+ INCY = val_to_int_ptr (y_stride )
78
153
79
154
numba_gbmv (
80
155
TRANS ,
81
156
M ,
82
157
N ,
83
158
KL ,
84
159
KU ,
85
- ALPHA .view (w_type ).ctypes ,
160
+ alpha .view (w_type ).ctypes ,
86
161
A_banded .view (w_type ).ctypes ,
87
162
LDA ,
88
163
# x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have
89
164
# a negative stride, we need to trick BLAS by pointing to the last element of the array.
90
165
# The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes)
91
- (x if stride >= 0 else x [- 1 :]).view (w_type ).ctypes ,
166
+ (x if x_stride >= 0 else x [- 1 :]).view (w_type ).ctypes ,
92
167
INCX ,
93
- BETA .view (w_type ).ctypes ,
94
- Y .view (w_type ).ctypes ,
168
+ beta .view (w_type ).ctypes ,
169
+ y_copy .view (w_type ).ctypes ,
95
170
INCY ,
96
171
)
97
172
98
- return Y
173
+ return y_copy
99
174
100
175
return impl
0 commit comments