19
19
20
20
logger : logging .Logger = logging .getLogger ()
21
21
22
+ log_grad_norm : bool = False
23
+ use_64bit_grad_norm : bool = False
24
+
22
25
23
26
@unique
24
27
class GradientClipping (Enum ):
@@ -59,6 +62,7 @@ def __init__(
59
62
self ._norm_type = norm_type
60
63
self ._check_meta : bool = True
61
64
self ._enable_global_grad_clip = enable_global_grad_clip
65
+ self ._step_num = 0
62
66
63
67
# Group parameters by model parallelism process group if global clipping is enabled.
64
68
# Otherwise, all parameters are treated as replicated and will be clipped locally.
@@ -129,6 +133,7 @@ def step(self, closure: Any = None) -> None:
129
133
torch .nn .utils .clip_grad_value_ (self ._replicate_params , self ._max_gradient )
130
134
131
135
super ().step (closure )
136
+ self ._step_num += 1
132
137
133
138
@torch .no_grad ()
134
139
def clip_grad_norm_ (self ) -> None :
@@ -165,6 +170,8 @@ def clip_grad_norm_(self) -> None:
165
170
)
166
171
)
167
172
173
+ square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174
+
168
175
# Process replicated parameters and gradients
169
176
if self ._replicate_params :
170
177
replicated_grads = [
@@ -189,6 +196,22 @@ def clip_grad_norm_(self) -> None:
189
196
else total_grad_norm + replicated_grad_norm
190
197
)
191
198
)
199
+ square_replicated_grad_norm = replicated_grad_norm
200
+ else :
201
+ square_replicated_grad_norm = 0
202
+
203
+ global log_grad_norm
204
+ if log_grad_norm :
205
+ if total_grad_norm is not None and self ._norm_type != torch .inf :
206
+ # pyre-ignore[58]
207
+ grad_norm = total_grad_norm ** (1.0 / norm_type )
208
+ else :
209
+ grad_norm = 0
210
+
211
+ rank = dist .get_rank ()
212
+ logger .info (
213
+ f"Clipping [rank={ rank } , step={ self ._step_num } ]: square_sharded_grad_norm = { square_sharded_grad_norm } , square_replicated_grad_norm = { square_replicated_grad_norm } , total_grad_norm = { grad_norm } "
214
+ )
192
215
193
216
# Aggregation
194
217
if total_grad_norm is None :
@@ -212,10 +235,18 @@ def _batch_cal_norm(
212
235
"""Helper function that calculates the norm of a list of gradients in batches. If process_groups
213
236
are passed in, the norm will be aggregated across all ranks in the process group.
214
237
"""
215
- grad_norms = torch .linalg .vector_norm (
216
- torch .stack (torch ._foreach_norm (grad_list , norm_type )),
217
- norm_type ,
218
- )
238
+
239
+ global use_64bit_grad_norm
240
+ if use_64bit_grad_norm :
241
+ grad_norms = torch .linalg .vector_norm (
242
+ torch .stack (torch ._foreach_norm (grad_list , norm_type , dtype = torch .float64 )),
243
+ norm_type ,
244
+ )
245
+ else :
246
+ grad_norms = torch .linalg .vector_norm (
247
+ torch .stack (torch ._foreach_norm (grad_list , norm_type )),
248
+ norm_type ,
249
+ )
219
250
220
251
if norm_type == torch .inf :
221
252
if process_groups is not None :
@@ -227,6 +258,9 @@ def _batch_cal_norm(
227
258
for pg in process_groups :
228
259
dist .all_reduce (grad_norms , group = pg )
229
260
261
+ if use_64bit_grad_norm :
262
+ grad_norms = grad_norms .to (torch .float32 )
263
+
230
264
return grad_norms
231
265
232
266
0 commit comments