@@ -48,4 +48,125 @@ def non_causal_mask(lengths):
4848 return mask
4949
5050
51+ def subsequent_chunk_mask (
52+ size : int ,
53+ chunk_size : int ,
54+ num_left_chunks : int = - 1 ,
55+ device : torch .device = torch .device ("cpu" ),
56+ ) -> torch .Tensor :
57+ """Create mask for subsequent steps (size, size) with chunk size,
58+ this is for streaming encoder
59+
60+ Args:
61+ size (int): size of mask
62+ chunk_size (int): size of chunk
63+ num_left_chunks (int): number of left chunks
64+ <0: use full chunk
65+ >=0: use num_left_chunks
66+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
67+
68+ Returns:
69+ torch.Tensor: mask
70+
71+ Examples:
72+ >>> subsequent_chunk_mask(4, 2)
73+ [[1, 1, 0, 0],
74+ [1, 1, 0, 0],
75+ [1, 1, 1, 1],
76+ [1, 1, 1, 1]]
77+ """
78+ ret = torch .zeros (size , size , device = device , dtype = torch .bool )
79+ for i in range (size ):
80+ if num_left_chunks < 0 :
81+ start = 0
82+ else :
83+ start = max ((i // chunk_size - num_left_chunks ) * chunk_size , 0 )
84+ ending = min ((i // chunk_size + 1 ) * chunk_size , size )
85+ ret [i , start :ending ] = True
86+ return ret
87+
88+
89+ def mask_to_bias (mask : torch .Tensor , dtype : torch .dtype ) -> torch .Tensor :
90+ assert mask .dtype == torch .bool
91+ assert dtype in [torch .float32 , torch .bfloat16 , torch .float16 ]
92+ mask = mask .to (dtype )
93+ # attention mask bias
94+ # NOTE(Mddct): torch.finfo jit issues
95+ # chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
96+ mask = (1.0 - mask ) * - 1.0e+10
97+ return mask
98+
99+
100+ def add_optional_chunk_mask (xs : torch .Tensor ,
101+ masks : torch .Tensor ,
102+ use_dynamic_chunk : bool ,
103+ use_dynamic_left_chunk : bool ,
104+ decoding_chunk_size : int ,
105+ static_chunk_size : int ,
106+ num_decoding_left_chunks : int ,
107+ enable_full_context : bool = True ,
108+ max_chunk_size : int = 25 ):
109+ """ Apply optional mask for encoder.
110+
111+ Args:
112+ xs (torch.Tensor): padded input, (B, L, D), L for max length
113+ mask (torch.Tensor): mask for xs, (B, 1, L)
114+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
115+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
116+ training.
117+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
118+ 0: default for training, use random dynamic chunk.
119+ <0: for decoding, use full chunk.
120+ >0: for decoding, use fixed chunk size as set.
121+ static_chunk_size (int): chunk size for static chunk training/decoding
122+ if it's greater than 0, if use_dynamic_chunk is true,
123+ this parameter will be ignored
124+ num_decoding_left_chunks: number of left chunks, this is for decoding,
125+ the chunk size is decoding_chunk_size.
126+ >=0: use num_decoding_left_chunks
127+ <0: use all left chunks
128+ enable_full_context (bool):
129+ True: chunk size is [1, max_chunk_size] or full context(max_len)
130+ False: chunk size ~ U[1, max_chunk_size]
131+
132+ Returns:
133+ torch.Tensor: chunk mask of the input xs.
134+ """
135+ # Whether to use chunk mask or not
136+ if use_dynamic_chunk :
137+ max_len = xs .size (1 )
138+ if decoding_chunk_size < 0 :
139+ chunk_size = max_len
140+ num_left_chunks = - 1
141+ elif decoding_chunk_size > 0 :
142+ chunk_size = decoding_chunk_size
143+ num_left_chunks = num_decoding_left_chunks
144+ else :
145+ # chunk_size maybe [1, max_chunk_size] or max_len if full context.
146+ chunk_size = torch .randint (1 , max_len , (1 , )).item ()
147+ num_left_chunks = - 1
148+ if chunk_size > max_len // 2 and enable_full_context :
149+ chunk_size = max_len
150+ else :
151+ chunk_size = chunk_size % max_chunk_size + 1
152+ if use_dynamic_left_chunk :
153+ max_left_chunks = (max_len - 1 ) // chunk_size
154+ num_left_chunks = torch .randint (0 , max_left_chunks ,
155+ (1 , )).item ()
156+ chunk_masks = subsequent_chunk_mask (xs .size (1 ), chunk_size ,
157+ num_left_chunks ,
158+ xs .device ) # (L, L)
159+ chunk_masks = chunk_masks .unsqueeze (0 ) # (1, L, L)
160+ chunk_masks = masks & chunk_masks # (B, L, L)
161+ elif static_chunk_size > 0 :
162+ num_left_chunks = num_decoding_left_chunks
163+ chunk_masks = subsequent_chunk_mask (xs .size (1 ), static_chunk_size ,
164+ num_left_chunks ,
165+ xs .device ) # (L, L)
166+ chunk_masks = chunk_masks .unsqueeze (0 ) # (1, L, L)
167+ chunk_masks = masks & chunk_masks # (B, L, L)
168+ else :
169+ chunk_masks = masks
170+ return chunk_masks
171+
51172# print(non_causal_mask(torch.tensor([2, 3, 4], dtype=torch.long)))
0 commit comments