Skip to content

Commit 1dff78a

Browse files
committed
Working version
1 parent 2daac20 commit 1dff78a

File tree

2 files changed

+14
-13
lines changed

2 files changed

+14
-13
lines changed

src/accelerate/accelerator.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,17 +1404,17 @@ def prepare(self, *args, device_placement=None):
14041404
fsdp2_should_fix_optimizer = self.is_fsdp2
14051405
should_fix_optimizer = tpu_should_fix_optimizer or fsdp2_should_fix_optimizer
14061406

1407-
# We need to specifically prepare AO (possibly other FP8 backends, haven't tested yet) here, as fsdp2 is very picky about the order of preparation
1407+
# We need to specifically prepare AO (possibly oter FP8 backends, haven't tested yet) here, as fsdp2 is very picky about the order of preparation
1408+
if self.is_fsdp2 and self.fp8_backend == "AO":
1409+
args = self._prepare_ao(*args)
14081410

1411+
# Compile needs to be done before gathering old params: investigate why?
14091412
if self.is_fsdp2 and model_index is not None:
14101413
new_args = list(args)
14111414

14121415
new_args[model_index] = compile_regions(new_args[model_index])
14131416
args = tuple(new_args)
14141417

1415-
if self.is_fsdp2 and self.fp8_backend == "AO":
1416-
args = self._prepare_ao(*args)
1417-
14181418
if should_fix_optimizer:
14191419
# 1. grabbing old model parameters
14201420
old_named_params = self._get_named_parameters(
@@ -1425,6 +1425,7 @@ def prepare(self, *args, device_placement=None):
14251425
# however that goes against `Accelerate's` design of `bring your own`
14261426
# this is a workaround to make memory footprint match if `Optimizer` is created before preparing the model
14271427
if fsdp2_should_fix_optimizer:
1428+
old_named_params = fsdp2_canonicalize_names(old_named_params)
14281429
for obj in args:
14291430
if isinstance(obj, torch.optim.Optimizer):
14301431
for param_group in obj.param_groups:
@@ -1758,11 +1759,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
17581759
if self.delayed_fp8_autocast:
17591760
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
17601761
# torch.compile should be called last and only if the model isn't already compiled
1761-
if (
1762-
self.state.dynamo_plugin.backend != DynamoBackend.NO
1763-
and not is_compiled_module(model)
1764-
and not self.is_fsdp2
1765-
):
1762+
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
17661763
if self.state.dynamo_plugin.use_regional_compilation:
17671764
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
17681765
else:
@@ -3574,6 +3571,7 @@ def clear(self, *objects):
35743571

35753572
def _get_named_parameters(self, *args, drop_refs=False):
35763573
named_parameters = {}
3574+
accessor_mapping = {}
35773575
for obj in args:
35783576
if isinstance(obj, torch.nn.Module):
35793577
obj = extract_model_from_parallel(obj)
@@ -3583,9 +3581,7 @@ def _get_named_parameters(self, *args, drop_refs=False):
35833581
if self.fp8_backend == "AO":
35843582
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
35853583

3586-
accessor_mapping = {
3587-
WeightWithDynamicFloat8CastTensor: "_tensor",
3588-
}
3584+
accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
35893585

35903586
named_parameters.update(
35913587
{

src/accelerate/utils/fsdp_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from .constants import FSDP_MODEL_NAME, OPTIMIZER_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_NAME
2828
from .dataclasses import get_module_class_from_name
2929
from .modeling import get_non_persistent_buffers, is_peft_model
30-
from .other import get_module_children_bottom_up, is_compiled_module, save
30+
from .other import compile_regions, get_module_children_bottom_up, is_compiled_module, save
3131
from .versions import is_torch_version
3232

3333

@@ -615,6 +615,10 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
615615
"mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(),
616616
}
617617

618+
# This is slow and high mem usage???
619+
# model = torch.compile(model)
620+
# model = compile_regions(model)
621+
618622
model_has_params4bit = False
619623
for name, param in model.named_parameters():
620624
# this is a temporary fix whereby loading models with bnb params cannot be moved from
@@ -769,4 +773,5 @@ def fsdp2_canonicalize_names(named_params: dict) -> dict:
769773
named_params = {
770774
k.replace("_orig_mod.", "") if k.startswith("_orig_mod.") else k: v for k, v in named_params.items()
771775
}
776+
named_params = {k.replace("._orig_mod", ""): v for k, v in named_params.items()}
772777
return named_params

0 commit comments

Comments
 (0)