@@ -1409,31 +1409,31 @@ def prepare(self, *args, device_placement=None):
1409
1409
args = self ._prepare_ao (* args )
1410
1410
1411
1411
# Compile needs to be done before gathering old params: investigate why?
1412
- if self .is_fsdp2 and model_index is not None :
1413
- new_args = list (args )
1412
+ # if self.is_fsdp2 and model_index is not None:
1413
+ # new_args = list(args)
1414
1414
1415
- new_args [model_index ] = compile_regions (new_args [model_index ])
1416
- args = tuple (new_args )
1415
+ # new_args[model_index] = compile_regions(new_args[model_index])
1416
+ # args = tuple(new_args)
1417
1417
1418
- if should_fix_optimizer :
1419
- # 1. grabbing old model parameters
1420
- old_named_params = self ._get_named_parameters (
1421
- * args , drop_refs = fsdp2_should_fix_optimizer
1422
- ) # Drop refs for FSDP2, to enable reallocation of parameters further in `fully_shard`
1418
+ # if should_fix_optimizer:
1419
+ # # 1. grabbing old model parameters
1420
+ # old_named_params = self._get_named_parameters(
1421
+ # *args, drop_refs=fsdp2_should_fix_optimizer
1422
+ # ) # Drop refs for FSDP2, to enable reallocation of parameters further in `fully_shard`
1423
1423
1424
1424
# `FSDP2` by default expects `Optimizer` to be created after the model is prepared,
1425
1425
# however that goes against `Accelerate's` design of `bring your own`
1426
1426
# this is a workaround to make memory footprint match if `Optimizer` is created before preparing the model
1427
- if fsdp2_should_fix_optimizer :
1428
- old_named_params = fsdp2_canonicalize_names (old_named_params )
1429
- for obj in args :
1430
- if isinstance (obj , torch .optim .Optimizer ):
1431
- for param_group in obj .param_groups :
1432
- for i , p in enumerate (param_group ["params" ]):
1433
- # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
1434
- # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
1435
- param_group ["params" ][i ] = torch .empty_like (p )
1436
- param_group ["params" ][i ].data_ptr = p .data_ptr ()
1427
+ # if fsdp2_should_fix_optimizer:
1428
+ # old_named_params = fsdp2_canonicalize_names(old_named_params)
1429
+ # for obj in args:
1430
+ # if isinstance(obj, torch.optim.Optimizer):
1431
+ # for param_group in obj.param_groups:
1432
+ # for i, p in enumerate(param_group["params"]):
1433
+ # # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
1434
+ # # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
1435
+ # param_group["params"][i] = torch.empty_like(p)
1436
+ # param_group["params"][i].data_ptr = p.data_ptr()
1437
1437
1438
1438
if self .distributed_type in [DistributedType .MULTI_CPU , DistributedType .MULTI_XPU , DistributedType .NO ]:
1439
1439
if (self .device .type == "cpu" or self .device .type == "xpu" ) and self .state .use_ipex :
@@ -1446,27 +1446,29 @@ def prepare(self, *args, device_placement=None):
1446
1446
result = self ._prepare_deepspeed (* args )
1447
1447
elif self .distributed_type == DistributedType .MEGATRON_LM :
1448
1448
result = self ._prepare_megatron_lm (* args )
1449
+ elif self .is_fsdp2 :
1450
+ result = self ._prepare_fsdp2 (* args )
1449
1451
else :
1450
1452
if self .fp8_backend == "MSAMP" :
1451
1453
args , device_placement = self ._prepare_msamp (* args , device_placement = device_placement )
1452
1454
result = tuple (
1453
1455
self ._prepare_one (obj , first_pass = True , device_placement = d ) for obj , d in zip (args , device_placement )
1454
1456
)
1455
1457
result = tuple (self ._prepare_one (obj , device_placement = d ) for obj , d in zip (result , device_placement ))
1456
- if should_fix_optimizer :
1457
- # 2. grabbing new model parameters
1458
- new_named_params = self ._get_named_parameters (* result )
1459
- if fsdp2_should_fix_optimizer :
1460
- new_named_params = fsdp2_canonicalize_names (new_named_params )
1461
- # 3. building a map from the first to the second
1462
- mapping = {p : new_named_params [n ] for n , p in old_named_params .items ()}
1463
- # 4. using that map to update the parameters of the optimizer
1464
- for obj in result :
1465
- if isinstance (obj , torch .optim .Optimizer ):
1466
- if not fsdp2_should_fix_optimizer :
1467
- obj ._switch_parameters (mapping )
1468
- else :
1469
- fsdp2_switch_optimizer_parameters (obj , mapping )
1458
+ # if should_fix_optimizer:
1459
+ # # 2. grabbing new model parameters
1460
+ # new_named_params = self._get_named_parameters(*result)
1461
+ # if fsdp2_should_fix_optimizer:
1462
+ # new_named_params = fsdp2_canonicalize_names(new_named_params)
1463
+ # # 3. building a map from the first to the second
1464
+ # mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1465
+ # # 4. using that map to update the parameters of the optimizer
1466
+ # for obj in result:
1467
+ # if isinstance(obj, torch.optim.Optimizer):
1468
+ # if not fsdp2_should_fix_optimizer:
1469
+ # obj._switch_parameters(mapping)
1470
+ # else:
1471
+ # fsdp2_switch_optimizer_parameters(obj, mapping)
1470
1472
1471
1473
for item in result :
1472
1474
if any (
@@ -1477,6 +1479,79 @@ def prepare(self, *args, device_placement=None):
1477
1479
1478
1480
return result if len (result ) > 1 else result [0 ]
1479
1481
1482
+ def _prepare_fsdp2 (self , * args ):
1483
+ _custom_prepare_classes = (
1484
+ torch .nn .Module ,
1485
+ torch .optim .Optimizer ,
1486
+ )
1487
+ device_placement = [None for _ in args ]
1488
+
1489
+ result = [
1490
+ self ._prepare_one (obj , first_pass = True , device_placement = d )
1491
+ if not isinstance (obj , _custom_prepare_classes )
1492
+ else obj
1493
+ for obj , d in zip (args , device_placement )
1494
+ ]
1495
+
1496
+ result = tuple (
1497
+ self ._prepare_one (obj , device_placement = d ) if not isinstance (obj , _custom_prepare_classes ) else obj
1498
+ for obj , d in zip (result , device_placement )
1499
+ )
1500
+
1501
+ models = []
1502
+ optimizers = []
1503
+
1504
+ for i , obj in enumerate (result ):
1505
+ if isinstance (obj , torch .nn .Module ):
1506
+ models .append ((i , obj ))
1507
+ elif isinstance (obj , torch .optim .Optimizer ):
1508
+ optimizers .append ((i , obj ))
1509
+
1510
+ if len (optimizers ) <= 0 and len (models ) <= 0 :
1511
+ return result
1512
+
1513
+ model_index , model = models [0 ]
1514
+ optimizer_index , optimizer = optimizers [0 ]
1515
+
1516
+ new_result = list (result )
1517
+
1518
+ new_result [model_index ] = compile_regions (result [model_index ])
1519
+ result = new_result
1520
+ # result = tuple(new_result)
1521
+
1522
+ old_named_params = self ._get_named_parameters (* tuple (result ), drop_refs = True )
1523
+
1524
+ old_named_params = fsdp2_canonicalize_names (old_named_params )
1525
+ for obj in result :
1526
+ if isinstance (obj , torch .optim .Optimizer ):
1527
+ for param_group in obj .param_groups :
1528
+ for i , p in enumerate (param_group ["params" ]):
1529
+ # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
1530
+ # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
1531
+ param_group ["params" ][i ] = torch .empty_like (p )
1532
+ param_group ["params" ][i ].data_ptr = p .data_ptr ()
1533
+
1534
+ self ._models .append (model )
1535
+
1536
+ model = fsdp2_prepare_model (self , model )
1537
+
1538
+ if len (self ._models ) > 1 and (self ._models [- 2 ] is self ._models [- 1 ]):
1539
+ del self ._models [- 2 ]
1540
+
1541
+ optimizer = self ._prepare_one (optimizer , device_placement = device_placement [optimizer_index ])
1542
+ result [optimizer_index ] = optimizer
1543
+
1544
+ new_named_params = self ._get_named_parameters (* result )
1545
+ new_named_params = fsdp2_canonicalize_names (new_named_params )
1546
+ # 3. building a map from the first to the second
1547
+ mapping = {p : new_named_params [n ] for n , p in old_named_params .items ()}
1548
+ # 4. using that map to update the parameters of the optimizer
1549
+ for obj in result :
1550
+ if isinstance (obj , torch .optim .Optimizer ):
1551
+ fsdp2_switch_optimizer_parameters (obj , mapping )
1552
+
1553
+ return result
1554
+
1480
1555
def prepare_model (self , model : torch .nn .Module , device_placement : bool = None , evaluation_mode : bool = False ):
1481
1556
"""
1482
1557
Prepares a PyTorch model for training in any distributed setup. It is recommended to use
0 commit comments