@@ -54,28 +54,44 @@ where
5454 let min_dim = usize:: min ( n, m) ;
5555 assert ! ( rank <= min_dim) ;
5656
57- for _ in 0 ..10 {
58- // handle full-rank case
59- let out = if rank == min_dim {
60- random ( shape. clone ( ) )
61-
62- // handle partial-rank case
63- } else {
64- // multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
65- // produce `an m × n` array with rank `r`
66- // https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
67- let mut out = Array2 :: zeros ( shape. clone ( ) ) ;
68- let left: Array2 < A > = random ( [ out. nrows ( ) , rank] ) ;
69- let right: Array2 < A > = random ( [ rank, out. ncols ( ) ] ) ;
57+ let mut rng = thread_rng ( ) ;
58+
59+ // handle full-rank case
60+ if rank == min_dim {
61+ let mut out = random ( shape) ;
62+ for _ in 0 ..10 {
63+ // check rank
64+ if let Ok ( out_rank) = out. rank ( ) {
65+ if out_rank == rank {
66+ return out;
67+ }
68+ }
69+
70+ out. mapv_inplace ( |_| A :: rand ( & mut rng) ) ;
71+ }
72+
73+ // handle partial-rank case
74+ //
75+ // multiplying two full-rank arrays with dimensions `m × r` and `r × n` will
76+ // produce `an m × n` array with rank `r`
77+ // https://en.wikipedia.org/wiki/Rank_(linear_algebra)#Properties
78+ } else {
79+ let mut out = Array2 :: zeros ( shape) ;
80+ let mut left: Array2 < A > = random ( [ out. nrows ( ) , rank] ) ;
81+ let mut right: Array2 < A > = random ( [ rank, out. ncols ( ) ] ) ;
82+
83+ for _ in 0 ..10 {
7084 general_mat_mul ( A :: one ( ) , & left, & right, A :: zero ( ) , & mut out) ;
71- out
72- } ;
7385
74- // check rank
75- if let Ok ( out_rank) = out. rank ( ) {
76- if out_rank == rank {
77- return out;
86+ // check rank
87+ if let Ok ( out_rank) = out. rank ( ) {
88+ if out_rank == rank {
89+ return out;
90+ }
7891 }
92+
93+ left. mapv_inplace ( |_| A :: rand ( & mut rng) ) ;
94+ right. mapv_inplace ( |_| A :: rand ( & mut rng) ) ;
7995 }
8096 }
8197
0 commit comments