@@ -1404,17 +1404,17 @@ def prepare(self, *args, device_placement=None):
1404
1404
fsdp2_should_fix_optimizer = self .is_fsdp2
1405
1405
should_fix_optimizer = tpu_should_fix_optimizer or fsdp2_should_fix_optimizer
1406
1406
1407
- # We need to specifically prepare AO (possibly other FP8 backends, haven't tested yet) here, as fsdp2 is very picky about the order of preparation
1407
+ # We need to specifically prepare AO (possibly oter FP8 backends, haven't tested yet) here, as fsdp2 is very picky about the order of preparation
1408
+ if self .is_fsdp2 and self .fp8_backend == "AO" :
1409
+ args = self ._prepare_ao (* args )
1408
1410
1411
+ # Compile needs to be done before gathering old params: investigate why?
1409
1412
if self .is_fsdp2 and model_index is not None :
1410
1413
new_args = list (args )
1411
1414
1412
1415
new_args [model_index ] = compile_regions (new_args [model_index ])
1413
1416
args = tuple (new_args )
1414
1417
1415
- if self .is_fsdp2 and self .fp8_backend == "AO" :
1416
- args = self ._prepare_ao (* args )
1417
-
1418
1418
if should_fix_optimizer :
1419
1419
# 1. grabbing old model parameters
1420
1420
old_named_params = self ._get_named_parameters (
@@ -1425,6 +1425,7 @@ def prepare(self, *args, device_placement=None):
1425
1425
# however that goes against `Accelerate's` design of `bring your own`
1426
1426
# this is a workaround to make memory footprint match if `Optimizer` is created before preparing the model
1427
1427
if fsdp2_should_fix_optimizer :
1428
+ old_named_params = fsdp2_canonicalize_names (old_named_params )
1428
1429
for obj in args :
1429
1430
if isinstance (obj , torch .optim .Optimizer ):
1430
1431
for param_group in obj .param_groups :
@@ -1758,11 +1759,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
1758
1759
if self .delayed_fp8_autocast :
1759
1760
model = apply_fp8_autowrap (model , self .te_recipe_handler or self .fp8_recipe_handler )
1760
1761
# torch.compile should be called last and only if the model isn't already compiled
1761
- if (
1762
- self .state .dynamo_plugin .backend != DynamoBackend .NO
1763
- and not is_compiled_module (model )
1764
- and not self .is_fsdp2
1765
- ):
1762
+ if self .state .dynamo_plugin .backend != DynamoBackend .NO and not is_compiled_module (model ):
1766
1763
if self .state .dynamo_plugin .use_regional_compilation :
1767
1764
model = compile_regions (model , ** self .state .dynamo_plugin .to_kwargs ())
1768
1765
else :
@@ -3574,6 +3571,7 @@ def clear(self, *objects):
3574
3571
3575
3572
def _get_named_parameters (self , * args , drop_refs = False ):
3576
3573
named_parameters = {}
3574
+ accessor_mapping = {}
3577
3575
for obj in args :
3578
3576
if isinstance (obj , torch .nn .Module ):
3579
3577
obj = extract_model_from_parallel (obj )
@@ -3583,9 +3581,7 @@ def _get_named_parameters(self, *args, drop_refs=False):
3583
3581
if self .fp8_backend == "AO" :
3584
3582
from torchao .float8 .fsdp_utils import WeightWithDynamicFloat8CastTensor
3585
3583
3586
- accessor_mapping = {
3587
- WeightWithDynamicFloat8CastTensor : "_tensor" ,
3588
- }
3584
+ accessor_mapping [WeightWithDynamicFloat8CastTensor ] = "_tensor"
3589
3585
3590
3586
named_parameters .update (
3591
3587
{
0 commit comments