Skip to content

Commit 71271d1

Browse files
committed
feat: support fsdpv2 and fsdpv2 + tp
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent cb7696f commit 71271d1

File tree

4 files changed

+76
-94
lines changed

4 files changed

+76
-94
lines changed

src/accelerate/accelerator.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -405,10 +405,12 @@ def __init__(
405405
if not isinstance(torch_tp_plugin, TorchTensorParallelPlugin):
406406
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
407407
os.environ["ACCELERATE_USE_TP"] = "true"
408-
408+
409409
if fsdp2_plugin is None:
410410
fsdp2_plugin = (
411-
FullyShardedDataParallelPlugin2() if os.environ.get("ACCELERATE_USE_FSDP2", "false") == "true" else None
411+
FullyShardedDataParallelPlugin2()
412+
if os.environ.get("ACCELERATE_USE_FSDP2", "false") == "true"
413+
else None
412414
)
413415
else:
414416
if not isinstance(fsdp2_plugin, FullyShardedDataParallelPlugin2):
@@ -1513,7 +1515,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15131515
elif device_placement and not self.verify_device_map(model):
15141516
model = model.to(self.device)
15151517
if not evaluation_mode:
1516-
device_mesh = prepare_nd_device_mesh(self.state.torch_tp_plugin.tp_size if self.state.torch_tp_plugin is not None else 1, self.state.fsdp2_plugin is not None)
1518+
# motivation behind preparing device mesh at the start is to easily extend
1519+
# device preparation for any combination of parallelisms and pass it on
1520+
# neatly to respective parallelism distribution code snippets.
1521+
# function prepare_nd_device_mesh should be enough to extend logic for future combinations
1522+
# for now prepare_nd_device_mesh handles any combination of TP and FSDP/HSDP
1523+
device_mesh = prepare_nd_device_mesh(
1524+
self.state.torch_tp_plugin.tp_size if self.state.torch_tp_plugin is not None else 1,
1525+
self.state.fsdp2_plugin is not None,
1526+
)
15171527
if self.distributed_type in (
15181528
DistributedType.MULTI_GPU,
15191529
DistributedType.MULTI_MLU,
@@ -1548,6 +1558,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15481558
if self.distributed_type == DistributedType.FSDP2 or self.distributed_type == DistributedType.FSDP2_TP:
15491559
self.state.fsdp2_plugin.torch_device_mesh = device_mesh["dp", "fsdp"]
15501560
from torch.distributed._composable.fsdp import fully_shard, FSDPModule
1561+
15511562
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
15521563
# don't wrap it again
15531564
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
@@ -1558,7 +1569,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15581569
)
15591570

15601571
if not is_type_fsdp:
1561-
fsdp2_kwargs = {
1572+
fsdp2_kwargs = {
15621573
"mp_policy": self.state.fsdp2_plugin.mp_policy,
15631574
"reshard_after_forward": self.state.fsdp2_plugin.reshard_after_forward,
15641575
"offload_policy": self.state.fsdp2_plugin.offload_policy,
@@ -1570,25 +1581,14 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
15701581
for layer in model.model.layers:
15711582
fully_shard(layer, **fsdp2_kwargs)
15721583
fully_shard(model, **fsdp2_kwargs)
1584+
# if the previous and current models are same, delete the previous one
1585+
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
1586+
del self._models[-2]
1587+
self._models[-1] = model
15731588

15741589
#######
1575-
# does existing activation_checkpointing API work out of the box with FSDP2?
1590+
# TODO: support activation_checkpointing for FSDP2 and nd parallel cases
15761591
#######
1577-
# if fsdp_plugin.activation_checkpointing:
1578-
# from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1579-
# CheckpointImpl,
1580-
# apply_activation_checkpointing,
1581-
# checkpoint_wrapper,
1582-
# )
1583-
1584-
# apply_activation_checkpointing(
1585-
# model,
1586-
# checkpoint_wrapper_fn=functools.partial(
1587-
# checkpoint_wrapper,
1588-
# checkpoint_impl=CheckpointImpl.NO_REENTRANT,
1589-
# ),
1590-
# auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
1591-
# )
15921592

15931593
elif self.distributed_type == DistributedType.FSDP:
15941594
# We need to fix the optimizer *before* sharding the model
@@ -2227,7 +2227,10 @@ def prepare_data_loader(
22272227
if device_placement is None:
22282228
device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False
22292229

2230-
nd_device_mesh = prepare_nd_device_mesh(self.state.torch_tp_plugin.tp_size if self.state.torch_tp_plugin is not None else 1, self.state.fsdp2_plugin is not None)
2230+
nd_device_mesh = prepare_nd_device_mesh(
2231+
self.state.torch_tp_plugin.tp_size if self.state.torch_tp_plugin is not None else 1,
2232+
self.state.fsdp2_plugin is not None,
2233+
)
22312234
prepared_data_loader = prepare_data_loader(
22322235
data_loader,
22332236
self.device,
@@ -2497,7 +2500,9 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
24972500
parameters = [p for p in parameters]
24982501
for model in self._models:
24992502
if parameters == [p for p in model.parameters()]:
2500-
return torch.nn.utils.clip_grad_norm_(parameters=parameters, max_norm=max_norm, norm_type=norm_type)
2503+
return torch.nn.utils.clip_grad_norm_(
2504+
parameters=parameters, max_norm=max_norm, norm_type=norm_type
2505+
)
25012506
elif self.distributed_type == DistributedType.DEEPSPEED:
25022507
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
25032508
# We cannot return the gradient norm because DeepSpeed does it.

src/accelerate/commands/launch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,9 @@ def launch_command_parser(subparsers=None):
612612
)
613613

614614
# fsdp2 args
615-
fsdp2_args = parser.add_argument_group("FSDP2 Arguments", "Arguments related to Fully Shared Data Parallelism Version 2.")
615+
fsdp2_args = parser.add_argument_group(
616+
"FSDP2 Arguments", "Arguments related to Fully Shared Data Parallelism Version 2."
617+
)
616618
fsdp2_args.add_argument(
617619
"--fsdp2_reshard_after_forward",
618620
default="true",
@@ -1059,6 +1061,7 @@ def _validate_launch_command(args):
10591061
and not args.tpu_use_cluster
10601062
and not args.use_deepspeed
10611063
and not args.use_fsdp
1064+
and not args.use_fsdp2
10621065
and not args.use_tp
10631066
and not args.use_megatron_lm
10641067
):
@@ -1078,6 +1081,7 @@ def _validate_launch_command(args):
10781081
args.tpu = defaults.distributed_type == DistributedType.XLA
10791082
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
10801083
args.use_tp = defaults.distributed_type == DistributedType.TP
1084+
args.use_fsdp2 = defaults.distributed_type == DistributedType.FSDP2
10811085
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
10821086
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
10831087
if args.gpu_ids is None:
@@ -1237,6 +1241,8 @@ def launch_command(args):
12371241
deepspeed_launcher(args)
12381242
elif args.use_fsdp and not args.cpu:
12391243
multi_gpu_launcher(args)
1244+
elif args.use_fsdp2 and not args.cpu:
1245+
multi_gpu_launcher(args)
12401246
elif args.use_tp and not args.cpu:
12411247
multi_gpu_launcher(args)
12421248
elif args.use_megatron_lm and not args.cpu:

src/accelerate/utils/dataclasses.py

Lines changed: 32 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,13 +1850,12 @@ class FullyShardedDataParallelPlugin2:
18501850
reshard_after_forward: Optional[bool] = field(
18511851
default=None,
18521852
metadata={
1853-
"help":
1854-
"If reshard_after_forward is True, the parameters are sharded on every forward pass and all-gathered during backward pass."
1855-
"reshard_after_forward in conjunction with device mesh dimension would mean different strategies like the following:"
1856-
"reshard_after_forward=True and 1D device mesh mean full shard"
1857-
"reshard_after_forward=True and 2D device mesh mean hybrid shard"
1858-
"reshard_after_forward=False and 1D device mesh mean shard grad and optimizer states only"
1859-
"reshard_after_forward=False and 2D device mesh mean hybrid shard with grad and optim sharding only"
1853+
"help": "If reshard_after_forward is True, the parameters are sharded on every forward pass and all-gathered during backward pass."
1854+
"reshard_after_forward in conjunction with device mesh dimension would mean different strategies like the following:"
1855+
"reshard_after_forward=True and 1D device mesh mean full shard"
1856+
"reshard_after_forward=True and 2D device mesh mean hybrid shard"
1857+
"reshard_after_forward=False and 1D device mesh mean shard grad and optimizer states only"
1858+
"reshard_after_forward=False and 2D device mesh mean hybrid shard with grad and optim sharding only"
18601859
},
18611860
)
18621861
offload_policy: Optional[Union[dict, "torch.distributed._composable.OffloadPolicy"]] = field(
@@ -1865,45 +1864,45 @@ class FullyShardedDataParallelPlugin2:
18651864
"help": "A config to enable CPU offload. If passing in a `dict`, it should have the following key: `pin_memory`."
18661865
},
18671866
)
1868-
mp_policy: Optional[Union[dict, "torch.distributed._composable.MixedPrecisionPolicy"]] = (
1869-
field(
1870-
default=None,
1871-
metadata={
1872-
"help": "A config to enable mixed precision training with FullyShardedDataParallelv2. If passing in a `dict`, it"
1867+
mp_policy: Optional[Union[dict, "torch.distributed._composable.MixedPrecisionPolicy"]] = field(
1868+
default=None,
1869+
metadata={
1870+
"help": "A config to enable mixed precision training with FullyShardedDataParallelv2. If passing in a `dict`, it"
18731871
"should have the following keys: `param_dtype`, `reduce_dtype`, `output_dtype`, and `cast_forward_inputs`. "
1874-
},
1875-
)
1872+
},
18761873
)
18771874
ignored_params: Optional[set["torch.nn.Parameter"]] = field(
18781875
default=None,
1879-
metadata={
1880-
"help": "The set of parameters that we don't want to shard with FSDP."
1881-
},
1876+
metadata={"help": "The set of parameters that we don't want to shard with FSDP."},
18821877
)
18831878

18841879
def __post_init__(self):
18851880
env_prefix = "FSDP2_"
18861881
if self.reshard_after_forward is None:
18871882
self.reshard_after_forward = str_to_bool(os.environ.get(env_prefix + "RESHARD_AFTER_FORWARD", "True")) == 1
1888-
1883+
18891884
self.set_offload_policy()
18901885
self.set_mp_policy()
18911886

18921887
from torch.distributed.device_mesh import init_device_mesh
1888+
18931889
dp_mesh_dim_name = "dp"
18941890
fsdp_mesh_dim_name = "fsdp"
18951891
device = "cuda" # support for other devices has to be investigated
18961892
num_nodes = torch.distributed.get_world_size() // torch.cuda.device_count()
18971893
nproc_per_node = torch.cuda.device_count()
1898-
self.torch_device_mesh = init_device_mesh(device, (num_nodes,nproc_per_node), mesh_dim_names=(dp_mesh_dim_name, fsdp_mesh_dim_name))
1894+
self.torch_device_mesh = init_device_mesh(
1895+
device, (num_nodes, nproc_per_node), mesh_dim_names=(dp_mesh_dim_name, fsdp_mesh_dim_name)
1896+
)
18991897

19001898
def set_offload_policy(self, pin_memory=None):
19011899
"""
19021900
Set the offload policy
19031901
"""
19041902
from torch.distributed._composable.fsdp import CPUOffloadPolicy
1903+
19051904
env_prefix = "FSDP2_"
1906-
1905+
19071906
if self.offload_policy is None:
19081907
fsdp2_cpu_offload = str_to_bool(os.environ.get(env_prefix + "CPU_OFFLOAD", "False")) == 1
19091908
if fsdp2_cpu_offload:
@@ -1918,6 +1917,7 @@ def set_mp_policy(self, param_dtype=None, reduce_dtype=None, output_dtype=None,
19181917
Set mixed precision policy
19191918
"""
19201919
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
1920+
19211921
env_prefix = "FSDP2_"
19221922
mixed_precision_mapping = {
19231923
"fp8": torch.bfloat16,
@@ -1926,18 +1926,22 @@ def set_mp_policy(self, param_dtype=None, reduce_dtype=None, output_dtype=None,
19261926
"fp32": torch.float32,
19271927
}
19281928

1929-
# current_env["FSDP2_MP_PARAM_DTYPE"] = str(args.fsdp2_mp_param_dtype).lower()
1930-
# current_env["FSDP2_MP_REDUCE_DTYPE"] = str(args.fsdp2_mp_reduce_dtype).lower()
1931-
# current_env["FSDP2_MP_OUTPUT_DTYPE"] = str(args.fsdp2_mp_output_dtype).lower()
1932-
# current_env["FSDP2_CAST_FORWARD_INPUTS"] = str(args.fsdp2_cast_forward_inputs).lower()
1933-
19341929
if self.mp_policy is None:
19351930
param_dtype = mixed_precision_mapping.get(os.environ.get(env_prefix + "MP_PARAM_DTYPE", param_dtype), None)
1936-
reduce_dtype = mixed_precision_mapping.get(os.environ.get(env_prefix + "MP_REDUCE_DTYPE", reduce_dtype), None)
1937-
output_dtype = mixed_precision_mapping.get(os.environ.get(env_prefix + "MP_OUTPUT_DTYPE", output_dtype), None)
1931+
reduce_dtype = mixed_precision_mapping.get(
1932+
os.environ.get(env_prefix + "MP_REDUCE_DTYPE", reduce_dtype), None
1933+
)
1934+
output_dtype = mixed_precision_mapping.get(
1935+
os.environ.get(env_prefix + "MP_OUTPUT_DTYPE", output_dtype), None
1936+
)
19381937
if not cast_forward_inputs:
19391938
cast_forward_inputs = str_to_bool(os.environ.get(env_prefix + "CAST_FORWARD_INPUTS", "False")) == 1
1940-
self.mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype, output_dtype=output_dtype, cast_forward_inputs=cast_forward_inputs)
1939+
self.mp_policy = MixedPrecisionPolicy(
1940+
param_dtype=param_dtype,
1941+
reduce_dtype=reduce_dtype,
1942+
output_dtype=output_dtype,
1943+
cast_forward_inputs=cast_forward_inputs,
1944+
)
19411945

19421946
if isinstance(self.mp_policy, dict):
19431947
self.mp_policy = MixedPrecisionPolicy(**self.mp_policy)
@@ -1973,47 +1977,6 @@ def __post_init__(self):
19731977
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,))
19741978

19751979

1976-
@dataclass
1977-
class DeviceMeshHandler:
1978-
"""
1979-
This handler is used to create and hold device mesh state throughout the training
1980-
and dynamically support any combination of parallelisms.
1981-
"""
1982-
1983-
tp_size: int = field(
1984-
default=1,
1985-
metadata={"help": "tensor parallel size will be used in the device mesh preparation with other parallelisms."},
1986-
)
1987-
1988-
use_fsdp: bool = field(
1989-
default=False,
1990-
metadata={"help": "fsdp v2 will be used with other parallelisms for device mesh preparation."},
1991-
)
1992-
1993-
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
1994-
1995-
def __post_init__(self):
1996-
self.tp_size = self.tp_size if os.environ.get("TP_SIZE", "1") == "1" else int(os.environ.get("TP_SIZE", "1"))
1997-
self.use_fsdp = self.use_fsdp or str_to_bool(os.environ.get("ACCELERATE_USE_FSDP2", "False")) == 1
1998-
if self.tp_size == 1:
1999-
raise ValueError("Provide TP degree > 1.")
2000-
2001-
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION):
2002-
raise ValueError(
2003-
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel."
2004-
)
2005-
from torch.distributed.device_mesh import init_device_mesh
2006-
2007-
mesh_dim_names = ("tp",)
2008-
mesh_dims = (self.tp_size,)
2009-
if self.use_fsdp:
2010-
num_nodes = torch.distributed.get_world_size() // torch.cuda.device_count()
2011-
nproc_per_node = torch.cuda.device_count()
2012-
mesh_dim_names = ("dp", "fsdp",) + mesh_dim_names
2013-
mesh_dims = (num_nodes, nproc_per_node, ) + mesh_dims
2014-
device = "cuda" # support for other devices has to be investigated
2015-
self.torch_device_mesh = init_device_mesh(device, mesh_dims, mesh_dim_names=mesh_dim_names)
2016-
20171980
@dataclass
20181981
class MegatronLMPlugin:
20191982
"""

src/accelerate/utils/pytorch_utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import torch
22

3+
34
def prepare_nd_device_mesh(tp_size=1, use_fsdp=False):
45
"""Returns a multi dimensional device mesh.
56
Extend this function to support various combinations of parallelisms.
67
"""
78
from torch.distributed.device_mesh import init_device_mesh
9+
810
mesh_dim_names = ()
911
mesh_dims = ()
1012
if tp_size <= 1 and not use_fsdp:
@@ -15,7 +17,13 @@ def prepare_nd_device_mesh(tp_size=1, use_fsdp=False):
1517
if use_fsdp:
1618
num_nodes = torch.distributed.get_world_size() // torch.cuda.device_count()
1719
nproc_per_node = torch.cuda.device_count()
18-
mesh_dim_names = ("dp", "fsdp",) + mesh_dim_names
19-
mesh_dims = (num_nodes, nproc_per_node, ) + mesh_dims
20+
mesh_dim_names = (
21+
"dp",
22+
"fsdp",
23+
) + mesh_dim_names
24+
mesh_dims = (
25+
num_nodes,
26+
nproc_per_node // tp_size,
27+
) + mesh_dims
2028
device = "cuda"
2129
return init_device_mesh(device, mesh_dims, mesh_dim_names=mesh_dim_names)

0 commit comments

Comments
 (0)