File tree Expand file tree Collapse file tree 1 file changed +7
-10
lines changed Expand file tree Collapse file tree 1 file changed +7
-10
lines changed Original file line number Diff line number Diff line change @@ -1172,16 +1172,13 @@ def _create_output_dist(self) -> None:
1172
1172
for i , name in enumerate (self ._uncombined_embedding_names ):
1173
1173
embedding_name_order .setdefault (name , i )
1174
1174
1175
- def sort_key (input : Tuple [int , str ]) -> Tuple [int , int ]:
1176
- index , name = input
1177
- return (embedding_name_order [name ], embedding_shard_offsets [index ])
1178
-
1179
- permute_indices = [
1180
- i
1181
- for i , _ in sorted (
1182
- enumerate (self ._uncombined_embedding_names ), key = sort_key
1183
- )
1184
- ]
1175
+ permute_indices = sorted (
1176
+ range (len (self ._uncombined_embedding_names )),
1177
+ key = lambda i : (
1178
+ embedding_name_order [self ._uncombined_embedding_names [i ]],
1179
+ embedding_shard_offsets [i ],
1180
+ ),
1181
+ )
1185
1182
self ._permute_op : PermutePooledEmbeddings = PermutePooledEmbeddings (
1186
1183
self ._uncombined_embedding_dims , permute_indices , self ._device
1187
1184
)
You can’t perform that action at this time.
0 commit comments