Skip to content

Commit cf48289

Browse files
committed
Partial rewrite
1 parent 1dff78a commit cf48289

File tree

1 file changed

+108
-33
lines changed

1 file changed

+108
-33
lines changed

src/accelerate/accelerator.py

Lines changed: 108 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,31 +1409,31 @@ def prepare(self, *args, device_placement=None):
14091409
args = self._prepare_ao(*args)
14101410

14111411
# Compile needs to be done before gathering old params: investigate why?
1412-
if self.is_fsdp2 and model_index is not None:
1413-
new_args = list(args)
1412+
# if self.is_fsdp2 and model_index is not None:
1413+
# new_args = list(args)
14141414

1415-
new_args[model_index] = compile_regions(new_args[model_index])
1416-
args = tuple(new_args)
1415+
# new_args[model_index] = compile_regions(new_args[model_index])
1416+
# args = tuple(new_args)
14171417

1418-
if should_fix_optimizer:
1419-
# 1. grabbing old model parameters
1420-
old_named_params = self._get_named_parameters(
1421-
*args, drop_refs=fsdp2_should_fix_optimizer
1422-
) # Drop refs for FSDP2, to enable reallocation of parameters further in `fully_shard`
1418+
# if should_fix_optimizer:
1419+
# # 1. grabbing old model parameters
1420+
# old_named_params = self._get_named_parameters(
1421+
# *args, drop_refs=fsdp2_should_fix_optimizer
1422+
# ) # Drop refs for FSDP2, to enable reallocation of parameters further in `fully_shard`
14231423

14241424
# `FSDP2` by default expects `Optimizer` to be created after the model is prepared,
14251425
# however that goes against `Accelerate's` design of `bring your own`
14261426
# this is a workaround to make memory footprint match if `Optimizer` is created before preparing the model
1427-
if fsdp2_should_fix_optimizer:
1428-
old_named_params = fsdp2_canonicalize_names(old_named_params)
1429-
for obj in args:
1430-
if isinstance(obj, torch.optim.Optimizer):
1431-
for param_group in obj.param_groups:
1432-
for i, p in enumerate(param_group["params"]):
1433-
# We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
1434-
# We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
1435-
param_group["params"][i] = torch.empty_like(p)
1436-
param_group["params"][i].data_ptr = p.data_ptr()
1427+
# if fsdp2_should_fix_optimizer:
1428+
# old_named_params = fsdp2_canonicalize_names(old_named_params)
1429+
# for obj in args:
1430+
# if isinstance(obj, torch.optim.Optimizer):
1431+
# for param_group in obj.param_groups:
1432+
# for i, p in enumerate(param_group["params"]):
1433+
# # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
1434+
# # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
1435+
# param_group["params"][i] = torch.empty_like(p)
1436+
# param_group["params"][i].data_ptr = p.data_ptr()
14371437

14381438
if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
14391439
if (self.device.type == "cpu" or self.device.type == "xpu") and self.state.use_ipex:
@@ -1446,27 +1446,29 @@ def prepare(self, *args, device_placement=None):
14461446
result = self._prepare_deepspeed(*args)
14471447
elif self.distributed_type == DistributedType.MEGATRON_LM:
14481448
result = self._prepare_megatron_lm(*args)
1449+
elif self.is_fsdp2:
1450+
result = self._prepare_fsdp2(*args)
14491451
else:
14501452
if self.fp8_backend == "MSAMP":
14511453
args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
14521454
result = tuple(
14531455
self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
14541456
)
14551457
result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
1456-
if should_fix_optimizer:
1457-
# 2. grabbing new model parameters
1458-
new_named_params = self._get_named_parameters(*result)
1459-
if fsdp2_should_fix_optimizer:
1460-
new_named_params = fsdp2_canonicalize_names(new_named_params)
1461-
# 3. building a map from the first to the second
1462-
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1463-
# 4. using that map to update the parameters of the optimizer
1464-
for obj in result:
1465-
if isinstance(obj, torch.optim.Optimizer):
1466-
if not fsdp2_should_fix_optimizer:
1467-
obj._switch_parameters(mapping)
1468-
else:
1469-
fsdp2_switch_optimizer_parameters(obj, mapping)
1458+
# if should_fix_optimizer:
1459+
# # 2. grabbing new model parameters
1460+
# new_named_params = self._get_named_parameters(*result)
1461+
# if fsdp2_should_fix_optimizer:
1462+
# new_named_params = fsdp2_canonicalize_names(new_named_params)
1463+
# # 3. building a map from the first to the second
1464+
# mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1465+
# # 4. using that map to update the parameters of the optimizer
1466+
# for obj in result:
1467+
# if isinstance(obj, torch.optim.Optimizer):
1468+
# if not fsdp2_should_fix_optimizer:
1469+
# obj._switch_parameters(mapping)
1470+
# else:
1471+
# fsdp2_switch_optimizer_parameters(obj, mapping)
14701472

14711473
for item in result:
14721474
if any(
@@ -1477,6 +1479,79 @@ def prepare(self, *args, device_placement=None):
14771479

14781480
return result if len(result) > 1 else result[0]
14791481

1482+
def _prepare_fsdp2(self, *args):
1483+
_custom_prepare_classes = (
1484+
torch.nn.Module,
1485+
torch.optim.Optimizer,
1486+
)
1487+
device_placement = [None for _ in args]
1488+
1489+
result = [
1490+
self._prepare_one(obj, first_pass=True, device_placement=d)
1491+
if not isinstance(obj, _custom_prepare_classes)
1492+
else obj
1493+
for obj, d in zip(args, device_placement)
1494+
]
1495+
1496+
result = tuple(
1497+
self._prepare_one(obj, device_placement=d) if not isinstance(obj, _custom_prepare_classes) else obj
1498+
for obj, d in zip(result, device_placement)
1499+
)
1500+
1501+
models = []
1502+
optimizers = []
1503+
1504+
for i, obj in enumerate(result):
1505+
if isinstance(obj, torch.nn.Module):
1506+
models.append((i, obj))
1507+
elif isinstance(obj, torch.optim.Optimizer):
1508+
optimizers.append((i, obj))
1509+
1510+
if len(optimizers) <= 0 and len(models) <= 0:
1511+
return result
1512+
1513+
model_index, model = models[0]
1514+
optimizer_index, optimizer = optimizers[0]
1515+
1516+
new_result = list(result)
1517+
1518+
new_result[model_index] = compile_regions(result[model_index])
1519+
result = new_result
1520+
# result = tuple(new_result)
1521+
1522+
old_named_params = self._get_named_parameters(*tuple(result), drop_refs=True)
1523+
1524+
old_named_params = fsdp2_canonicalize_names(old_named_params)
1525+
for obj in result:
1526+
if isinstance(obj, torch.optim.Optimizer):
1527+
for param_group in obj.param_groups:
1528+
for i, p in enumerate(param_group["params"]):
1529+
# We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
1530+
# We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
1531+
param_group["params"][i] = torch.empty_like(p)
1532+
param_group["params"][i].data_ptr = p.data_ptr()
1533+
1534+
self._models.append(model)
1535+
1536+
model = fsdp2_prepare_model(self, model)
1537+
1538+
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
1539+
del self._models[-2]
1540+
1541+
optimizer = self._prepare_one(optimizer, device_placement=device_placement[optimizer_index])
1542+
result[optimizer_index] = optimizer
1543+
1544+
new_named_params = self._get_named_parameters(*result)
1545+
new_named_params = fsdp2_canonicalize_names(new_named_params)
1546+
# 3. building a map from the first to the second
1547+
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
1548+
# 4. using that map to update the parameters of the optimizer
1549+
for obj in result:
1550+
if isinstance(obj, torch.optim.Optimizer):
1551+
fsdp2_switch_optimizer_parameters(obj, mapping)
1552+
1553+
return result
1554+
14801555
def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
14811556
"""
14821557
Prepares a PyTorch model for training in any distributed setup. It is recommended to use

0 commit comments

Comments
 (0)