Skip to content

Commit 54ec8aa

Browse files
ckluk2facebook-github-bot
authored andcommitted
Log the grad norms in clipping.py (#2489)
Summary: Pull Request resolved: #2489 Log the grad norms for debugging purpose. Reviewed By: iamzainhuda Differential Revision: D64275984 fbshipit-source-id: 83326a7fe6f1684b494c50d0599248ecb7c93e9b
1 parent bce8ae3 commit 54ec8aa

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

torchrec/optim/clipping.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919

2020
logger: logging.Logger = logging.getLogger()
2121

22+
log_grad_norm: bool = False
23+
use_64bit_grad_norm: bool = False
24+
2225

2326
@unique
2427
class GradientClipping(Enum):
@@ -59,6 +62,7 @@ def __init__(
5962
self._norm_type = norm_type
6063
self._check_meta: bool = True
6164
self._enable_global_grad_clip = enable_global_grad_clip
65+
self._step_num = 0
6266

6367
# Group parameters by model parallelism process group if global clipping is enabled.
6468
# Otherwise, all parameters are treated as replicated and will be clipped locally.
@@ -129,6 +133,7 @@ def step(self, closure: Any = None) -> None:
129133
torch.nn.utils.clip_grad_value_(self._replicate_params, self._max_gradient)
130134

131135
super().step(closure)
136+
self._step_num += 1
132137

133138
@torch.no_grad()
134139
def clip_grad_norm_(self) -> None:
@@ -165,6 +170,8 @@ def clip_grad_norm_(self) -> None:
165170
)
166171
)
167172

173+
square_sharded_grad_norm = total_grad_norm if total_grad_norm is not None else 0
174+
168175
# Process replicated parameters and gradients
169176
if self._replicate_params:
170177
replicated_grads = [
@@ -189,6 +196,22 @@ def clip_grad_norm_(self) -> None:
189196
else total_grad_norm + replicated_grad_norm
190197
)
191198
)
199+
square_replicated_grad_norm = replicated_grad_norm
200+
else:
201+
square_replicated_grad_norm = 0
202+
203+
global log_grad_norm
204+
if log_grad_norm:
205+
if total_grad_norm is not None and self._norm_type != torch.inf:
206+
# pyre-ignore[58]
207+
grad_norm = total_grad_norm ** (1.0 / norm_type)
208+
else:
209+
grad_norm = 0
210+
211+
rank = dist.get_rank()
212+
logger.info(
213+
f"Clipping [rank={rank}, step={self._step_num}]: square_sharded_grad_norm = {square_sharded_grad_norm}, square_replicated_grad_norm = {square_replicated_grad_norm}, total_grad_norm = {grad_norm}"
214+
)
192215

193216
# Aggregation
194217
if total_grad_norm is None:
@@ -212,10 +235,18 @@ def _batch_cal_norm(
212235
"""Helper function that calculates the norm of a list of gradients in batches. If process_groups
213236
are passed in, the norm will be aggregated across all ranks in the process group.
214237
"""
215-
grad_norms = torch.linalg.vector_norm(
216-
torch.stack(torch._foreach_norm(grad_list, norm_type)),
217-
norm_type,
218-
)
238+
239+
global use_64bit_grad_norm
240+
if use_64bit_grad_norm:
241+
grad_norms = torch.linalg.vector_norm(
242+
torch.stack(torch._foreach_norm(grad_list, norm_type, dtype=torch.float64)),
243+
norm_type,
244+
)
245+
else:
246+
grad_norms = torch.linalg.vector_norm(
247+
torch.stack(torch._foreach_norm(grad_list, norm_type)),
248+
norm_type,
249+
)
219250

220251
if norm_type == torch.inf:
221252
if process_groups is not None:
@@ -227,6 +258,9 @@ def _batch_cal_norm(
227258
for pg in process_groups:
228259
dist.all_reduce(grad_norms, group=pg)
229260

261+
if use_64bit_grad_norm:
262+
grad_norms = grad_norms.to(torch.float32)
263+
230264
return grad_norms
231265

232266

0 commit comments

Comments
 (0)