@@ -1667,7 +1667,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
1667
1667
1668
1668
sdd = SparseDataDistUtil [ModelInput ](
1669
1669
model = sharded_model_pipelined ,
1670
- stream = torch .cuda .Stream (),
1670
+ data_dist_stream = torch .cuda .Stream (),
1671
1671
apply_jit = False ,
1672
1672
)
1673
1673
@@ -1695,7 +1695,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
1695
1695
PipelineStage (
1696
1696
name = "start_sparse_data_dist" ,
1697
1697
runnable = sdd .start_sparse_data_dist ,
1698
- stream = sdd .stream ,
1698
+ stream = sdd .data_dist_stream ,
1699
1699
fill_callback = sdd .wait_sparse_data_dist ,
1700
1700
),
1701
1701
]
@@ -1744,7 +1744,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
1744
1744
1745
1745
sdd = SparseDataDistUtil [ModelInput ](
1746
1746
model = sharded_model_pipelined ,
1747
- stream = torch .cuda .Stream (),
1747
+ data_dist_stream = torch .cuda .Stream (),
1748
1748
apply_jit = False ,
1749
1749
)
1750
1750
@@ -1762,7 +1762,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
1762
1762
PipelineStage (
1763
1763
name = "start_sparse_data_dist" ,
1764
1764
runnable = sdd .start_sparse_data_dist ,
1765
- stream = sdd .stream ,
1765
+ stream = sdd .data_dist_stream ,
1766
1766
fill_callback = sdd .wait_sparse_data_dist ,
1767
1767
),
1768
1768
]
@@ -1860,7 +1860,7 @@ def test_model_detach(self) -> None:
1860
1860
1861
1861
sdd = SparseDataDistUtil [ModelInput ](
1862
1862
model = sharded_model_pipelined ,
1863
- stream = torch .cuda .Stream (),
1863
+ data_dist_stream = torch .cuda .Stream (),
1864
1864
apply_jit = False ,
1865
1865
)
1866
1866
@@ -1873,7 +1873,7 @@ def test_model_detach(self) -> None:
1873
1873
PipelineStage (
1874
1874
name = "start_sparse_data_dist" ,
1875
1875
runnable = sdd .start_sparse_data_dist ,
1876
- stream = sdd .stream ,
1876
+ stream = sdd .data_dist_stream ,
1877
1877
fill_callback = sdd .wait_sparse_data_dist ,
1878
1878
),
1879
1879
]
@@ -1964,3 +1964,133 @@ def test_model_detach(self) -> None:
1964
1964
# Check pipeline exhausted
1965
1965
preproc_input = pipeline .progress (dataloader )
1966
1966
self .assertIsNone (preproc_input )
1967
+
1968
+ @unittest .skipIf (
1969
+ not torch .cuda .is_available (),
1970
+ "Not enough GPUs, this test requires at least one GPU" ,
1971
+ )
1972
+ @settings (max_examples = 4 , deadline = None )
1973
+ # pyre-ignore[56]
1974
+ @given (
1975
+ sharding_type = st .sampled_from (
1976
+ [
1977
+ ShardingType .TABLE_WISE .value ,
1978
+ ]
1979
+ ),
1980
+ kernel_type = st .sampled_from (
1981
+ [
1982
+ EmbeddingComputeKernel .FUSED_UVM_CACHING .value ,
1983
+ ]
1984
+ ),
1985
+ cache_precision = st .sampled_from (
1986
+ [
1987
+ DataType .FP16 ,
1988
+ DataType .FP32 ,
1989
+ ]
1990
+ ),
1991
+ load_factor = st .sampled_from (
1992
+ [
1993
+ 0.2 ,
1994
+ 0.4 ,
1995
+ ]
1996
+ ),
1997
+ )
1998
+ def test_pipelining_prefetch (
1999
+ self ,
2000
+ sharding_type : str ,
2001
+ kernel_type : str ,
2002
+ cache_precision : DataType ,
2003
+ load_factor : float ,
2004
+ ) -> None :
2005
+ model = self ._setup_model ()
2006
+
2007
+ fused_params = {
2008
+ "cache_load_factor" : load_factor ,
2009
+ "cache_precision" : cache_precision ,
2010
+ "stochastic_rounding" : False , # disable non-deterministic behavior when converting fp32<->fp16
2011
+ }
2012
+ fused_params_pipelined = {
2013
+ ** fused_params ,
2014
+ "prefetch_pipeline" : True ,
2015
+ }
2016
+
2017
+ sharded_model , optim = self ._generate_sharded_model_and_optimizer (
2018
+ model , sharding_type , kernel_type , fused_params
2019
+ )
2020
+ (
2021
+ sharded_model_pipelined ,
2022
+ optim_pipelined ,
2023
+ ) = self ._generate_sharded_model_and_optimizer (
2024
+ model , sharding_type , kernel_type , fused_params_pipelined
2025
+ )
2026
+
2027
+ copy_state_dict (
2028
+ sharded_model .state_dict (), sharded_model_pipelined .state_dict ()
2029
+ )
2030
+
2031
+ num_batches = 12
2032
+ data = self ._generate_data (
2033
+ num_batches = num_batches ,
2034
+ batch_size = 32 ,
2035
+ )
2036
+
2037
+ non_pipelined_outputs = []
2038
+ for batch in data :
2039
+ batch = batch .to (self .device )
2040
+ optim .zero_grad ()
2041
+ loss , pred = sharded_model (batch )
2042
+ loss .backward ()
2043
+ optim .step ()
2044
+ non_pipelined_outputs .append (pred )
2045
+
2046
+ def gpu_preproc (x : StageOut ) -> StageOut :
2047
+ return x
2048
+
2049
+ sdd = SparseDataDistUtil [ModelInput ](
2050
+ model = sharded_model_pipelined ,
2051
+ data_dist_stream = torch .cuda .Stream (),
2052
+ apply_jit = False ,
2053
+ prefetch_stream = torch .cuda .Stream (),
2054
+ )
2055
+
2056
+ pipeline_stages = [
2057
+ PipelineStage (
2058
+ name = "data_copy" ,
2059
+ runnable = partial (get_h2d_func , device = self .device ),
2060
+ stream = torch .cuda .Stream (),
2061
+ ),
2062
+ PipelineStage (
2063
+ name = "start_sparse_data_dist" ,
2064
+ runnable = sdd .start_sparse_data_dist ,
2065
+ stream = sdd .data_dist_stream ,
2066
+ fill_callback = sdd .wait_sparse_data_dist ,
2067
+ ),
2068
+ PipelineStage (
2069
+ name = "prefetch" ,
2070
+ runnable = sdd .prefetch ,
2071
+ # pyre-ignore
2072
+ stream = sdd .prefetch_stream ,
2073
+ fill_callback = sdd .load_prefetch ,
2074
+ ),
2075
+ ]
2076
+ pipeline = StagedTrainPipeline (
2077
+ pipeline_stages = pipeline_stages , compute_stream = torch .cuda .current_stream ()
2078
+ )
2079
+ dataloader = iter (data )
2080
+
2081
+ pipelined_out = []
2082
+ num_batches_processed = 0
2083
+
2084
+ while model_in := pipeline .progress (dataloader ):
2085
+ num_batches_processed += 1
2086
+ optim_pipelined .zero_grad ()
2087
+ loss , pred = sharded_model_pipelined (model_in )
2088
+ loss .backward ()
2089
+ optim_pipelined .step ()
2090
+ pipelined_out .append (pred )
2091
+
2092
+ self .assertEqual (num_batches_processed , num_batches )
2093
+
2094
+ self .assertEqual (len (pipelined_out ), len (non_pipelined_outputs ))
2095
+ for out , ref_out in zip (pipelined_out , non_pipelined_outputs ):
2096
+ torch .testing .assert_close (out , ref_out )
0 commit comments