File tree Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Expand file tree Collapse file tree 1 file changed +10
-1
lines changed Original file line number Diff line number Diff line change 35
35
from torchrec .modules .utils import deterministic_dedup
36
36
37
37
38
+ @torch .fx .wrap
39
+ def index_select_view (
40
+ output : torch .Tensor ,
41
+ unbucketize_permute : torch .Tensor ,
42
+ dim : int ,
43
+ ) -> torch .Tensor :
44
+ return output [unbucketize_permute ].view (- 1 , dim )
45
+
46
+
38
47
class TensorPoolAwaitable (LazyAwaitable [torch .Tensor ]):
39
48
def __init__ (
40
49
self ,
@@ -441,7 +450,7 @@ def forward(self, ids: torch.Tensor) -> torch.Tensor:
441
450
442
451
output = self ._lookup_values_dist (lookup_list )
443
452
444
- return output [ unbucketize_permute ]. view ( - 1 , self ._dim )
453
+ return index_select_view ( output , unbucketize_permute , self ._dim )
445
454
446
455
# pyre-ignore
447
456
def _update_values_dist (self , ctx : ObjectPoolShardingContext , values : torch .Tensor ):
You can’t perform that action at this time.
0 commit comments