|
21 | 21 | Deque,
|
22 | 22 | Dict,
|
23 | 23 | Generic,
|
| 24 | + Iterable, |
24 | 25 | Iterator,
|
25 | 26 | List,
|
26 | 27 | Optional,
|
@@ -911,17 +912,40 @@ def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
|
911 | 912 | if cast(int, context.index) % 2 == 0
|
912 | 913 | else self._embedding_odd_streams
|
913 | 914 | )
|
914 |
| - for stream, emb_tensors, detached_emb_tensors in zip( |
| 915 | + assert len(context.embedding_features) == len(context.embedding_tensors) |
| 916 | + for stream, emb_tensors, embedding_features, detached_emb_tensors in zip( |
915 | 917 | streams,
|
916 | 918 | context.embedding_tensors,
|
| 919 | + context.embedding_features, |
917 | 920 | context.detached_embedding_tensors,
|
918 | 921 | ):
|
919 | 922 | with self._stream_context(stream):
|
920 | 923 | grads = [tensor.grad for tensor in detached_emb_tensors]
|
921 | 924 | if stream:
|
922 | 925 | stream.wait_stream(default_stream)
|
923 |
| - # pyre-ignore |
924 |
| - torch.autograd.backward(emb_tensors, grads) |
| 926 | + # Some embeddings may never get used in the final loss computation, |
| 927 | + # so the grads will be `None`. If we don't exclude these, it will fail |
| 928 | + # with error: "grad can be implicitly created only for scalar outputs" |
| 929 | + # Alternatively, if the tensor has only 1 element, pytorch can still |
| 930 | + # figure out how to do autograd |
| 931 | + embs_to_backprop, grads_to_use, invalid_features = [], [], [] |
| 932 | + assert len(embedding_features) == len(emb_tensors) |
| 933 | + for features, tensor, grad in zip( |
| 934 | + embedding_features, emb_tensors, grads |
| 935 | + ): |
| 936 | + if tensor.numel() == 1 or grad is not None: |
| 937 | + embs_to_backprop.append(tensor) |
| 938 | + grads_to_use.append(grad) |
| 939 | + else: |
| 940 | + if isinstance(features, Iterable): |
| 941 | + invalid_features.extend(features) |
| 942 | + else: |
| 943 | + invalid_features.append(features) |
| 944 | + if invalid_features and context.index == 0: |
| 945 | + logger.warning( |
| 946 | + f"SemiSync, the following features have no gradients: {invalid_features}" |
| 947 | + ) |
| 948 | + torch.autograd.backward(embs_to_backprop, grads_to_use) |
925 | 949 |
|
926 | 950 | def copy_batch_to_gpu(
|
927 | 951 | self,
|
|
0 commit comments