Skip to content

Commit 8b06801

Browse files
tissue3facebook-github-bot
authored andcommitted
put outputdist to remote ro
Differential Revision: D76324557
1 parent 4e43395 commit 8b06801

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

torchrec/distributed/tensor_pool.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@
3535
from torchrec.modules.utils import deterministic_dedup
3636

3737

38+
@torch.fx.wrap
39+
def index_select_view(
40+
output: torch.Tensor,
41+
unbucketize_permute: torch.Tensor,
42+
dim: int,
43+
) -> torch.Tensor:
44+
return output[unbucketize_permute].view(-1, dim)
45+
46+
3847
class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]):
3948
def __init__(
4049
self,
@@ -441,7 +450,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
441450

442451
output = self._lookup_values_dist(lookup_list)
443452

444-
return output[unbucketize_permute].view(-1, self._dim)
453+
return index_select_view(output, unbucketize_permute, self._dim)
445454

446455
# pyre-ignore
447456
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):

0 commit comments

Comments
 (0)