Commit d2a3e56
DistributedModelParallel resharding Interface (#2945)
Summary:
Pull Request resolved: #2945
Finally! DMP interface for resharding, most of the changes here are to enable proper testing of DMP.
## Main changes:
### 1. DMP reshard API:
* which calls the underlying sharder for sharded module to reshard
### 2. Proper Testing:
* A multi-rank test which generates a full Model and utilizes DMP interface. Currently only tests TW.
* This test is called from `test_dynamic_sharding.py` -> `test_model_parallel.py` -> `test_sharding.py`, which follows the same structure as current DMP unit tests
* This is how the test tests for correctness:
```
1. Generate global model and inputs
2. Create 2 identical local models based on global model
3. Use planner to generate sharding plan for local model
4. Based on planner output, generate a second, different sharding plan
5. Shard both local models 1 and 2 through DMP with plan 1 and 2 respectively
6. Reshard (dynamic sharding API) model 1 with plan 2
7. Generate predictions for local models and compare them to global model prediction. Expect to be the same.
```
* This tests for `optimzier` being correctly saved in resharding
* The test is setup with other variables to-be-set once more functionalities are enabled with dynamic sharding, e.g. `variable_batch_size` etc.
### 3. Helper functions for testing
* `get_sharding_constructor_from_type` to enable setting sharding_type for each unit test.
* `compare_model_pred_one_step` only used for debugging to get more information on whether models are identical after resharding/running initial step
* `compare_model_weights` also for debugging
### 4. Bug fixes in `update_shards` call.
* namely input dist was not properly updated - this will cause error when I am testing the reshard function in the *middle of training*. As input dist depends on the shard placements.
Reviewed By: aliafzal
Differential Revision: D73049934
fbshipit-source-id: b0fad46b4fd846c204ebaba75a6f3984e74b26231 parent 714d996 commit d2a3e56
File tree
7 files changed
+875
-47
lines changed- torchrec/distributed
- sharding
- test_utils
- tests
7 files changed
+875
-47
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1531 | 1531 | | |
1532 | 1532 | | |
1533 | 1533 | | |
1534 | | - | |
1535 | 1534 | | |
1536 | | - | |
1537 | | - | |
1538 | | - | |
1539 | | - | |
1540 | | - | |
1541 | | - | |
1542 | | - | |
| 1535 | + | |
| 1536 | + | |
1543 | 1537 | | |
1544 | 1538 | | |
1545 | 1539 | | |
| |||
1603 | 1597 | | |
1604 | 1598 | | |
1605 | 1599 | | |
| 1600 | + | |
| 1601 | + | |
| 1602 | + | |
| 1603 | + | |
| 1604 | + | |
| 1605 | + | |
1606 | 1606 | | |
1607 | 1607 | | |
1608 | 1608 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
35 | 35 | | |
36 | 36 | | |
37 | 37 | | |
| 38 | + | |
38 | 39 | | |
39 | 40 | | |
40 | 41 | | |
| |||
612 | 613 | | |
613 | 614 | | |
614 | 615 | | |
| 616 | + | |
| 617 | + | |
| 618 | + | |
| 619 | + | |
| 620 | + | |
| 621 | + | |
| 622 | + | |
| 623 | + | |
| 624 | + | |
| 625 | + | |
| 626 | + | |
| 627 | + | |
| 628 | + | |
| 629 | + | |
| 630 | + | |
| 631 | + | |
| 632 | + | |
| 633 | + | |
| 634 | + | |
| 635 | + | |
| 636 | + | |
| 637 | + | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
615 | 694 | | |
616 | 695 | | |
617 | 696 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
| 18 | + | |
17 | 19 | | |
18 | 20 | | |
19 | 21 | | |
| |||
364 | 366 | | |
365 | 367 | | |
366 | 368 | | |
| 369 | + | |
| 370 | + | |
| 371 | + | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
410 | 410 | | |
411 | 411 | | |
412 | 412 | | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
| 420 | + | |
| 421 | + | |
| 422 | + | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
413 | 427 | | |
414 | 428 | | |
415 | 429 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
| 25 | + | |
25 | 26 | | |
26 | 27 | | |
27 | 28 | | |
| |||
190 | 191 | | |
191 | 192 | | |
192 | 193 | | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
193 | 266 | | |
194 | 267 | | |
195 | 268 | | |
| |||
0 commit comments