5555from torchrec .distributed .embedding_kernel import (
5656 BaseEmbedding ,
5757 create_virtual_sharded_tensors ,
58+ create_virtual_table_local_metadata ,
5859 get_state_dict ,
5960)
6061from torchrec .distributed .embedding_types import (
@@ -206,7 +207,9 @@ def _populate_zero_collision_tbe_params(
206207 bucket_sizes : List [int ] = [size for _ , _ , size in sharded_local_buckets ]
207208
208209 tbe_params ["kv_zch_params" ] = KVZCHParams (
209- bucket_offsets = bucket_offsets , bucket_sizes = bucket_sizes
210+ bucket_offsets = bucket_offsets ,
211+ bucket_sizes = bucket_sizes ,
212+ enable_optimizer_offloading = False ,
210213 )
211214
212215
@@ -283,6 +286,53 @@ def __init__( # noqa C901
283286 table_name_to_weight_count_per_rank
284287 )
285288
289+ # pyre-ignore [33]
290+ state : Dict [Any , Any ] = {}
291+ param_group : Dict [str , Any ] = {
292+ "params" : [],
293+ "lr" : emb_module .get_learning_rate (),
294+ }
295+
296+ params : Dict [str , Union [torch .Tensor , ShardedTensor ]] = {}
297+
298+ sorted_id_tensors = (
299+ [
300+ sharded_t ._local_shards [0 ].tensor
301+ for sharded_t in self ._sharded_embedding_weight_ids
302+ ]
303+ if self ._sharded_embedding_weight_ids is not None
304+ else None
305+ )
306+
307+ all_optimizer_states = emb_module .get_optimizer_state (
308+ sorted_id_tensor = sorted_id_tensors
309+ )
310+ opt_param_list = [param ["momentum1" ] for param in all_optimizer_states ]
311+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
312+ for emb_table in emb_table_config_copy :
313+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
314+ opt_sharded_t_list = create_virtual_sharded_tensors (
315+ emb_table_config_copy , opt_param_list , self ._pg
316+ )
317+
318+ for (
319+ emb_config ,
320+ sharded_weight ,
321+ opt_sharded_t ,
322+ ) in zip (
323+ emb_table_config_copy ,
324+ sharded_embedding_weights_by_table ,
325+ opt_sharded_t_list ,
326+ ):
327+ param_key = emb_config .name + ".weight"
328+ state [sharded_weight ] = {}
329+ param_group ["params" ].append (sharded_weight )
330+ params [param_key ] = sharded_weight
331+
332+ state [sharded_weight ][f"{ emb_config .name } .momentum1" ] = opt_sharded_t
333+
334+ super ().__init__ (params , state , [param_group ])
335+
286336 def zero_grad (self , set_to_none : bool = False ) -> None :
287337 # pyre-ignore [16]
288338 self ._emb_module .set_learning_rate (self .param_groups [0 ]["lr" ])
@@ -292,6 +342,61 @@ def step(self, closure: Any = None) -> None:
292342 # pyre-ignore [16]
293343 self ._emb_module .set_learning_rate (self .param_groups [0 ]["lr" ])
294344
345+ def set_sharded_embedding_weight_ids (
346+ self , sharded_embedding_weight_ids : Optional [List [ShardedTensor ]]
347+ ) -> None :
348+ self ._sharded_embedding_weight_ids = sharded_embedding_weight_ids
349+
350+ def _post_state_dict_hook (self , curr_state : Dict [str , Any ]) -> None :
351+ logger .info ("update optimizer state dict in state_dict_post_hook" )
352+ embedding_weight_ids = (
353+ [
354+ sharded_t ._local_shards [0 ].tensor
355+ for sharded_t in self ._sharded_embedding_weight_ids
356+ ]
357+ if self ._sharded_embedding_weight_ids is not None
358+ else None
359+ )
360+ all_optimizer_states = self ._emb_module .get_optimizer_state (
361+ embedding_weight_ids ,
362+ no_snapshot = False ,
363+ should_flush = False , # get embedding weights already flushed, no need to flush again here
364+ )
365+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
366+ for emb_table in emb_table_config_copy :
367+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
368+
369+ # The order of table_config is determined so put it as outer-loop for consistent traverse order across ranks
370+ for table_config , opt_states in zip (
371+ emb_table_config_copy ,
372+ all_optimizer_states ,
373+ ):
374+ for key , sharded_t_dict in curr_state .items ():
375+ # update zero collision table's optimizer state
376+ if f".{ table_config .name } .weight" in key :
377+ for (_ , opt_state_t ), (sharded_t_k , sharded_t ) in zip (
378+ opt_states .items (), sharded_t_dict .items ()
379+ ):
380+ logger .info (
381+ f"update optimizer state for table { table_config .name } with state shape { opt_state_t .shape } , rank={ self ._my_rank } , weight_count_per_rank={ self ._table_name_to_weight_count_per_rank .get (table_config .name , None )} "
382+ )
383+ sharded_t .local_shards ()[0 ].tensor = opt_state_t
384+ create_virtual_table_local_metadata (
385+ # pyre-ignore [6]
386+ table_config .local_metadata ,
387+ opt_state_t ,
388+ self ._my_rank ,
389+ )
390+ for shard in sharded_t .local_shards ():
391+ shard .metadata = table_config .local_metadata
392+ new_sharded_t = ShardedTensor ._init_from_local_shards (
393+ sharded_t .local_shards (),
394+ None ,
395+ None ,
396+ process_group = self ._pg ,
397+ )
398+ sharded_t_dict [sharded_t_k ] = new_sharded_t
399+
295400
296401class EmbeddingFusedOptimizer (FusedOptimizer ):
297402 def __init__ ( # noqa C901
@@ -756,7 +861,6 @@ def _gen_named_parameters_by_table_fused(
756861 table_count = table_name_to_count .pop (table_name )
757862 if emb_module .weights_precision == SparseType .INT8 :
758863 dim += emb_module .int8_emb_row_dim_offset
759- # pyre-fixme[29]: `Union[(self: TensorBase, indices: Union[None, _NestedSeque...
760864 offset = emb_module .weights_physical_offsets [t_idx ]
761865 weights : torch .Tensor
762866 if location == EmbeddingLocation .DEVICE .value :
@@ -1330,7 +1434,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
13301434 return
13311435
13321436 pmt_list , weight_ids_list , bucket_cnt_list = self .split_embedding_weights (
1333- no_snapshot = False
1437+ no_snapshot = False , should_flush = True
13341438 )
13351439 emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
13361440 for emb_table in emb_table_config_copy :
@@ -1381,12 +1485,16 @@ def purge(self) -> None:
13811485 self .emb_module .lxu_cache_state .fill_ (- 1 )
13821486
13831487 # pyre-ignore [15]
1384- def split_embedding_weights (self , no_snapshot : bool = True ) -> Tuple [
1488+ def split_embedding_weights (
1489+ self , no_snapshot : bool = True , should_flush : bool = False
1490+ ) -> Tuple [
13851491 List [PartiallyMaterializedTensor ],
13861492 Optional [List [torch .Tensor ]],
13871493 Optional [List [torch .Tensor ]],
13881494 ]:
1389- return self .emb_module .split_embedding_weights (no_snapshot )
1495+ return self .emb_module .split_embedding_weights (
1496+ no_snapshot , should_flush = should_flush
1497+ )
13901498
13911499 def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
13921500 # reset split weights during training
0 commit comments