Skip to content

Support TP + FSDPv2 / HSDP or just FSDPv2 / HSDP #3395

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 88 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
DynamoBackend,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
FullyShardedDataParallelPlugin2,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand Down Expand Up @@ -107,6 +108,7 @@
save_fsdp_model,
save_fsdp_optimizer,
wait_for_everyone,
prepare_nd_device_mesh,
)
from .utils.constants import (
BETA_TP_AVAILABLE_PYTORCH_VERSION,
Expand Down Expand Up @@ -195,6 +197,9 @@ class Accelerator:
fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
fsdp2_plugin ([`~utils.FullyShardedDataParallelPlugin2`], *optional*):
Tweak your FSDPv2 related args using this argument. This argument is optional and can be configured directly
using *accelerate config*
torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
Tweak your torch tensor parallel. This argument is optional and can be configured directly using
*accelerate config*
Expand Down Expand Up @@ -267,6 +272,7 @@ def __init__(
dataloader_config: DataLoaderConfiguration | None = None,
deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
fsdp2_plugin: FullyShardedDataParallelPlugin | None = None,
torch_tp_plugin: TorchTensorParallelPlugin | None = None,
megatron_lm_plugin: MegatronLMPlugin | None = None,
rng_types: list[str | RNGType] | None = None,
Expand Down Expand Up @@ -373,6 +379,15 @@ def __init__(
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")

if os.environ.get("ACCELERATE_USE_FSDP2", "false") == "true" or isinstance(
fsdp2_plugin, FullyShardedDataParallelPlugin2
):
if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
raise ValueError(f"FSDPv2 requires PyTorch >= {BETA_TP_AVAILABLE_PYTORCH_VERSION}")

if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"FSDPv2 requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")

if fsdp_plugin is None: # init from env variables
fsdp_plugin = (
FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None
Expand All @@ -391,6 +406,17 @@ def __init__(
raise TypeError("`torch_tp_plugin` must be a TorchTensorParallelPlugin object.")
os.environ["ACCELERATE_USE_TP"] = "true"

if fsdp2_plugin is None:
fsdp2_plugin = (
FullyShardedDataParallelPlugin2()
if os.environ.get("ACCELERATE_USE_FSDP2", "false") == "true"
else None
)
else:
if not isinstance(fsdp2_plugin, FullyShardedDataParallelPlugin2):
raise TypeError("`fsdp2_plugin` must be a FullyShardedDataParallelPlugin2 object.")
os.environ["ACCELERATE_USE_FSDP2"] = "true"

if megatron_lm_plugin is None: # init from env variables
megatron_lm_plugin = (
MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None
Expand Down Expand Up @@ -456,6 +482,7 @@ def __init__(
dynamo_plugin=dynamo_plugin,
deepspeed_plugin=deepspeed_plugins,
fsdp_plugin=fsdp_plugin,
fsdp2_plugin=fsdp2_plugin,
torch_tp_plugin=torch_tp_plugin,
megatron_lm_plugin=megatron_lm_plugin,
_from_accelerator=True,
Expand Down Expand Up @@ -1420,6 +1447,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if device_placement is None:
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
self._models.append(model)
device_mesh = None

# TODO: Look at enabling native TP training directly with a proper config
if (
Expand Down Expand Up @@ -1487,6 +1515,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
elif device_placement and not self.verify_device_map(model):
model = model.to(self.device)
if not evaluation_mode:
# motivation behind preparing device mesh at the start is to easily extend
# device preparation for any combination of parallelisms and pass it on
# neatly to respective parallelism distribution code snippets.
# function prepare_nd_device_mesh should be enough to extend logic for future combinations
# for now prepare_nd_device_mesh handles any combination of TP and FSDP/HSDP
device_mesh = prepare_nd_device_mesh(
self.state.torch_tp_plugin.tp_size if self.state.torch_tp_plugin is not None else 1,
self.state.fsdp2_plugin is not None,
)
if self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
Expand All @@ -1507,7 +1544,8 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
)
if self.ddp_handler is not None:
self.ddp_handler.register_comm_hook(model)
elif self.distributed_type == DistributedType.TP:
elif self.distributed_type == DistributedType.TP or self.distributed_type == DistributedType.FSDP2_TP:
self.state.torch_tp_plugin.torch_device_mesh = device_mesh["tp"]
if hasattr(model, "supports_tp_plan") and not model.supports_tp_plan:
if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
Expand All @@ -1517,6 +1555,41 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
and _tp_plan attribute to model class."
)
model.tensor_parallel(self.state.torch_tp_plugin.torch_device_mesh["tp"])
if self.distributed_type == DistributedType.FSDP2 or self.distributed_type == DistributedType.FSDP2_TP:
self.state.fsdp2_plugin.torch_device_mesh = device_mesh["dp", "fsdp"]
from torch.distributed._composable.fsdp import fully_shard, FSDPModule

# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
# don't wrap it again
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
# is a FSDP model, don't wrap it again
# We check for FSDPModule instead of FSDP class for FSDP v2
is_type_fsdp = isinstance(model, FSDPModule) or (
is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule)
)

if not is_type_fsdp:
fsdp2_kwargs = {
"mp_policy": self.state.fsdp2_plugin.mp_policy,
"reshard_after_forward": self.state.fsdp2_plugin.reshard_after_forward,
"offload_policy": self.state.fsdp2_plugin.offload_policy,
# pretty recent feature so lets ignore it for now
# "ignored_params": fsdp2_plugin.ignored_params,
"mesh": self.state.fsdp2_plugin.torch_device_mesh,
}

for layer in model.model.layers:
fully_shard(layer, **fsdp2_kwargs)
fully_shard(model, **fsdp2_kwargs)
# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
self._models[-1] = model

#######
# TODO: support activation_checkpointing for FSDP2 and nd parallel cases
#######

elif self.distributed_type == DistributedType.FSDP:
# We need to fix the optimizer *before* sharding the model
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
Expand Down Expand Up @@ -2153,6 +2226,11 @@ def prepare_data_loader(
return data_loader
if device_placement is None:
device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False

nd_device_mesh = prepare_nd_device_mesh(
self.state.torch_tp_plugin.tp_size if self.state.torch_tp_plugin is not None else 1,
self.state.fsdp2_plugin is not None,
)
prepared_data_loader = prepare_data_loader(
data_loader,
self.device,
Expand All @@ -2168,7 +2246,7 @@ def prepare_data_loader(
data_seed=self.dataloader_config.data_seed,
non_blocking=self.non_blocking,
use_stateful_dataloader=self.use_stateful_dataloader,
torch_device_mesh=self.state.torch_tp_plugin.torch_device_mesh if self.state.torch_tp_plugin else None,
torch_device_mesh=nd_device_mesh,
)
self._dataloaders.append(prepared_data_loader)
return prepared_data_loader
Expand Down Expand Up @@ -2417,6 +2495,14 @@ def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
for model in self._models:
if parameters == [p for p in model.parameters()]:
return model.clip_grad_norm_(max_norm, norm_type)
elif self.distributed_type == DistributedType.FSDP2:
self.unscale_gradients()
parameters = [p for p in parameters]
for model in self._models:
if parameters == [p for p in model.parameters()]:
return torch.nn.utils.clip_grad_norm_(
parameters=parameters, max_norm=max_norm, norm_type=norm_type
)
elif self.distributed_type == DistributedType.DEEPSPEED:
# `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed
# We cannot return the gradient norm because DeepSpeed does it.
Expand Down
61 changes: 61 additions & 0 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"tpu": "TPU",
"use_deepspeed": "DeepSpeed Arguments",
"use_fsdp": "FSDP Arguments",
"use_fsdp2": "FSDP2 Arguments",
"use_tp": "PyTorch TP Arguments",
"use_megatron_lm": "Megatron-LM Arguments",
"fp8_backend": "FP8 Arguments",
Expand Down Expand Up @@ -262,6 +263,12 @@ def launch_command_parser(subparsers=None):
action="store_true",
help="Whether to use fsdp.",
)
paradigm_args.add_argument(
"--use_fsdp2",
default=False,
action="store_true",
help="Whether to use fsdpv2.",
)
paradigm_args.add_argument(
"--use_tp",
default=False,
Expand Down Expand Up @@ -604,6 +611,56 @@ def launch_command_parser(subparsers=None):
help="PyTorch Tensor Parallelism (TP) degree. Set a value greater than 1 to activate. (useful only when `use_tp` flag is passed)",
)

# fsdp2 args
fsdp2_args = parser.add_argument_group(
"FSDP2 Arguments", "Arguments related to Fully Shared Data Parallelism Version 2."
)
fsdp2_args.add_argument(
"--fsdp2_reshard_after_forward",
default="true",
type=str,
help="Decides Whether (true|false) to reshard parameters after forward pass.",
)
fsdp2_args.add_argument(
"--fsdp2_cpu_offload",
default="false",
type=str,
help="Decides Whether (true|false) to offload to CPU.",
)
fsdp2_args.add_argument(
"--fsdp2_cpu_offload_pin_memory",
default="false",
type=str,
help="Decides Whether (true|false) to pin memory during CPU offload.",
)
fsdp2_args.add_argument(
"--fsdp2_mp_param_dtype",
default="no",
type=str,
choices=["no", "fp16", "bf16", "fp8"],
help="Parameter datatype to be used in mixed precision training.",
)
fsdp2_args.add_argument(
"--fsdp2_mp_reduce_dtype",
default="no",
type=str,
choices=["no", "fp16", "bf16", "fp8"],
help="Dtype for gradient reduction to be used in mixed precision training.",
)
fsdp2_args.add_argument(
"--fsdp2_mp_output_dtype",
default="no",
type=str,
choices=["no", "fp16", "bf16", "fp8"],
help="Dtype for forward outputs to be used in mixed precision training.",
)
fsdp2_args.add_argument(
"--fsdp2_cast_forward_inputs",
default="false",
type=str,
help="Decides Whether (true|false) to cast forward inputs in mixed precision training.",
)

# megatron_lm args
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
megatron_lm_args.add_argument(
Expand Down Expand Up @@ -1004,6 +1061,7 @@ def _validate_launch_command(args):
and not args.tpu_use_cluster
and not args.use_deepspeed
and not args.use_fsdp
and not args.use_fsdp2
and not args.use_tp
and not args.use_megatron_lm
):
Expand All @@ -1023,6 +1081,7 @@ def _validate_launch_command(args):
args.tpu = defaults.distributed_type == DistributedType.XLA
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
args.use_tp = defaults.distributed_type == DistributedType.TP
args.use_fsdp2 = defaults.distributed_type == DistributedType.FSDP2
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
if args.gpu_ids is None:
Expand Down Expand Up @@ -1182,6 +1241,8 @@ def launch_command(args):
deepspeed_launcher(args)
elif args.use_fsdp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_fsdp2 and not args.cpu:
multi_gpu_launcher(args)
elif args.use_tp and not args.cpu:
multi_gpu_launcher(args)
elif args.use_megatron_lm and not args.cpu:
Expand Down
6 changes: 6 additions & 0 deletions src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,6 +851,7 @@ def __init__(
dynamo_plugin=None,
deepspeed_plugin=None,
fsdp_plugin=None,
fsdp2_plugin=None,
torch_tp_plugin=None,
megatron_lm_plugin=None,
_from_accelerator: bool = False,
Expand All @@ -867,6 +868,7 @@ def __init__(
self.deepspeed_plugins = None
self.use_ipex = None
self.torch_tp_plugin = torch_tp_plugin
self.fsdp2_plugin = fsdp2_plugin
mixed_precision = (
parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no")
if mixed_precision is None
Expand Down Expand Up @@ -926,6 +928,10 @@ def __init__(
self.megatron_lm_plugin = megatron_lm_plugin
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.TP
if os.environ.get("ACCELERATE_USE_FSDP2", "false") == "true" or self.fsdp2_plugin is not None:
self.distributed_type = DistributedType.FSDP2
if os.environ.get("ACCELERATE_USE_TP", "false") == "true" or self.torch_tp_plugin is not None:
self.distributed_type = DistributedType.FSDP2_TP
elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
if is_ipex_available():
# check if user disables it explicitly
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
DynamoBackend,
FP8RecipeKwargs,
FullyShardedDataParallelPlugin,
FullyShardedDataParallelPlugin2,
GradientAccumulationPlugin,
GradScalerKwargs,
InitProcessGroupKwargs,
Expand Down Expand Up @@ -272,3 +273,4 @@
convert_model,
has_transformer_engine_layers,
)
from .pytorch_utils import prepare_nd_device_mesh
Loading