Skip to content

Commit 976422f

Browse files
Fix negative strides
1 parent f467322 commit 976422f

File tree

1 file changed

+4
-1
lines changed
  • pytensor/link/numba/dispatch/linalg/dot

1 file changed

+4
-1
lines changed

pytensor/link/numba/dispatch/linalg/dot/banded.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ def impl(A: np.ndarray, x: np.ndarray, kl: int, ku: int) -> np.ndarray:
8585
ALPHA.view(w_type).ctypes,
8686
A_banded.view(w_type).ctypes,
8787
LDA,
88-
x.view(w_type).ctypes,
88+
# x.view().ctypes is creating a pointer to the beginning of the memory where the array is. When we have
89+
# a negative stride, we need to trick BLAS by pointing to the last element of the array.
90+
# The [-1:] slice is a workaround to make sure x remains an array (otherwise it has no .ctypes)
91+
(x if stride >= 0 else x[-1:]).view(w_type).ctypes,
8992
INCX,
9093
BETA.view(w_type).ctypes,
9194
Y.view(w_type).ctypes,

0 commit comments

Comments
 (0)