Skip to content

Commit 265bd4a

Browse files
committed
move to 2d matrix repr
1 parent e2af638 commit 265bd4a

File tree

8 files changed

+13227
-13260
lines changed

8 files changed

+13227
-13260
lines changed

experimental/segmenter/src/lstm_bies.rs

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
use crate::grapheme::GraphemeClusterSegmenter;
66
use crate::lstm_error::Error;
7-
use crate::math_helper::{self, MatrixBorrowedMut, MatrixOwned, MatrixZero};
7+
use crate::math_helper::{MatrixBorrowedMut, MatrixOwned, MatrixZero};
88
use crate::provider::{LstmDataV1, LstmDataV1Marker, RuleBreakDataV1};
99
use alloc::string::String;
1010
use alloc::string::ToString;
@@ -16,12 +16,12 @@ use zerovec::ule::AsULE;
1616
pub struct Lstm<'l> {
1717
data: &'l LstmDataV1<'l>,
1818
mat1: MatrixZero<'l, 2>,
19-
mat2: MatrixZero<'l, 3>,
20-
mat3: MatrixZero<'l, 3>,
21-
mat4: MatrixZero<'l, 2>,
22-
mat5: MatrixZero<'l, 3>,
23-
mat6: MatrixZero<'l, 3>,
24-
mat7: MatrixZero<'l, 2>,
19+
mat2: MatrixZero<'l, 2>,
20+
mat3: MatrixZero<'l, 2>,
21+
mat4: MatrixZero<'l, 1>,
22+
mat5: MatrixZero<'l, 2>,
23+
mat6: MatrixZero<'l, 2>,
24+
mat7: MatrixZero<'l, 1>,
2525
mat8: MatrixZero<'l, 3>,
2626
mat9: MatrixZero<'l, 1>,
2727
grapheme: Option<&'l RuleBreakDataV1<'l>>,
@@ -49,22 +49,22 @@ impl<'l> Lstm<'l> {
4949
}
5050

5151
let mat1 = data.get().mat1.as_matrix_zero::<2>()?;
52-
let mat2 = data.get().mat2.as_matrix_zero::<3>()?;
53-
let mat3 = data.get().mat3.as_matrix_zero::<3>()?;
54-
let mat4 = data.get().mat4.as_matrix_zero::<2>()?;
55-
let mat5 = data.get().mat5.as_matrix_zero::<3>()?;
56-
let mat6 = data.get().mat6.as_matrix_zero::<3>()?;
57-
let mat7 = data.get().mat7.as_matrix_zero::<2>()?;
52+
let mat2 = data.get().mat2.as_matrix_zero::<2>()?;
53+
let mat3 = data.get().mat3.as_matrix_zero::<2>()?;
54+
let mat4 = data.get().mat4.as_matrix_zero::<1>()?;
55+
let mat5 = data.get().mat5.as_matrix_zero::<2>()?;
56+
let mat6 = data.get().mat6.as_matrix_zero::<2>()?;
57+
let mat7 = data.get().mat7.as_matrix_zero::<1>()?;
5858
let mat8 = data.get().mat8.as_matrix_zero::<3>()?;
5959
let mat9 = data.get().mat9.as_matrix_zero::<1>()?;
6060
let embedd_dim = mat1.dim().1;
61-
let hunits = mat3.dim().0;
62-
if mat2.dim() != (hunits, 4, embedd_dim)
63-
|| mat3.dim() != (hunits, 4, hunits)
64-
|| mat4.dim() != (hunits, 4)
65-
|| mat5.dim() != (hunits, 4, embedd_dim)
66-
|| mat6.dim() != (hunits, 4, hunits)
67-
|| mat7.dim() != (hunits, 4)
61+
let hunits = mat3.dim().1;
62+
if mat2.dim() != (4 * hunits, embedd_dim)
63+
|| mat3.dim() != (4 * hunits, hunits)
64+
|| mat4.dim() != (4 * hunits)
65+
|| mat5.dim() != (4 * hunits, embedd_dim)
66+
|| mat6.dim() != (4 * hunits, hunits)
67+
|| mat7.dim() != (4 * hunits)
6868
|| mat8.dim() != (2, 4, hunits)
6969
|| mat9.dim() != (4)
7070
{
@@ -145,40 +145,40 @@ impl<'l> Lstm<'l> {
145145
x_t: MatrixZero<'a, 1>,
146146
mut h_tm1: MatrixBorrowedMut<'a, 1>,
147147
mut c_tm1: MatrixBorrowedMut<'a, 1>,
148-
warr: MatrixZero<'a, 3>,
149-
uarr: MatrixZero<'a, 3>,
150-
barr: MatrixZero<'a, 2>,
148+
warr: MatrixZero<'a, 2>,
149+
uarr: MatrixZero<'a, 2>,
150+
barr: MatrixZero<'a, 1>,
151151
hunits: usize,
152152
) -> Option<()> {
153153
#[cfg(debug_assertions)]
154154
{
155155
let embedd_dim = x_t.dim();
156156
h_tm1.as_borrowed().debug_assert_dims([hunits]);
157157
c_tm1.as_borrowed().debug_assert_dims([hunits]);
158-
warr.debug_assert_dims([hunits, 4, embedd_dim]);
159-
uarr.debug_assert_dims([hunits, 4, hunits]);
160-
barr.debug_assert_dims([hunits, 4]);
158+
warr.debug_assert_dims([4 * hunits, embedd_dim]);
159+
uarr.debug_assert_dims([4 * hunits, hunits]);
160+
barr.debug_assert_dims([4 * hunits]);
161161
}
162162

163163
let mut s_t = barr.to_owned();
164164

165-
s_t.as_mut().add_dot_3d_2(x_t, warr);
166-
s_t.as_mut().add_dot_3d_1(h_tm1.as_borrowed(), uarr);
167-
168-
for i in 0..hunits {
169-
let [s0, s1, s2, s3] = s_t
170-
.as_borrowed()
171-
.submatrix::<1>(i)
172-
.and_then(|s| s.read_4())?;
173-
let p = math_helper::sigmoid(s0);
174-
let f = math_helper::sigmoid(s1);
175-
let c = math_helper::tanh(s2);
176-
let o = math_helper::sigmoid(s3);
177-
let c_old = c_tm1.as_borrowed().as_slice().get(i)?;
178-
let c_new = p * c + f * c_old;
179-
*c_tm1.as_mut_slice().get_mut(i)? = c_new;
180-
*h_tm1.as_mut_slice().get_mut(i)? = o * math_helper::tanh(c_new);
181-
}
165+
s_t.as_mut().add_dot_2d_1(x_t, warr);
166+
s_t.as_mut().add_dot_2d(h_tm1.as_borrowed(), uarr);
167+
168+
s_t.as_mut().sigmoid(0..2 * hunits);
169+
s_t.as_mut().tanh(2 * hunits..3 * hunits);
170+
s_t.as_mut().sigmoid(3 * hunits..4 * hunits);
171+
172+
let sb = s_t.as_borrowed();
173+
174+
c_tm1.convolve(
175+
sb.view(0..hunits),
176+
sb.view(2 * hunits..3 * hunits),
177+
sb.view(hunits..2 * hunits),
178+
);
179+
180+
h_tm1.mul_tanh(sb.view(3*hunits..4*hunits), c_tm1.as_borrowed());
181+
182182
Some(())
183183
}
184184

experimental/segmenter/src/math_helper.rs

Lines changed: 50 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,6 @@ impl<'a, const D: usize> MatrixBorrowed<'a, D> {
8484
debug_assert_eq!(expected_len, self.data.len());
8585
}
8686

87-
pub fn as_slice(&self) -> &'a [f32] {
88-
self.data
89-
}
90-
91-
/// See [`MatrixOwned::submatrix`].
92-
#[inline]
93-
pub fn submatrix<const M: usize>(&self, index: usize) -> Option<MatrixBorrowed<'a, M>> {
94-
// This assertion is based on const generics; it should always succeed and be elided.
95-
assert_eq!(M, D - 1);
96-
let (range, dims) = self.submatrix_range(index);
97-
let data = &self.data.get(range)?;
98-
Some(MatrixBorrowed { data, dims })
99-
}
100-
10187
#[inline]
10288
fn submatrix_range<const M: usize>(&self, index: usize) -> (Range<usize>, [usize; M]) {
10389
// This assertion is based on const generics; it should always succeed and be elided.
@@ -210,6 +196,14 @@ impl<'a> MatrixBorrowed<'a, 1> {
210196
debug_assert_eq!(self.dims, other.dims);
211197
unrolled_dot_1(self.data, other.data)
212198
}
199+
200+
pub fn view(&'a self, r: Range<usize>) -> MatrixBorrowed<'a, 1> {
201+
MatrixBorrowed {
202+
#[allow(clippy::indexing_slicing)]
203+
data: &self.data[r],
204+
dims: self.dims,
205+
}
206+
}
213207
}
214208

215209
impl<'a> MatrixBorrowedMut<'a, 1> {
@@ -245,87 +239,74 @@ impl<'a> MatrixBorrowedMut<'a, 1> {
245239
}
246240
}
247241
}
248-
}
249242

250-
impl<'a> MatrixBorrowedMut<'a, 2> {
251243
/// Calculate the dot product of a and b, adding the result to self.
252244
///
253-
/// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
254-
pub fn add_dot_3d_1(&mut self, a: MatrixBorrowed<1>, b: MatrixZero<3>) {
245+
/// Note: For better dot product efficiency, if `b` is MxN, then `a` should be N;
246+
/// this is the opposite of standard practice.
247+
pub fn add_dot_2d_1(&mut self, a: MatrixZero<1>, b: MatrixZero<2>) {
255248
let m = a.dim();
256-
let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
249+
let n = self.as_borrowed().dim();
257250
debug_assert_eq!(
258251
m,
259-
b.dim().2,
252+
b.dim().1,
260253
"dims: {:?}/{:?}/{:?}",
261254
self.as_borrowed().dim(),
262255
a.dim(),
263256
b.dim()
264257
);
265258
debug_assert_eq!(
266259
n,
267-
b.dim().0 * b.dim().1,
260+
b.dim().0,
268261
"dims: {:?}/{:?}/{:?}",
269262
self.as_borrowed().dim(),
270263
a.dim(),
271264
b.dim()
272265
);
273-
// Note: The following two loops are equivalent, but the second has more opportunity for
274-
// vectorization since it allows the vectorization to span submatrices.
275-
// for i in 0..b.dim().0 {
276-
// self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i));
277-
// }
278-
let lhs = a.as_slice();
279266
for i in 0..n {
280-
if let (Some(dest), Some(rhs)) = (
281-
self.as_mut_slice().get_mut(i),
282-
b.as_slice().get_subslice(i * m..(i + 1) * m),
283-
) {
284-
*dest += unrolled_dot_1(lhs, rhs);
267+
if let (Some(dest), Some(b_sub)) = (self.as_mut_slice().get_mut(i), b.submatrix::<1>(i))
268+
{
269+
*dest += unrolled_dot_2(a.as_slice(), b_sub.data);
285270
} else {
286271
debug_assert!(false, "unreachable: dims checked above");
287272
}
288273
}
289274
}
290275

291-
/// Calculate the dot product of a and b, adding the result to self.
292-
///
293-
/// Self should be _MxN_; `a`, _O_; and `b`, _MxNxO_.
294-
pub fn add_dot_3d_2(&mut self, a: MatrixZero<1>, b: MatrixZero<3>) {
295-
let m = a.dim();
296-
let n = self.as_borrowed().dim().0 * self.as_borrowed().dim().1;
297-
debug_assert_eq!(
298-
m,
299-
b.dim().2,
300-
"dims: {:?}/{:?}/{:?}",
301-
self.as_borrowed().dim(),
302-
a.dim(),
303-
b.dim()
304-
);
305-
debug_assert_eq!(
306-
n,
307-
b.dim().0 * b.dim().1,
308-
"dims: {:?}/{:?}/{:?}",
309-
self.as_borrowed().dim(),
310-
a.dim(),
311-
b.dim()
312-
);
313-
// Note: The following two loops are equivalent, but the second has more opportunity for
314-
// vectorization since it allows the vectorization to span submatrices.
315-
// for i in 0..b.dim().0 {
316-
// self.submatrix_mut::<1>(i).add_dot_2d(a, b.submatrix(i));
317-
// }
318-
let lhs = a.as_slice();
319-
for i in 0..n {
320-
if let (Some(dest), Some(rhs)) = (
321-
self.as_mut_slice().get_mut(i),
322-
b.as_slice().get_subslice(i * m..(i + 1) * m),
323-
) {
324-
*dest += unrolled_dot_2(lhs, rhs);
325-
} else {
326-
debug_assert!(false, "unreachable: dims checked above");
327-
}
276+
pub fn sigmoid(&mut self, r: Range<usize>) -> Option<()> {
277+
let slice = self.data.get_mut(r)?;
278+
for i in &mut *slice {
279+
*i = sigmoid(*i);
280+
}
281+
Some(())
282+
}
283+
284+
pub fn tanh(&mut self, r: Range<usize>) -> Option<()> {
285+
let slice = self.data.get_mut(r)?;
286+
for i in &mut *slice {
287+
*i = tanh(*i);
288+
}
289+
Some(())
290+
}
291+
292+
pub fn convolve(
293+
&mut self,
294+
i: MatrixBorrowed<'_, 1>,
295+
c: MatrixBorrowed<'_, 1>,
296+
f: MatrixBorrowed<'_, 1>,
297+
) -> Option<()> {
298+
for idx in 0..self.data.len() {
299+
*self.data.get_mut(idx)? =
300+
i.data.get(idx)? * c.data.get(idx)? + self.data.get(idx)? * f.data.get(idx)?
328301
}
302+
Some(())
303+
}
304+
305+
pub fn mul_tanh(&mut self, o: MatrixBorrowed<'_, 1>, c: MatrixBorrowed<'_, 1>) -> Option<()> {
306+
for idx in 0..self.data.len() {
307+
*self.data.get_mut(idx)? = o.data.get(idx)? * tanh(*c.data.get(idx)?);
308+
}
309+
Some(())
329310
}
330311
}
331312

provider/datagen/src/transform/segmenter/lstm.rs

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ struct RawLstmData {
5959
impl RawLstmData {
6060
pub fn try_convert(&self) -> Result<LstmDataV1<'static>, DataError> {
6161
let mat1 = self.mat1.to_ndarray2()?;
62-
let mat2 = self.mat2.to_ndarray2()?;
63-
let mat3 = self.mat3.to_ndarray2()?;
62+
let mut mat2 = self.mat2.to_ndarray2()?;
63+
let mut mat3 = self.mat3.to_ndarray2()?;
6464
let mat4 = self.mat4.to_ndarray1()?;
65-
let mat5 = self.mat5.to_ndarray2()?;
66-
let mat6 = self.mat6.to_ndarray2()?;
65+
let mut mat5 = self.mat5.to_ndarray2()?;
66+
let mut mat6 = self.mat6.to_ndarray2()?;
6767
let mat7 = self.mat7.to_ndarray1()?;
6868
let mat8 = self.mat8.to_ndarray2()?;
6969
let mat9 = self.mat9.to_ndarray1()?;
@@ -81,19 +81,11 @@ impl RawLstmData {
8181
return Err(DIMENSION_MISMATCH_ERROR);
8282
}
8383
// Unwraps okay: dimensions checked above
84-
let mut mat2 = mat2.into_shape((embedd_dim, 4, hunits)).unwrap();
85-
let mut mat3 = mat3.into_shape((hunits, 4, hunits)).unwrap();
86-
let mut mat4 = mat4.into_shape((4, hunits)).unwrap();
87-
let mut mat5 = mat5.into_shape((embedd_dim, 4, hunits)).unwrap();
88-
let mut mat6 = mat6.into_shape((hunits, 4, hunits)).unwrap();
89-
let mut mat7 = mat7.into_shape((4, hunits)).unwrap();
9084
let mut mat8 = mat8.into_shape((2, hunits, 4)).unwrap();
91-
mat2.swap_axes(0, 2);
92-
mat3.swap_axes(0, 2);
93-
mat4.swap_axes(0, 1);
94-
mat5.swap_axes(0, 2);
95-
mat6.swap_axes(0, 2);
96-
mat7.swap_axes(0, 1);
85+
mat2.swap_axes(0, 1);
86+
mat3.swap_axes(0, 1);
87+
mat5.swap_axes(0, 1);
88+
mat6.swap_axes(0, 1);
9789
mat8.swap_axes(1, 2);
9890
let mat2 = mat2.as_standard_layout().into_owned();
9991
let mat3 = mat3.as_standard_layout().into_owned();

provider/repodata/data/json/fingerprints.csv

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)