Skip to content

Commit dbca437

Browse files
sarckkfacebook-github-bot
authored andcommitted
Ignore non-existent grad tensors for semi-sync (#2490)
Summary: Pull Request resolved: #2490 During semi-sync, we need to ignore embedding tensor grads that are `None`, otherwise `torch.autograd.backward` will fail with error `grad can be implicitly created only for scalar outputs`. This is a valid scenario if, for example, the embeddings are looked up but never used for the final loss computation Reviewed By: che-sh Differential Revision: D63379382 fbshipit-source-id: 62d5b6153d6aab339ef774f3a319630b7a2cfe98
1 parent 1d8824b commit dbca437

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Deque,
2222
Dict,
2323
Generic,
24+
Iterable,
2425
Iterator,
2526
List,
2627
Optional,
@@ -911,17 +912,40 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
911912
if cast(int, context.index) % 2 == 0
912913
else self._embedding_odd_streams
913914
)
914-
for stream, emb_tensors, detached_emb_tensors in zip(
915+
assert len(context.embedding_features) == len(context.embedding_tensors)
916+
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
915917
streams,
916918
context.embedding_tensors,
919+
context.embedding_features,
917920
context.detached_embedding_tensors,
918921
):
919922
with self._stream_context(stream):
920923
grads = [tensor.grad for tensor in detached_emb_tensors]
921924
if stream:
922925
stream.wait_stream(default_stream)
923-
# pyre-ignore
924-
torch.autograd.backward(emb_tensors, grads)
926+
# Some embeddings may never get used in the final loss computation,
927+
# so the grads will be `None`. If we don't exclude these, it will fail
928+
# with error: "grad can be implicitly created only for scalar outputs"
929+
# Alternatively, if the tensor has only 1 element, pytorch can still
930+
# figure out how to do autograd
931+
embs_to_backprop, grads_to_use, invalid_features = [], [], []
932+
assert len(embedding_features) == len(emb_tensors)
933+
for features, tensor, grad in zip(
934+
embedding_features, emb_tensors, grads
935+
):
936+
if tensor.numel() == 1 or grad is not None:
937+
embs_to_backprop.append(tensor)
938+
grads_to_use.append(grad)
939+
else:
940+
if isinstance(features, Iterable):
941+
invalid_features.extend(features)
942+
else:
943+
invalid_features.append(features)
944+
if invalid_features and context.index == 0:
945+
logger.warning(
946+
f"SemiSync, the following features have no gradients: {invalid_features}"
947+
)
948+
torch.autograd.backward(embs_to_backprop, grads_to_use)
925949

926950
def copy_batch_to_gpu(
927951
self,

torchrec/distributed/train_pipeline/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ class PrefetchTrainPipelineContext(TrainPipelineContext):
128128
class EmbeddingTrainPipelineContext(TrainPipelineContext):
129129
embedding_a2a_requests: Dict[str, Multistreamable] = field(default_factory=dict)
130130
embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
131+
embedding_features: List[List[Union[str, List[str]]]] = field(default_factory=list)
131132
detached_embedding_tensors: List[List[torch.Tensor]] = field(default_factory=list)
132133

133134

@@ -408,6 +409,8 @@ def __call__(self, *input, **kwargs) -> Awaitable:
408409
# pyre-ignore [16]
409410
self._context.embedding_tensors.append(tensors)
410411
# pyre-ignore [16]
412+
self._context.embedding_features.append(list(embeddings.keys()))
413+
# pyre-ignore [16]
411414
self._context.detached_embedding_tensors.append(detached_tensors)
412415
else:
413416
assert isinstance(embeddings, KeyedTensor)
@@ -418,6 +421,13 @@ def __call__(self, *input, **kwargs) -> Awaitable:
418421
tensors.append(tensor)
419422
detached_tensors.append(detached_tensor)
420423
self._context.embedding_tensors.append(tensors)
424+
# KeyedTensor is returned by EmbeddingBagCollections and its variants
425+
# KeyedTensor holds dense data from multiple features and .values()
426+
# returns a single concatenated dense tensor. To ensure that
427+
# context.embedding_tensors[i] has the same length as
428+
# context.embedding_features[i], we pass in a list with a single item:
429+
# a list containing all the embedding feature names.
430+
self._context.embedding_features.append([list(embeddings.keys())])
421431
self._context.detached_embedding_tensors.append(detached_tensors)
422432

423433
return LazyNoWait(embeddings)

0 commit comments

Comments
 (0)