Skip to content

Commit 4bd259c

Browse files
Adjust numba test
1 parent 30fece4 commit 4bd259c

File tree

1 file changed

+20
-9
lines changed

1 file changed

+20
-9
lines changed

tests/link/numba/test_slinalg.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -724,25 +724,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
724724
np.testing.assert_allclose(b_val_not_contig, b_val)
725725

726726

727-
@pytest.mark.parametrize("stride", [1, 2, -1, -2], ids=lambda x: f"stride={x}")
728-
def test_banded_dot(stride):
727+
def test_banded_dot():
729728
rng = np.random.default_rng()
730729

730+
A = pt.tensor("A", shape=(10, 10), dtype=config.floatX)
731731
A_val = _make_banded_A(rng.normal(size=(10, 10)), kl=1, ku=1).astype(config.floatX)
732732

733-
x_shape = (10 * abs(stride),)
734-
x_val = rng.normal(size=x_shape).astype(config.floatX)
735-
x_val = x_val[::stride]
736-
737-
A = pt.tensor("A", shape=A_val.shape, dtype=A_val.dtype)
738-
x = pt.tensor("x", shape=x_val.shape, dtype=x_val.dtype)
733+
x = pt.tensor("x", shape=(10,), dtype=config.floatX)
734+
x_val = rng.normal(size=(10,)).astype(config.floatX)
739735

740736
output = banded_dot(A, x, upper_diags=1, lower_diags=1)
741737

742-
compare_numba_and_py(
738+
fn, _ = compare_numba_and_py(
743739
[A, x],
744740
output,
745741
test_inputs=[A_val, x_val],
746742
numba_mode=numba_inplace_mode,
747743
eval_obj_mode=False,
748744
)
745+
746+
for stride in [2, -1, -2]:
747+
x_shape = (10 * abs(stride),)
748+
x_val = rng.normal(size=x_shape).astype(config.floatX)
749+
x_val = x_val[::stride]
750+
751+
nb_output = fn(A_val, x_val)
752+
expected = A_val @ x_val
753+
754+
np.testing.assert_allclose(
755+
nb_output,
756+
expected,
757+
strict=True,
758+
err_msg=f"Test failed for stride = {stride}",
759+
)

0 commit comments

Comments
 (0)