@@ -1850,13 +1850,12 @@ class FullyShardedDataParallelPlugin2:
1850
1850
reshard_after_forward : Optional [bool ] = field (
1851
1851
default = None ,
1852
1852
metadata = {
1853
- "help" :
1854
- "If reshard_after_forward is True, the parameters are sharded on every forward pass and all-gathered during backward pass."
1855
- "reshard_after_forward in conjunction with device mesh dimension would mean different strategies like the following:"
1856
- "reshard_after_forward=True and 1D device mesh mean full shard"
1857
- "reshard_after_forward=True and 2D device mesh mean hybrid shard"
1858
- "reshard_after_forward=False and 1D device mesh mean shard grad and optimizer states only"
1859
- "reshard_after_forward=False and 2D device mesh mean hybrid shard with grad and optim sharding only"
1853
+ "help" : "If reshard_after_forward is True, the parameters are sharded on every forward pass and all-gathered during backward pass."
1854
+ "reshard_after_forward in conjunction with device mesh dimension would mean different strategies like the following:"
1855
+ "reshard_after_forward=True and 1D device mesh mean full shard"
1856
+ "reshard_after_forward=True and 2D device mesh mean hybrid shard"
1857
+ "reshard_after_forward=False and 1D device mesh mean shard grad and optimizer states only"
1858
+ "reshard_after_forward=False and 2D device mesh mean hybrid shard with grad and optim sharding only"
1860
1859
},
1861
1860
)
1862
1861
offload_policy : Optional [Union [dict , "torch.distributed._composable.OffloadPolicy" ]] = field (
@@ -1865,45 +1864,45 @@ class FullyShardedDataParallelPlugin2:
1865
1864
"help" : "A config to enable CPU offload. If passing in a `dict`, it should have the following key: `pin_memory`."
1866
1865
},
1867
1866
)
1868
- mp_policy : Optional [Union [dict , "torch.distributed._composable.MixedPrecisionPolicy" ]] = (
1869
- field (
1870
- default = None ,
1871
- metadata = {
1872
- "help" : "A config to enable mixed precision training with FullyShardedDataParallelv2. If passing in a `dict`, it"
1867
+ mp_policy : Optional [Union [dict , "torch.distributed._composable.MixedPrecisionPolicy" ]] = field (
1868
+ default = None ,
1869
+ metadata = {
1870
+ "help" : "A config to enable mixed precision training with FullyShardedDataParallelv2. If passing in a `dict`, it"
1873
1871
"should have the following keys: `param_dtype`, `reduce_dtype`, `output_dtype`, and `cast_forward_inputs`. "
1874
- },
1875
- )
1872
+ },
1876
1873
)
1877
1874
ignored_params : Optional [set ["torch.nn.Parameter" ]] = field (
1878
1875
default = None ,
1879
- metadata = {
1880
- "help" : "The set of parameters that we don't want to shard with FSDP."
1881
- },
1876
+ metadata = {"help" : "The set of parameters that we don't want to shard with FSDP." },
1882
1877
)
1883
1878
1884
1879
def __post_init__ (self ):
1885
1880
env_prefix = "FSDP2_"
1886
1881
if self .reshard_after_forward is None :
1887
1882
self .reshard_after_forward = str_to_bool (os .environ .get (env_prefix + "RESHARD_AFTER_FORWARD" , "True" )) == 1
1888
-
1883
+
1889
1884
self .set_offload_policy ()
1890
1885
self .set_mp_policy ()
1891
1886
1892
1887
from torch .distributed .device_mesh import init_device_mesh
1888
+
1893
1889
dp_mesh_dim_name = "dp"
1894
1890
fsdp_mesh_dim_name = "fsdp"
1895
1891
device = "cuda" # support for other devices has to be investigated
1896
1892
num_nodes = torch .distributed .get_world_size () // torch .cuda .device_count ()
1897
1893
nproc_per_node = torch .cuda .device_count ()
1898
- self .torch_device_mesh = init_device_mesh (device , (num_nodes ,nproc_per_node ), mesh_dim_names = (dp_mesh_dim_name , fsdp_mesh_dim_name ))
1894
+ self .torch_device_mesh = init_device_mesh (
1895
+ device , (num_nodes , nproc_per_node ), mesh_dim_names = (dp_mesh_dim_name , fsdp_mesh_dim_name )
1896
+ )
1899
1897
1900
1898
def set_offload_policy (self , pin_memory = None ):
1901
1899
"""
1902
1900
Set the offload policy
1903
1901
"""
1904
1902
from torch .distributed ._composable .fsdp import CPUOffloadPolicy
1903
+
1905
1904
env_prefix = "FSDP2_"
1906
-
1905
+
1907
1906
if self .offload_policy is None :
1908
1907
fsdp2_cpu_offload = str_to_bool (os .environ .get (env_prefix + "CPU_OFFLOAD" , "False" )) == 1
1909
1908
if fsdp2_cpu_offload :
@@ -1918,6 +1917,7 @@ def set_mp_policy(self, param_dtype=None, reduce_dtype=None, output_dtype=None,
1918
1917
Set mixed precision policy
1919
1918
"""
1920
1919
from torch .distributed ._composable .fsdp import MixedPrecisionPolicy
1920
+
1921
1921
env_prefix = "FSDP2_"
1922
1922
mixed_precision_mapping = {
1923
1923
"fp8" : torch .bfloat16 ,
@@ -1926,18 +1926,22 @@ def set_mp_policy(self, param_dtype=None, reduce_dtype=None, output_dtype=None,
1926
1926
"fp32" : torch .float32 ,
1927
1927
}
1928
1928
1929
- # current_env["FSDP2_MP_PARAM_DTYPE"] = str(args.fsdp2_mp_param_dtype).lower()
1930
- # current_env["FSDP2_MP_REDUCE_DTYPE"] = str(args.fsdp2_mp_reduce_dtype).lower()
1931
- # current_env["FSDP2_MP_OUTPUT_DTYPE"] = str(args.fsdp2_mp_output_dtype).lower()
1932
- # current_env["FSDP2_CAST_FORWARD_INPUTS"] = str(args.fsdp2_cast_forward_inputs).lower()
1933
-
1934
1929
if self .mp_policy is None :
1935
1930
param_dtype = mixed_precision_mapping .get (os .environ .get (env_prefix + "MP_PARAM_DTYPE" , param_dtype ), None )
1936
- reduce_dtype = mixed_precision_mapping .get (os .environ .get (env_prefix + "MP_REDUCE_DTYPE" , reduce_dtype ), None )
1937
- output_dtype = mixed_precision_mapping .get (os .environ .get (env_prefix + "MP_OUTPUT_DTYPE" , output_dtype ), None )
1931
+ reduce_dtype = mixed_precision_mapping .get (
1932
+ os .environ .get (env_prefix + "MP_REDUCE_DTYPE" , reduce_dtype ), None
1933
+ )
1934
+ output_dtype = mixed_precision_mapping .get (
1935
+ os .environ .get (env_prefix + "MP_OUTPUT_DTYPE" , output_dtype ), None
1936
+ )
1938
1937
if not cast_forward_inputs :
1939
1938
cast_forward_inputs = str_to_bool (os .environ .get (env_prefix + "CAST_FORWARD_INPUTS" , "False" )) == 1
1940
- self .mp_policy = MixedPrecisionPolicy (param_dtype = param_dtype , reduce_dtype = reduce_dtype , output_dtype = output_dtype , cast_forward_inputs = cast_forward_inputs )
1939
+ self .mp_policy = MixedPrecisionPolicy (
1940
+ param_dtype = param_dtype ,
1941
+ reduce_dtype = reduce_dtype ,
1942
+ output_dtype = output_dtype ,
1943
+ cast_forward_inputs = cast_forward_inputs ,
1944
+ )
1941
1945
1942
1946
if isinstance (self .mp_policy , dict ):
1943
1947
self .mp_policy = MixedPrecisionPolicy (** self .mp_policy )
@@ -1973,47 +1977,6 @@ def __post_init__(self):
1973
1977
self .torch_device_mesh = init_device_mesh (device , (self .tp_size ,), mesh_dim_names = (mesh_dim_name ,))
1974
1978
1975
1979
1976
- @dataclass
1977
- class DeviceMeshHandler :
1978
- """
1979
- This handler is used to create and hold device mesh state throughout the training
1980
- and dynamically support any combination of parallelisms.
1981
- """
1982
-
1983
- tp_size : int = field (
1984
- default = 1 ,
1985
- metadata = {"help" : "tensor parallel size will be used in the device mesh preparation with other parallelisms." },
1986
- )
1987
-
1988
- use_fsdp : bool = field (
1989
- default = False ,
1990
- metadata = {"help" : "fsdp v2 will be used with other parallelisms for device mesh preparation." },
1991
- )
1992
-
1993
- torch_device_mesh : Optional ["torch.distributed.DeviceMesh" ] = field (default = None )
1994
-
1995
- def __post_init__ (self ):
1996
- self .tp_size = self .tp_size if os .environ .get ("TP_SIZE" , "1" ) == "1" else int (os .environ .get ("TP_SIZE" , "1" ))
1997
- self .use_fsdp = self .use_fsdp or str_to_bool (os .environ .get ("ACCELERATE_USE_FSDP2" , "False" )) == 1
1998
- if self .tp_size == 1 :
1999
- raise ValueError ("Provide TP degree > 1." )
2000
-
2001
- if is_torch_version ("<" , BETA_TP_AVAILABLE_PYTORCH_VERSION ):
2002
- raise ValueError (
2003
- f"Minimum PyTorch version { BETA_TP_AVAILABLE_PYTORCH_VERSION } needed to use tensor parallel."
2004
- )
2005
- from torch .distributed .device_mesh import init_device_mesh
2006
-
2007
- mesh_dim_names = ("tp" ,)
2008
- mesh_dims = (self .tp_size ,)
2009
- if self .use_fsdp :
2010
- num_nodes = torch .distributed .get_world_size () // torch .cuda .device_count ()
2011
- nproc_per_node = torch .cuda .device_count ()
2012
- mesh_dim_names = ("dp" , "fsdp" ,) + mesh_dim_names
2013
- mesh_dims = (num_nodes , nproc_per_node , ) + mesh_dims
2014
- device = "cuda" # support for other devices has to be investigated
2015
- self .torch_device_mesh = init_device_mesh (device , mesh_dims , mesh_dim_names = mesh_dim_names )
2016
-
2017
1980
@dataclass
2018
1981
class MegatronLMPlugin :
2019
1982
"""
0 commit comments