2
2
# Licensed under The MIT License [see LICENSE for details]
3
3
4
4
import logging
5
- from typing import Optional
6
5
from dataclasses import dataclass , field
6
+ from typing import Optional
7
7
8
8
import torch
9
9
import torch .nn as nn
10
10
import torch .nn .functional as F
11
+ from apex .normalization import FusedLayerNorm as LayerNorm
11
12
from fairseq import utils
12
- from fairseq .models import BaseFairseqModel , register_model , register_model_architecture
13
13
from fairseq .dataclass import ChoiceEnum , FairseqDataclass
14
- from fairseq .models .transformer import (
15
- DEFAULT_MIN_PARAMS_TO_WRAP , Embedding
16
- )
17
- from fairseq .modules import PositionalEmbedding
14
+ from fairseq .models import BaseFairseqModel , register_model , register_model_architecture
18
15
from fairseq .models .squad import SQuADHead
16
+ from fairseq .models .transformer import DEFAULT_MIN_PARAMS_TO_WRAP , Embedding
17
+ from fairseq .modules import PositionalEmbedding
19
18
from omegaconf import II
20
- from . machine_translation import MTEncoder as Encoder
19
+
21
20
from torchscale .architecture .config import EncoderConfig
22
- from apex .normalization import FusedLayerNorm as LayerNorm
21
+
22
+ from .machine_translation import MTEncoder as Encoder
23
23
24
24
DEFAULT_MAX_SOURCE_POSITIONS = 1024
25
25
@@ -109,7 +109,7 @@ class BertConfig(FairseqDataclass):
109
109
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
110
110
"--offload-activations are passed."
111
111
)
112
- }
112
+ },
113
113
)
114
114
max_source_positions : int = field (
115
115
default = 1024 , metadata = {"help" : "max source positions" }
@@ -118,59 +118,41 @@ class BertConfig(FairseqDataclass):
118
118
default = "relu" , metadata = {"help" : "activation function to use for pooler layer" }
119
119
)
120
120
pooler_dropout : float = field (
121
- default = 0.0 , metadata = {"help" : "dropout probability in the masked_lm pooler layers" }
121
+ default = 0.0 ,
122
+ metadata = {"help" : "dropout probability in the masked_lm pooler layers" },
122
123
)
123
124
# options from other parts of the config
124
125
# add_bos_token: bool = II("task.add_bos_token")
125
126
# tokens_per_sample: int = II("task.tokens_per_sample")
126
127
tpu : bool = II ("common.tpu" )
127
- rel_pos_buckets : int = field (
128
- default = 0 , metadata = {"help" : "" }
129
- )
130
- max_rel_pos : int = field (
131
- default = 0 , metadata = {"help" : "" }
132
- )
128
+ rel_pos_buckets : int = field (default = 0 , metadata = {"help" : "" })
129
+ max_rel_pos : int = field (default = 0 , metadata = {"help" : "" })
133
130
moe_freq : int = field (
134
131
default = 0 ,
135
- metadata = {
136
- "help" : "Frequency at which we insert MoE Transformer layers"
137
- },
132
+ metadata = {"help" : "Frequency at which we insert MoE Transformer layers" },
138
133
)
139
134
moe_expert_count : int = field (
140
- default = 0 ,
141
- metadata = {
142
- "help" : "Number of experts in each MoE Layer"
143
- }
135
+ default = 0 , metadata = {"help" : "Number of experts in each MoE Layer" }
144
136
)
145
137
moe_gating_use_fp32 : bool = field (
146
138
default = False ,
147
- metadata = {
148
- "help" : "Use FP32 computations in MoE top2 gating function"
149
- }
139
+ metadata = {"help" : "Use FP32 computations in MoE top2 gating function" },
150
140
)
151
141
moe_second_expert_policy : str = field (
152
- default = 'sampling' ,
153
- metadata = {
154
- "help" : "policy for second expert, options: all/sampling/random"
155
- }
142
+ default = "sampling" ,
143
+ metadata = {"help" : "policy for second expert, options: all/sampling/random" },
156
144
)
157
145
moe_normalize_gate_prob_before_dropping : bool = field (
158
146
default = False ,
159
147
metadata = {
160
- "help" : ' whether to normalize gate probs before or after dropping experts for capacity and randomization'
161
- }
148
+ "help" : " whether to normalize gate probs before or after dropping experts for capacity and randomization"
149
+ },
162
150
)
163
151
moe_expert_ffn_dim : Optional [int ] = field (
164
- default = None ,
165
- metadata = {
166
- "help" : "MoE expert FFN dimension"
167
- }
152
+ default = None , metadata = {"help" : "MoE expert FFN dimension" }
168
153
)
169
154
moe_top1_expert : Optional [bool ] = field (
170
- default = False ,
171
- metadata = {
172
- "help" : "Use top1 gate instead of top2"
173
- }
155
+ default = False , metadata = {"help" : "Use top1 gate instead of top2" }
174
156
)
175
157
moe_eval_capacity_token_fraction : Optional [float ] = field (
176
158
default = 0.25 ,
@@ -179,23 +161,29 @@ class BertConfig(FairseqDataclass):
179
161
"Default: 0.25, Fraction of tokens as capacity during validation, "
180
162
"if set to negative, use same as training. range: (0.0, 1.0]."
181
163
)
182
- }
164
+ },
183
165
)
184
166
moe_normalize_expert_grad : Optional [str ] = field (
185
- default = ' world_size' ,
167
+ default = " world_size" ,
186
168
metadata = {
187
169
"help" : "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
188
- }
170
+ },
189
171
)
190
172
record_a2a_perf_stats : Optional [bool ] = field (
191
- default = False , metadata = {"help" : "records all to all perf stats during distributed training" }
173
+ default = False ,
174
+ metadata = {"help" : "records all to all perf stats during distributed training" },
192
175
)
193
176
dummy_a2a : Optional [bool ] = field (
194
- default = False , metadata = {
195
- "help" : "By passes all to all during distributed training by returning the input buffer as output" }
177
+ default = False ,
178
+ metadata = {
179
+ "help" : "By passes all to all during distributed training by returning the input buffer as output"
180
+ },
196
181
)
197
182
moe_batch_prioritized_routing : Optional [bool ] = field (
198
- default = False , metadata = {"help" : "if true orders token by the gate prob before capacity dropping." }
183
+ default = False ,
184
+ metadata = {
185
+ "help" : "if true orders token by the gate prob before capacity dropping."
186
+ },
199
187
)
200
188
ddp_rank : int = II ("distributed_training.distributed_rank" )
201
189
deepnorm : Optional [bool ] = field (
@@ -208,7 +196,6 @@ class BertConfig(FairseqDataclass):
208
196
209
197
@register_model ("mlm" , dataclass = BertConfig )
210
198
class BertModel (BaseFairseqModel ):
211
-
212
199
def __init__ (self , args , encoder ):
213
200
super ().__init__ ()
214
201
self .args = args
@@ -240,7 +227,11 @@ def build_model(cls, args, task):
240
227
)
241
228
242
229
lm_head = cls .build_lm_head (
243
- args , args .encoder_embed_dim , len (task .dictionary ), args .activation_fn , weight = embed_tokens .weight
230
+ args ,
231
+ args .encoder_embed_dim ,
232
+ len (task .dictionary ),
233
+ args .activation_fn ,
234
+ weight = embed_tokens .weight ,
244
235
)
245
236
246
237
config = EncoderConfig ()
@@ -269,15 +260,17 @@ def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
269
260
def output_layer (self , features , masked_tokens = None ):
270
261
return self .encoder .output_projection (features , masked_tokens = masked_tokens )
271
262
272
- def register_classification_head (self , name , num_classes = None , inner_dim = None , ** kwargs ):
263
+ def register_classification_head (
264
+ self , name , num_classes = None , inner_dim = None , ** kwargs
265
+ ):
273
266
"""Register a classification head."""
274
267
if name in self .classification_heads :
275
268
prev_num_classes = self .classification_heads [name ].out_proj .out_features
276
269
prev_inner_dim = self .classification_heads [name ].dense .out_features
277
270
if num_classes != prev_num_classes or inner_dim != prev_inner_dim :
278
271
logger .warning (
279
272
're-registering head "{}" with num_classes {} (prev: {}) '
280
- ' and inner_dim {} (prev: {})' .format (
273
+ " and inner_dim {} (prev: {})" .format (
281
274
name , num_classes , prev_num_classes , inner_dim , prev_inner_dim
282
275
)
283
276
)
@@ -295,55 +288,64 @@ def register_question_answering_head(self, name, num_classes=None):
295
288
)
296
289
297
290
def upgrade_state_dict_named (self , state_dict , name ):
298
- prefix = name + '.' if name != '' else ''
291
+ prefix = name + "." if name != "" else ""
299
292
300
293
# upgrade children modules
301
294
super ().upgrade_state_dict_named (state_dict , name )
302
295
303
296
# Handle new classification heads present in the state dict.
304
297
current_head_names = (
305
- [] if not hasattr (self , 'classification_heads' )
298
+ []
299
+ if not hasattr (self , "classification_heads" )
306
300
else self .classification_heads .keys ()
307
301
)
308
302
keys_to_delete = []
309
303
for k in state_dict .keys ():
310
- if not k .startswith (prefix + ' classification_heads.' ):
304
+ if not k .startswith (prefix + " classification_heads." ):
311
305
continue
312
306
313
- head_name = k [len (prefix + 'classification_heads.' ):].split ('.' )[0 ]
314
- num_classes = state_dict [prefix + 'classification_heads.' + head_name + '.out_proj.weight' ].size (0 )
315
- inner_dim = state_dict [prefix + 'classification_heads.' + head_name + '.dense.weight' ].size (0 )
307
+ head_name = k [len (prefix + "classification_heads." ) :].split ("." )[0 ] # noqa: E203
308
+ num_classes = state_dict [
309
+ prefix + "classification_heads." + head_name + ".out_proj.weight"
310
+ ].size (0 )
311
+ inner_dim = state_dict [
312
+ prefix + "classification_heads." + head_name + ".dense.weight"
313
+ ].size (0 )
316
314
317
- if getattr (self .args , ' load_checkpoint_heads' , False ):
315
+ if getattr (self .args , " load_checkpoint_heads" , False ):
318
316
if head_name not in current_head_names :
319
317
self .register_classification_head (head_name , num_classes , inner_dim )
320
318
else :
321
319
if head_name not in current_head_names :
322
320
logger .warning (
323
- ' deleting classification head ({}) from checkpoint '
324
- ' not present in current model: {}' .format (head_name , k )
321
+ " deleting classification head ({}) from checkpoint "
322
+ " not present in current model: {}" .format (head_name , k )
325
323
)
326
324
keys_to_delete .append (k )
327
325
elif (
328
- num_classes != self .classification_heads [head_name ].out_proj .out_features
329
- or inner_dim != self .classification_heads [head_name ].dense .out_features
326
+ num_classes
327
+ != self .classification_heads [head_name ].out_proj .out_features
328
+ or inner_dim
329
+ != self .classification_heads [head_name ].dense .out_features
330
330
):
331
331
logger .warning (
332
- 'deleting classification head ({}) from checkpoint '
333
- 'with different dimensions than current model: {}' .format (head_name , k )
332
+ "deleting classification head ({}) from checkpoint "
333
+ "with different dimensions than current model: {}" .format (
334
+ head_name , k
335
+ )
334
336
)
335
337
keys_to_delete .append (k )
336
338
for k in keys_to_delete :
337
339
del state_dict [k ]
338
340
339
341
# Copy any newly-added classification heads into the state dict
340
342
# with their current weights.
341
- if hasattr (self , ' classification_heads' ):
343
+ if hasattr (self , " classification_heads" ):
342
344
cur_state = self .classification_heads .state_dict ()
343
345
for k , v in cur_state .items ():
344
- if prefix + ' classification_heads.' + k not in state_dict :
345
- logger .info (' Overwriting ' + prefix + ' classification_heads.' + k )
346
- state_dict [prefix + ' classification_heads.' + k ] = v
346
+ if prefix + " classification_heads." + k not in state_dict :
347
+ logger .info (" Overwriting " + prefix + " classification_heads." + k )
348
+ state_dict [prefix + " classification_heads." + k ] = v
347
349
348
350
def forward (
349
351
self ,
@@ -354,7 +356,9 @@ def forward(
354
356
masked_tokens = None ,
355
357
** kwargs
356
358
):
357
- encoder_out = self .encoder (src_tokens , features_only = True , return_all_hiddens = return_all_hiddens )
359
+ encoder_out = self .encoder (
360
+ src_tokens , features_only = True , return_all_hiddens = return_all_hiddens
361
+ )
358
362
x , extra = encoder_out ["encoder_out" ], encoder_out
359
363
x = x .transpose (0 , 1 )
360
364
@@ -455,7 +459,7 @@ def base_unilm_architecture(args):
455
459
args .encoder_input_dim = getattr (args , "encoder_input_dim" , args .encoder_embed_dim )
456
460
457
461
# Model training is not stable without this
458
- args .encoder_normalize_before = getattr (args , ' encoder_normalize_before' , False )
462
+ args .encoder_normalize_before = getattr (args , " encoder_normalize_before" , False )
459
463
args .no_encoder_final_norm = getattr (args , "no_encoder_final_norm" , False )
460
464
461
465
args .no_scale_embedding = getattr (args , "no_scale_embedding" , True )
0 commit comments