diff --git a/pretrain.py b/pretrain.py index fd3eb7f..a5a71bc 100644 --- a/pretrain.py +++ b/pretrain.py @@ -143,7 +143,7 @@ def __call__(self, instance): # For masked Language Models masked_tokens, masked_pos, tokens = _sample_mask(tokens, self.mask_alpha, self.mask_beta, self.max_gram, - goal_num_predict=n_pred) + goal_num_predict=n_pred, vocab_words=self.vocab_words) masked_weights = [1]*len(masked_tokens) diff --git a/utils.py b/utils.py index 66c8a7b..0b3a8a5 100644 --- a/utils.py +++ b/utils.py @@ -13,6 +13,7 @@ import numpy as np import torch +from random import random as rand def set_seeds(seed): "set random seeds" @@ -105,7 +106,7 @@ def _is_start_piece(piece): return False def _sample_mask(seg, mask_alpha, mask_beta, - max_gram=3, goal_num_predict=85): + max_gram=3, goal_num_predict=85, vocab_words=None): # try to n-gram masking SpanBERT(Joshi et al., 2019) # 3-gram implementation seg_len = len(seg) @@ -164,12 +165,31 @@ def _sample_mask(seg, mask_alpha, mask_beta, mask[i] = True num_predict += 1 - tokens, masked_tokens, masked_pos = [], [], [] - for i in range(seg_len): - if mask[i] and (seg[i] != '[CLS]' and seg[i] != '[SEP]'): - masked_tokens.append(seg[i]) - masked_pos.append(i) - tokens.append('[MASK]') - else: - tokens.append(seg[i]) + tokens, masked_tokens, masked_pos = seg, [], [] + i = 0 + while i < seg_len: + if mask[i]: + if rand() < 0.8: + i = set_mask(i, seg_len, mask, tokens, masked_tokens, masked_pos, vocab_words, 0) + elif rand() < 0.5: # 10% + i = set_mask(i, seg_len, mask, tokens, masked_tokens, masked_pos, vocab_words, 1) + else: # 10% + i = set_mask(i, seg_len, mask, tokens, masked_tokens, masked_pos, vocab_words, 2) + i += 1 + return masked_tokens, masked_pos, tokens + +def set_mask(start, end, mask, tokens, masked_tokens, masked_pos, vocab_words, mode): + for j in range(start, end): + if not mask[j]: + break + if tokens[j] == '[CLS]' or tokens[j] == '[SEP]': + continue + masked_tokens.append(tokens[j]) + masked_pos.append(j) + if mode == 0: + tokens[j] = '[MASK]' + elif mode == 1: + tokens[j] = get_random_word(vocab_words) + start += 1 + return start