diff --git a/utils.py b/utils.py index 66c8a7b..be642cb 100644 --- a/utils.py +++ b/utils.py @@ -144,7 +144,7 @@ def _sample_mask(seg, mask_alpha, mask_beta, end = beg + 1 cnt_ngram = 1 while end < seg_len: - if _is_start_piece([seg[beg]]): + if _is_start_piece([seg[end]]): cnt_ngram += 1 if cnt_ngram > n: break