-
Notifications
You must be signed in to change notification settings - Fork 525
Delta tracker DMP integration #3064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
This pull request was exported from Phabricator. Differential Revision: D76202371 |
818dd82
to
2f65757
Compare
Summary: ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
Summary: ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
2f65757
to
f726e9e
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D76202371 |
f726e9e
to
1785a5e
Compare
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
This pull request was exported from Phabricator. Differential Revision: D76202371 |
1785a5e
to
9505130
Compare
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
9505130
to
c04f2da
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
c04f2da
to
61b208e
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
61b208e
to
678b3cf
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
678b3cf
to
57dc963
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
57dc963
to
eb3a4ab
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
da5577e
to
fd63a1f
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
fd63a1f
to
c41f3ee
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
c41f3ee
to
ce3ff0e
Compare
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_hook in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_hook for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom hook provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking hooks / Callable * Added hook registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
ce3ff0e
to
27f7722
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
27f7722
to
3748948
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
3748948
to
190845f
Compare
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
190845f
to
4905da2
Compare
This pull request was exported from Phabricator. Differential Revision: D76202371 |
4905da2
to
7fbc766
Compare
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker * Configurable options include tracking_mode, delete_on_read, auto_compact, and fqns_to_skip **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules * Implemented tracking support for different optimizer states (momentum, Adam states) ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
This pull request was exported from Phabricator. Differential Revision: D76202371 |
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
7fbc766
to
774f11e
Compare
Summary: Pull Request resolved: pytorch#3064 ## This Diff Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios. ### Key Components: **ModelTrackerConfig Integration**: * Added ModelTrackerConfig parameter to DMP constructor * When provided, automatically initializes ModelDeltaTracker **Custom Callables for Tracking**: * Added custom post_lookup_tracker_fn in ShardedModule to capture IDs and embeddings after lookup operations. This provides tracking ids/states natively into torchrec without registering any nn.Module specific hooks. * Added post_odist_tracker_fn for auto-compaction of tracked data. This custom callable provides native support for overlapping compaction with odist. * Implemented pre_forward callables in DMP for operations like batch index incrementation **Model Parallel API Enhancements**: * Added `get_model_tracker()` method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module. * Added `get_delta()` method as a convenience API to retrieve delta rows from dmp_module. **Embedding Module Changes**: * Enhanced ShardedEmbeddingBag and ShardedEmbedding to support tracking callable * Added callable registration methods in embedding modules ## ModelDeltaTracker Context ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for: 1. Identifying which embedding rows were accessed during model execution 2. Retrieving the latest delta or unique rows for a model 3. Computing top-k changed embeddings 4. Supporting streaming updated embeddings between systems during online training For more details see diff:D75853147 or PR pytorch#3057 Differential Revision: D76202371
This pull request was exported from Phabricator. Differential Revision: D76202371 |
774f11e
to
f4f132f
Compare
Summary:
This Diff
Adds ModelDeltaTracker integration with DMP (DistributedModelParallel) and sharded modules. This integration enables tracking of embedding IDs, embeddings, and optimizer states during model execution, which is particularly useful for online training scenarios.
Key Components:
ModelTrackerConfig Integration:
Custom Callables for Tracking:
Model Parallel API Enhancements:
get_model_tracker()
method to DistributedModelParallel for direct access to the ModelDeltaTracker instance. This API give the flexibility to integrate model tracker into required components directly without needing to access the dmp_module.get_delta()
method as a convenience API to retrieve delta rows from dmp_module.Embedding Module Changes:
ModelDeltaTracker Context
ModelDeltaTracker is a utility for tracking and retrieving unique IDs and their corresponding embeddings or states from embedding modules in model using Torchrec. It's particularly useful for:
For more details see diff:D75853147 or PR #3057
Differential Revision: D76202371