@@ -724,25 +724,36 @@ def test_lu_solve(b_func, b_shape: tuple[int, ...], trans: bool, overwrite_b: bo
724
724
np .testing .assert_allclose (b_val_not_contig , b_val )
725
725
726
726
727
- @pytest .mark .parametrize ("stride" , [1 , 2 , - 1 , - 2 ], ids = lambda x : f"stride={ x } " )
728
- def test_banded_dot (stride ):
727
+ def test_banded_dot ():
729
728
rng = np .random .default_rng ()
730
729
730
+ A = pt .tensor ("A" , shape = (10 , 10 ), dtype = config .floatX )
731
731
A_val = _make_banded_A (rng .normal (size = (10 , 10 )), kl = 1 , ku = 1 ).astype (config .floatX )
732
732
733
- x_shape = (10 * abs (stride ),)
734
- x_val = rng .normal (size = x_shape ).astype (config .floatX )
735
- x_val = x_val [::stride ]
736
-
737
- A = pt .tensor ("A" , shape = A_val .shape , dtype = A_val .dtype )
738
- x = pt .tensor ("x" , shape = x_val .shape , dtype = x_val .dtype )
733
+ x = pt .tensor ("x" , shape = (10 ,), dtype = config .floatX )
734
+ x_val = rng .normal (size = (10 ,)).astype (config .floatX )
739
735
740
736
output = banded_dot (A , x , upper_diags = 1 , lower_diags = 1 )
741
737
742
- compare_numba_and_py (
738
+ fn , _ = compare_numba_and_py (
743
739
[A , x ],
744
740
output ,
745
741
test_inputs = [A_val , x_val ],
746
742
numba_mode = numba_inplace_mode ,
747
743
eval_obj_mode = False ,
748
744
)
745
+
746
+ for stride in [2 , - 1 , - 2 ]:
747
+ x_shape = (10 * abs (stride ),)
748
+ x_val = rng .normal (size = x_shape ).astype (config .floatX )
749
+ x_val = x_val [::stride ]
750
+
751
+ nb_output = fn (A_val , x_val )
752
+ expected = A_val @ x_val
753
+
754
+ np .testing .assert_allclose (
755
+ nb_output ,
756
+ expected ,
757
+ strict = True ,
758
+ err_msg = f"Test failed for stride = { stride } " ,
759
+ )
0 commit comments