Skip to content

Commit 7eca1a5

Browse files
committed
Code reformatting
1 parent 1354614 commit 7eca1a5

29 files changed

+775
-557
lines changed

examples/fairseq/generate.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# flake8: noqa
55
import models
66
import tasks
7-
87
from fairseq_cli.generate import cli_main
98

109
if __name__ == "__main__":

examples/fairseq/interactive.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# flake8: noqa
55
import models
66
import tasks
7-
87
from fairseq_cli.interactive import cli_main
98

109
if __name__ == "__main__":

examples/fairseq/models/bert.py

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@
22
# Licensed under The MIT License [see LICENSE for details]
33

44
import logging
5-
from typing import Optional
65
from dataclasses import dataclass, field
6+
from typing import Optional
77

88
import torch
99
import torch.nn as nn
1010
import torch.nn.functional as F
11+
from apex.normalization import FusedLayerNorm as LayerNorm
1112
from fairseq import utils
12-
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
1313
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
14-
from fairseq.models.transformer import (
15-
DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
16-
)
17-
from fairseq.modules import PositionalEmbedding
14+
from fairseq.models import BaseFairseqModel, register_model, register_model_architecture
1815
from fairseq.models.squad import SQuADHead
16+
from fairseq.models.transformer import DEFAULT_MIN_PARAMS_TO_WRAP, Embedding
17+
from fairseq.modules import PositionalEmbedding
1918
from omegaconf import II
20-
from .machine_translation import MTEncoder as Encoder
19+
2120
from torchscale.architecture.config import EncoderConfig
22-
from apex.normalization import FusedLayerNorm as LayerNorm
21+
22+
from .machine_translation import MTEncoder as Encoder
2323

2424
DEFAULT_MAX_SOURCE_POSITIONS = 1024
2525

@@ -109,7 +109,7 @@ class BertConfig(FairseqDataclass):
109109
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
110110
"--offload-activations are passed."
111111
)
112-
}
112+
},
113113
)
114114
max_source_positions: int = field(
115115
default=1024, metadata={"help": "max source positions"}
@@ -118,59 +118,41 @@ class BertConfig(FairseqDataclass):
118118
default="relu", metadata={"help": "activation function to use for pooler layer"}
119119
)
120120
pooler_dropout: float = field(
121-
default=0.0, metadata={"help": "dropout probability in the masked_lm pooler layers"}
121+
default=0.0,
122+
metadata={"help": "dropout probability in the masked_lm pooler layers"},
122123
)
123124
# options from other parts of the config
124125
# add_bos_token: bool = II("task.add_bos_token")
125126
# tokens_per_sample: int = II("task.tokens_per_sample")
126127
tpu: bool = II("common.tpu")
127-
rel_pos_buckets: int = field(
128-
default=0, metadata={"help": ""}
129-
)
130-
max_rel_pos: int = field(
131-
default=0, metadata={"help": ""}
132-
)
128+
rel_pos_buckets: int = field(default=0, metadata={"help": ""})
129+
max_rel_pos: int = field(default=0, metadata={"help": ""})
133130
moe_freq: int = field(
134131
default=0,
135-
metadata={
136-
"help": "Frequency at which we insert MoE Transformer layers"
137-
},
132+
metadata={"help": "Frequency at which we insert MoE Transformer layers"},
138133
)
139134
moe_expert_count: int = field(
140-
default=0,
141-
metadata={
142-
"help": "Number of experts in each MoE Layer"
143-
}
135+
default=0, metadata={"help": "Number of experts in each MoE Layer"}
144136
)
145137
moe_gating_use_fp32: bool = field(
146138
default=False,
147-
metadata={
148-
"help": "Use FP32 computations in MoE top2 gating function"
149-
}
139+
metadata={"help": "Use FP32 computations in MoE top2 gating function"},
150140
)
151141
moe_second_expert_policy: str = field(
152-
default='sampling',
153-
metadata={
154-
"help": "policy for second expert, options: all/sampling/random"
155-
}
142+
default="sampling",
143+
metadata={"help": "policy for second expert, options: all/sampling/random"},
156144
)
157145
moe_normalize_gate_prob_before_dropping: bool = field(
158146
default=False,
159147
metadata={
160-
"help": 'whether to normalize gate probs before or after dropping experts for capacity and randomization'
161-
}
148+
"help": "whether to normalize gate probs before or after dropping experts for capacity and randomization"
149+
},
162150
)
163151
moe_expert_ffn_dim: Optional[int] = field(
164-
default=None,
165-
metadata={
166-
"help": "MoE expert FFN dimension"
167-
}
152+
default=None, metadata={"help": "MoE expert FFN dimension"}
168153
)
169154
moe_top1_expert: Optional[bool] = field(
170-
default=False,
171-
metadata={
172-
"help": "Use top1 gate instead of top2"
173-
}
155+
default=False, metadata={"help": "Use top1 gate instead of top2"}
174156
)
175157
moe_eval_capacity_token_fraction: Optional[float] = field(
176158
default=0.25,
@@ -179,23 +161,29 @@ class BertConfig(FairseqDataclass):
179161
"Default: 0.25, Fraction of tokens as capacity during validation, "
180162
"if set to negative, use same as training. range: (0.0, 1.0]."
181163
)
182-
}
164+
},
183165
)
184166
moe_normalize_expert_grad: Optional[str] = field(
185-
default='world_size',
167+
default="world_size",
186168
metadata={
187169
"help": "Divide expert gradients by (1) 'world_size' (2) 'sqrt_world_size'"
188-
}
170+
},
189171
)
190172
record_a2a_perf_stats: Optional[bool] = field(
191-
default=False, metadata={"help": "records all to all perf stats during distributed training"}
173+
default=False,
174+
metadata={"help": "records all to all perf stats during distributed training"},
192175
)
193176
dummy_a2a: Optional[bool] = field(
194-
default=False, metadata={
195-
"help": "By passes all to all during distributed training by returning the input buffer as output"}
177+
default=False,
178+
metadata={
179+
"help": "By passes all to all during distributed training by returning the input buffer as output"
180+
},
196181
)
197182
moe_batch_prioritized_routing: Optional[bool] = field(
198-
default=False, metadata={"help": "if true orders token by the gate prob before capacity dropping."}
183+
default=False,
184+
metadata={
185+
"help": "if true orders token by the gate prob before capacity dropping."
186+
},
199187
)
200188
ddp_rank: int = II("distributed_training.distributed_rank")
201189
deepnorm: Optional[bool] = field(
@@ -208,7 +196,6 @@ class BertConfig(FairseqDataclass):
208196

209197
@register_model("mlm", dataclass=BertConfig)
210198
class BertModel(BaseFairseqModel):
211-
212199
def __init__(self, args, encoder):
213200
super().__init__()
214201
self.args = args
@@ -240,7 +227,11 @@ def build_model(cls, args, task):
240227
)
241228

242229
lm_head = cls.build_lm_head(
243-
args, args.encoder_embed_dim, len(task.dictionary), args.activation_fn, weight=embed_tokens.weight
230+
args,
231+
args.encoder_embed_dim,
232+
len(task.dictionary),
233+
args.activation_fn,
234+
weight=embed_tokens.weight,
244235
)
245236

246237
config = EncoderConfig()
@@ -269,15 +260,17 @@ def build_lm_head(cls, args, embed_dim, output_dim, activation_fn, weight):
269260
def output_layer(self, features, masked_tokens=None):
270261
return self.encoder.output_projection(features, masked_tokens=masked_tokens)
271262

272-
def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs):
263+
def register_classification_head(
264+
self, name, num_classes=None, inner_dim=None, **kwargs
265+
):
273266
"""Register a classification head."""
274267
if name in self.classification_heads:
275268
prev_num_classes = self.classification_heads[name].out_proj.out_features
276269
prev_inner_dim = self.classification_heads[name].dense.out_features
277270
if num_classes != prev_num_classes or inner_dim != prev_inner_dim:
278271
logger.warning(
279272
're-registering head "{}" with num_classes {} (prev: {}) '
280-
'and inner_dim {} (prev: {})'.format(
273+
"and inner_dim {} (prev: {})".format(
281274
name, num_classes, prev_num_classes, inner_dim, prev_inner_dim
282275
)
283276
)
@@ -295,55 +288,64 @@ def register_question_answering_head(self, name, num_classes=None):
295288
)
296289

297290
def upgrade_state_dict_named(self, state_dict, name):
298-
prefix = name + '.' if name != '' else ''
291+
prefix = name + "." if name != "" else ""
299292

300293
# upgrade children modules
301294
super().upgrade_state_dict_named(state_dict, name)
302295

303296
# Handle new classification heads present in the state dict.
304297
current_head_names = (
305-
[] if not hasattr(self, 'classification_heads')
298+
[]
299+
if not hasattr(self, "classification_heads")
306300
else self.classification_heads.keys()
307301
)
308302
keys_to_delete = []
309303
for k in state_dict.keys():
310-
if not k.startswith(prefix + 'classification_heads.'):
304+
if not k.startswith(prefix + "classification_heads."):
311305
continue
312306

313-
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
314-
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
315-
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
307+
head_name = k[len(prefix + "classification_heads.") :].split(".")[0] # noqa: E203
308+
num_classes = state_dict[
309+
prefix + "classification_heads." + head_name + ".out_proj.weight"
310+
].size(0)
311+
inner_dim = state_dict[
312+
prefix + "classification_heads." + head_name + ".dense.weight"
313+
].size(0)
316314

317-
if getattr(self.args, 'load_checkpoint_heads', False):
315+
if getattr(self.args, "load_checkpoint_heads", False):
318316
if head_name not in current_head_names:
319317
self.register_classification_head(head_name, num_classes, inner_dim)
320318
else:
321319
if head_name not in current_head_names:
322320
logger.warning(
323-
'deleting classification head ({}) from checkpoint '
324-
'not present in current model: {}'.format(head_name, k)
321+
"deleting classification head ({}) from checkpoint "
322+
"not present in current model: {}".format(head_name, k)
325323
)
326324
keys_to_delete.append(k)
327325
elif (
328-
num_classes != self.classification_heads[head_name].out_proj.out_features
329-
or inner_dim != self.classification_heads[head_name].dense.out_features
326+
num_classes
327+
!= self.classification_heads[head_name].out_proj.out_features
328+
or inner_dim
329+
!= self.classification_heads[head_name].dense.out_features
330330
):
331331
logger.warning(
332-
'deleting classification head ({}) from checkpoint '
333-
'with different dimensions than current model: {}'.format(head_name, k)
332+
"deleting classification head ({}) from checkpoint "
333+
"with different dimensions than current model: {}".format(
334+
head_name, k
335+
)
334336
)
335337
keys_to_delete.append(k)
336338
for k in keys_to_delete:
337339
del state_dict[k]
338340

339341
# Copy any newly-added classification heads into the state dict
340342
# with their current weights.
341-
if hasattr(self, 'classification_heads'):
343+
if hasattr(self, "classification_heads"):
342344
cur_state = self.classification_heads.state_dict()
343345
for k, v in cur_state.items():
344-
if prefix + 'classification_heads.' + k not in state_dict:
345-
logger.info('Overwriting ' + prefix + 'classification_heads.' + k)
346-
state_dict[prefix + 'classification_heads.' + k] = v
346+
if prefix + "classification_heads." + k not in state_dict:
347+
logger.info("Overwriting " + prefix + "classification_heads." + k)
348+
state_dict[prefix + "classification_heads." + k] = v
347349

348350
def forward(
349351
self,
@@ -354,7 +356,9 @@ def forward(
354356
masked_tokens=None,
355357
**kwargs
356358
):
357-
encoder_out = self.encoder(src_tokens, features_only=True, return_all_hiddens=return_all_hiddens)
359+
encoder_out = self.encoder(
360+
src_tokens, features_only=True, return_all_hiddens=return_all_hiddens
361+
)
358362
x, extra = encoder_out["encoder_out"], encoder_out
359363
x = x.transpose(0, 1)
360364

@@ -455,7 +459,7 @@ def base_unilm_architecture(args):
455459
args.encoder_input_dim = getattr(args, "encoder_input_dim", args.encoder_embed_dim)
456460

457461
# Model training is not stable without this
458-
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
462+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
459463
args.no_encoder_final_norm = getattr(args, "no_encoder_final_norm", False)
460464

461465
args.no_scale_embedding = getattr(args, "no_scale_embedding", True)

0 commit comments

Comments
 (0)