4
4
5
5
use crate :: grapheme:: GraphemeClusterSegmenter ;
6
6
use crate :: lstm_error:: Error ;
7
- use crate :: math_helper:: { self , MatrixBorrowed , MatrixBorrowedMut , MatrixOwned } ;
8
- use crate :: provider:: { LstmDataV1Marker , RuleBreakDataV1 } ;
7
+ use crate :: math_helper:: { self , MatrixBorrowedMut , MatrixOwned , MatrixZero } ;
8
+ use crate :: provider:: { LstmDataV1 , LstmDataV1Marker , RuleBreakDataV1 } ;
9
9
use alloc:: string:: String ;
10
10
use alloc:: string:: ToString ;
11
11
use alloc:: vec:: Vec ;
@@ -14,16 +14,16 @@ use icu_provider::DataPayload;
14
14
use zerovec:: ule:: AsULE ;
15
15
16
16
pub struct Lstm < ' l > {
17
- data : & ' l DataPayload < LstmDataV1Marker > ,
18
- mat1 : MatrixOwned < 2 > ,
19
- mat2 : MatrixOwned < 3 > ,
20
- mat3 : MatrixOwned < 3 > ,
21
- mat4 : MatrixOwned < 2 > ,
22
- mat5 : MatrixOwned < 3 > ,
23
- mat6 : MatrixOwned < 3 > ,
24
- mat7 : MatrixOwned < 2 > ,
25
- mat8 : MatrixOwned < 3 > ,
26
- mat9 : MatrixOwned < 1 > ,
17
+ data : & ' l LstmDataV1 < ' l > ,
18
+ mat1 : MatrixZero < ' l , 2 > ,
19
+ mat2 : MatrixZero < ' l , 3 > ,
20
+ mat3 : MatrixZero < ' l , 3 > ,
21
+ mat4 : MatrixZero < ' l , 2 > ,
22
+ mat5 : MatrixZero < ' l , 3 > ,
23
+ mat6 : MatrixZero < ' l , 3 > ,
24
+ mat7 : MatrixZero < ' l , 2 > ,
25
+ mat8 : MatrixZero < ' l , 3 > ,
26
+ mat9 : MatrixZero < ' l , 1 > ,
27
27
grapheme : Option < & ' l RuleBreakDataV1 < ' l > > ,
28
28
hunits : usize ,
29
29
}
@@ -48,35 +48,31 @@ impl<'l> Lstm<'l> {
48
48
return Err ( Error :: Syntax ) ;
49
49
}
50
50
51
- // Note: We are currently copying the ZeroVecs into allocated matrices.
52
- // The ICU4X style guide discourages this. We do it here because:
53
- // 1. The data need to be aligned in order to be vectorized.
54
- // 2. The LSTM is highly performance-sensitive.
55
- let mat1 = data. get ( ) . mat1 . alloc_matrix :: < 2 > ( ) ?;
56
- let mat2 = data. get ( ) . mat2 . alloc_matrix :: < 3 > ( ) ?;
57
- let mat3 = data. get ( ) . mat3 . alloc_matrix :: < 3 > ( ) ?;
58
- let mat4 = data. get ( ) . mat4 . alloc_matrix :: < 2 > ( ) ?;
59
- let mat5 = data. get ( ) . mat5 . alloc_matrix :: < 3 > ( ) ?;
60
- let mat6 = data. get ( ) . mat6 . alloc_matrix :: < 3 > ( ) ?;
61
- let mat7 = data. get ( ) . mat7 . alloc_matrix :: < 2 > ( ) ?;
62
- let mat8 = data. get ( ) . mat8 . alloc_matrix :: < 3 > ( ) ?;
63
- let mat9 = data. get ( ) . mat9 . alloc_matrix :: < 1 > ( ) ?;
64
- let embedd_dim = mat1. as_borrowed ( ) . dim ( ) . 1 ;
65
- let hunits = mat3. as_borrowed ( ) . dim ( ) . 0 ;
66
- if mat2. as_borrowed ( ) . dim ( ) != ( hunits, 4 , embedd_dim)
67
- || mat3. as_borrowed ( ) . dim ( ) != ( hunits, 4 , hunits)
68
- || mat4. as_borrowed ( ) . dim ( ) != ( hunits, 4 )
69
- || mat5. as_borrowed ( ) . dim ( ) != ( hunits, 4 , embedd_dim)
70
- || mat6. as_borrowed ( ) . dim ( ) != ( hunits, 4 , hunits)
71
- || mat7. as_borrowed ( ) . dim ( ) != ( hunits, 4 )
72
- || mat8. as_borrowed ( ) . dim ( ) != ( 2 , 4 , hunits)
73
- || mat9. as_borrowed ( ) . dim ( ) != ( 4 )
51
+ let mat1 = data. get ( ) . mat1 . as_matrix_zero :: < 2 > ( ) ?;
52
+ let mat2 = data. get ( ) . mat2 . as_matrix_zero :: < 3 > ( ) ?;
53
+ let mat3 = data. get ( ) . mat3 . as_matrix_zero :: < 3 > ( ) ?;
54
+ let mat4 = data. get ( ) . mat4 . as_matrix_zero :: < 2 > ( ) ?;
55
+ let mat5 = data. get ( ) . mat5 . as_matrix_zero :: < 3 > ( ) ?;
56
+ let mat6 = data. get ( ) . mat6 . as_matrix_zero :: < 3 > ( ) ?;
57
+ let mat7 = data. get ( ) . mat7 . as_matrix_zero :: < 2 > ( ) ?;
58
+ let mat8 = data. get ( ) . mat8 . as_matrix_zero :: < 3 > ( ) ?;
59
+ let mat9 = data. get ( ) . mat9 . as_matrix_zero :: < 1 > ( ) ?;
60
+ let embedd_dim = mat1. dim ( ) . 1 ;
61
+ let hunits = mat3. dim ( ) . 0 ;
62
+ if mat2. dim ( ) != ( hunits, 4 , embedd_dim)
63
+ || mat3. dim ( ) != ( hunits, 4 , hunits)
64
+ || mat4. dim ( ) != ( hunits, 4 )
65
+ || mat5. dim ( ) != ( hunits, 4 , embedd_dim)
66
+ || mat6. dim ( ) != ( hunits, 4 , hunits)
67
+ || mat7. dim ( ) != ( hunits, 4 )
68
+ || mat8. dim ( ) != ( 2 , 4 , hunits)
69
+ || mat9. dim ( ) != ( 4 )
74
70
{
75
71
return Err ( Error :: DimensionMismatch ) ;
76
72
}
77
73
78
74
Ok ( Self {
79
- data,
75
+ data : data . get ( ) ,
80
76
mat1,
81
77
mat2,
82
78
mat3,
@@ -98,7 +94,7 @@ impl<'l> Lstm<'l> {
98
94
/// `get_model_name` returns the name of the LSTM model.
99
95
#[ allow( dead_code) ]
100
96
pub fn get_model_name ( & self ) -> & str {
101
- & self . data . get ( ) . model
97
+ & self . data . model
102
98
}
103
99
104
100
#[ cfg( test) ]
@@ -133,11 +129,11 @@ impl<'l> Lstm<'l> {
133
129
134
130
/// `_return_id` returns the id corresponding to a code point or a grapheme cluster based on the model dictionary.
135
131
fn return_id ( & self , g : & str ) -> i16 {
136
- let id = self . data . get ( ) . dic . get ( g) ;
132
+ let id = self . data . dic . get ( g) ;
137
133
if let Some ( id) = id {
138
134
i16:: from_unaligned ( * id)
139
135
} else {
140
- self . data . get ( ) . dic . len ( ) as i16
136
+ self . data . dic . len ( ) as i16
141
137
}
142
138
}
143
139
@@ -146,12 +142,12 @@ impl<'l> Lstm<'l> {
146
142
#[ must_use] // return value is GIGO path
147
143
fn compute_hc < ' a > (
148
144
& self ,
149
- x_t : MatrixBorrowed < ' a , 1 > ,
145
+ x_t : MatrixZero < ' a , 1 > ,
150
146
mut h_tm1 : MatrixBorrowedMut < ' a , 1 > ,
151
147
mut c_tm1 : MatrixBorrowedMut < ' a , 1 > ,
152
- warr : MatrixBorrowed < ' a , 3 > ,
153
- uarr : MatrixBorrowed < ' a , 3 > ,
154
- barr : MatrixBorrowed < ' a , 2 > ,
148
+ warr : MatrixZero < ' a , 3 > ,
149
+ uarr : MatrixZero < ' a , 3 > ,
150
+ barr : MatrixZero < ' a , 2 > ,
155
151
hunits : usize ,
156
152
) -> Option < ( ) > {
157
153
#[ cfg( debug_assertions) ]
@@ -166,8 +162,8 @@ impl<'l> Lstm<'l> {
166
162
167
163
let mut s_t = barr. to_owned ( ) ;
168
164
169
- s_t. as_mut ( ) . add_dot_3d ( x_t, warr) ;
170
- s_t. as_mut ( ) . add_dot_3d ( h_tm1. as_borrowed ( ) , uarr) ;
165
+ s_t. as_mut ( ) . add_dot_3d_2 ( x_t, warr) ;
166
+ s_t. as_mut ( ) . add_dot_3d_1 ( h_tm1. as_borrowed ( ) , uarr) ;
171
167
172
168
for i in 0 ..hunits {
173
169
let [ s0, s1, s2, s3] = s_t
@@ -238,9 +234,9 @@ impl<'l> Lstm<'l> {
238
234
x_t,
239
235
all_h_fw. submatrix_mut ( i) ?,
240
236
c_fw. as_mut ( ) ,
241
- self . mat2 . as_borrowed ( ) ,
242
- self . mat3 . as_borrowed ( ) ,
243
- self . mat4 . as_borrowed ( ) ,
237
+ self . mat2 ,
238
+ self . mat3 ,
239
+ self . mat4 ,
244
240
hunits,
245
241
) ?;
246
242
}
@@ -259,15 +255,15 @@ impl<'l> Lstm<'l> {
259
255
x_t,
260
256
all_h_bw. submatrix_mut ( input_seq_len - i - 1 ) ?,
261
257
c_bw. as_mut ( ) ,
262
- self . mat5 . as_borrowed ( ) ,
263
- self . mat6 . as_borrowed ( ) ,
264
- self . mat7 . as_borrowed ( ) ,
258
+ self . mat5 ,
259
+ self . mat6 ,
260
+ self . mat7 ,
265
261
self . hunits ,
266
262
) ?;
267
263
}
268
264
269
265
// Combining forward and backward LSTMs using the dense time-distributed layer
270
- let timeb = self . mat9 . as_borrowed ( ) ;
266
+ let timeb = self . mat9 ;
271
267
let mut bies = String :: new ( ) ;
272
268
for i in 0 ..input_seq_len {
273
269
let curr_fw = all_h_fw. submatrix :: < 1 > ( i) ?;
0 commit comments