Skip to content

Commit e2af638

Browse files
authored
Add MatrixZero and use it in the LSTM (#3210)
1 parent af518ea commit e2af638

File tree

3 files changed

+254
-113
lines changed

3 files changed

+254
-113
lines changed

experimental/segmenter/src/lstm_bies.rs

Lines changed: 48 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
use crate::grapheme::GraphemeClusterSegmenter;
66
use crate::lstm_error::Error;
7-
use crate::math_helper::{self, MatrixBorrowed, MatrixBorrowedMut, MatrixOwned};
8-
use crate::provider::{LstmDataV1Marker, RuleBreakDataV1};
7+
use crate::math_helper::{self, MatrixBorrowedMut, MatrixOwned, MatrixZero};
8+
use crate::provider::{LstmDataV1, LstmDataV1Marker, RuleBreakDataV1};
99
use alloc::string::String;
1010
use alloc::string::ToString;
1111
use alloc::vec::Vec;
@@ -14,16 +14,16 @@ use icu_provider::DataPayload;
1414
use zerovec::ule::AsULE;
1515

1616
pub struct Lstm<'l> {
17-
data: &'l DataPayload<LstmDataV1Marker>,
18-
mat1: MatrixOwned<2>,
19-
mat2: MatrixOwned<3>,
20-
mat3: MatrixOwned<3>,
21-
mat4: MatrixOwned<2>,
22-
mat5: MatrixOwned<3>,
23-
mat6: MatrixOwned<3>,
24-
mat7: MatrixOwned<2>,
25-
mat8: MatrixOwned<3>,
26-
mat9: MatrixOwned<1>,
17+
data: &'l LstmDataV1<'l>,
18+
mat1: MatrixZero<'l, 2>,
19+
mat2: MatrixZero<'l, 3>,
20+
mat3: MatrixZero<'l, 3>,
21+
mat4: MatrixZero<'l, 2>,
22+
mat5: MatrixZero<'l, 3>,
23+
mat6: MatrixZero<'l, 3>,
24+
mat7: MatrixZero<'l, 2>,
25+
mat8: MatrixZero<'l, 3>,
26+
mat9: MatrixZero<'l, 1>,
2727
grapheme: Option<&'l RuleBreakDataV1<'l>>,
2828
hunits: usize,
2929
}
@@ -48,35 +48,31 @@ impl<'l> Lstm<'l> {
4848
return Err(Error::Syntax);
4949
}
5050

51-
// Note: We are currently copying the ZeroVecs into allocated matrices.
52-
// The ICU4X style guide discourages this. We do it here because:
53-
// 1. The data need to be aligned in order to be vectorized.
54-
// 2. The LSTM is highly performance-sensitive.
55-
let mat1 = data.get().mat1.alloc_matrix::<2>()?;
56-
let mat2 = data.get().mat2.alloc_matrix::<3>()?;
57-
let mat3 = data.get().mat3.alloc_matrix::<3>()?;
58-
let mat4 = data.get().mat4.alloc_matrix::<2>()?;
59-
let mat5 = data.get().mat5.alloc_matrix::<3>()?;
60-
let mat6 = data.get().mat6.alloc_matrix::<3>()?;
61-
let mat7 = data.get().mat7.alloc_matrix::<2>()?;
62-
let mat8 = data.get().mat8.alloc_matrix::<3>()?;
63-
let mat9 = data.get().mat9.alloc_matrix::<1>()?;
64-
let embedd_dim = mat1.as_borrowed().dim().1;
65-
let hunits = mat3.as_borrowed().dim().0;
66-
if mat2.as_borrowed().dim() != (hunits, 4, embedd_dim)
67-
|| mat3.as_borrowed().dim() != (hunits, 4, hunits)
68-
|| mat4.as_borrowed().dim() != (hunits, 4)
69-
|| mat5.as_borrowed().dim() != (hunits, 4, embedd_dim)
70-
|| mat6.as_borrowed().dim() != (hunits, 4, hunits)
71-
|| mat7.as_borrowed().dim() != (hunits, 4)
72-
|| mat8.as_borrowed().dim() != (2, 4, hunits)
73-
|| mat9.as_borrowed().dim() != (4)
51+
let mat1 = data.get().mat1.as_matrix_zero::<2>()?;
52+
let mat2 = data.get().mat2.as_matrix_zero::<3>()?;
53+
let mat3 = data.get().mat3.as_matrix_zero::<3>()?;
54+
let mat4 = data.get().mat4.as_matrix_zero::<2>()?;
55+
let mat5 = data.get().mat5.as_matrix_zero::<3>()?;
56+
let mat6 = data.get().mat6.as_matrix_zero::<3>()?;
57+
let mat7 = data.get().mat7.as_matrix_zero::<2>()?;
58+
let mat8 = data.get().mat8.as_matrix_zero::<3>()?;
59+
let mat9 = data.get().mat9.as_matrix_zero::<1>()?;
60+
let embedd_dim = mat1.dim().1;
61+
let hunits = mat3.dim().0;
62+
if mat2.dim() != (hunits, 4, embedd_dim)
63+
|| mat3.dim() != (hunits, 4, hunits)
64+
|| mat4.dim() != (hunits, 4)
65+
|| mat5.dim() != (hunits, 4, embedd_dim)
66+
|| mat6.dim() != (hunits, 4, hunits)
67+
|| mat7.dim() != (hunits, 4)
68+
|| mat8.dim() != (2, 4, hunits)
69+
|| mat9.dim() != (4)
7470
{
7571
return Err(Error::DimensionMismatch);
7672
}
7773

7874
Ok(Self {
79-
data,
75+
data: data.get(),
8076
mat1,
8177
mat2,
8278
mat3,
@@ -98,7 +94,7 @@ impl<'l> Lstm<'l> {
9894
/// `get_model_name` returns the name of the LSTM model.
9995
#[allow(dead_code)]
10096
pub fn get_model_name(&self) -> &str {
101-
&self.data.get().model
97+
&self.data.model
10298
}
10399

104100
#[cfg(test)]
@@ -133,11 +129,11 @@ impl<'l> Lstm<'l> {
133129

134130
/// `_return_id` returns the id corresponding to a code point or a grapheme cluster based on the model dictionary.
135131
fn return_id(&self, g: &str) -> i16 {
136-
let id = self.data.get().dic.get(g);
132+
let id = self.data.dic.get(g);
137133
if let Some(id) = id {
138134
i16::from_unaligned(*id)
139135
} else {
140-
self.data.get().dic.len() as i16
136+
self.data.dic.len() as i16
141137
}
142138
}
143139

@@ -146,12 +142,12 @@ impl<'l> Lstm<'l> {
146142
#[must_use] // return value is GIGO path
147143
fn compute_hc<'a>(
148144
&self,
149-
x_t: MatrixBorrowed<'a, 1>,
145+
x_t: MatrixZero<'a, 1>,
150146
mut h_tm1: MatrixBorrowedMut<'a, 1>,
151147
mut c_tm1: MatrixBorrowedMut<'a, 1>,
152-
warr: MatrixBorrowed<'a, 3>,
153-
uarr: MatrixBorrowed<'a, 3>,
154-
barr: MatrixBorrowed<'a, 2>,
148+
warr: MatrixZero<'a, 3>,
149+
uarr: MatrixZero<'a, 3>,
150+
barr: MatrixZero<'a, 2>,
155151
hunits: usize,
156152
) -> Option<()> {
157153
#[cfg(debug_assertions)]
@@ -166,8 +162,8 @@ impl<'l> Lstm<'l> {
166162

167163
let mut s_t = barr.to_owned();
168164

169-
s_t.as_mut().add_dot_3d(x_t, warr);
170-
s_t.as_mut().add_dot_3d(h_tm1.as_borrowed(), uarr);
165+
s_t.as_mut().add_dot_3d_2(x_t, warr);
166+
s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), uarr);
171167

172168
for i in 0..hunits {
173169
let [s0, s1, s2, s3] = s_t
@@ -238,9 +234,9 @@ impl<'l> Lstm<'l> {
238234
x_t,
239235
all_h_fw.submatrix_mut(i)?,
240236
c_fw.as_mut(),
241-
self.mat2.as_borrowed(),
242-
self.mat3.as_borrowed(),
243-
self.mat4.as_borrowed(),
237+
self.mat2,
238+
self.mat3,
239+
self.mat4,
244240
hunits,
245241
)?;
246242
}
@@ -259,15 +255,15 @@ impl<'l> Lstm<'l> {
259255
x_t,
260256
all_h_bw.submatrix_mut(input_seq_len - i - 1)?,
261257
c_bw.as_mut(),
262-
self.mat5.as_borrowed(),
263-
self.mat6.as_borrowed(),
264-
self.mat7.as_borrowed(),
258+
self.mat5,
259+
self.mat6,
260+
self.mat7,
265261
self.hunits,
266262
)?;
267263
}
268264

269265
// Combining forward and backward LSTMs using the dense time-distributed layer
270-
let timeb = self.mat9.as_borrowed();
266+
let timeb = self.mat9;
271267
let mut bies = String::new();
272268
for i in 0..input_seq_len {
273269
let curr_fw = all_h_fw.submatrix::<1>(i)?;

0 commit comments

Comments
 (0)