Skip to content

Commit 9264186

Browse files
sarckkfacebook-github-bot
authored andcommitted
Enable prefetch stage for StagedTrainPipeline (#2239)
Summary: Pull Request resolved: #2239 Add ability to run prefetch as a stage in `StagedTrainPipeline` Recommended usage to run 3-stage pipeline with data copy, sparse dist and prefetch steps (changes required shown with arrows): ``` sdd = SparseDataDistUtil( model=self._model, data_dist_stream=torch.torch.cuda.Stream(), prefetch_stream=torch.torch.cuda.Stream(), <--- define prefetch stream ) pipeline = [ PipelineStage( name="data_copy", runnable=lambda batch, context: batch.to( self._device, non_blocking=True ), stream=torch.cuda.Stream(), ), PipelineStage( name="start_sparse_data_dist", runnable=sdd.start_sparse_data_dist, stream=sdd.data_dist_stream, fill_callback=sdd.wait_sparse_data_dist, ), PipelineStage( name="prefetch", runnable=sdd.prefetch, <--- add stage with runnable=sdd.prefetch stream=sdd.prefetch_stream, fill_callback=sdd.load_prefetch, <--- fill_callback of sdd.load_prefetch ), ] return StagedTrainPipeline(pipeline_stages=pipeline) ``` Order of execution for above pipeline: Iteration #1: _fill_pipeline(): batch 0: memcpy, start_sdd, wait_sdd (callback), prefetch, load_prefetch (callback) batch 1: memcpy, start_sdd, wait_sdd (callback) batch 2: memcpy progress(): batch 3: memcpy batch 2: start_sdd batch 1: prefetch after pipeline progress(): model(batch 0) load_prefetch (prepares for model fwd on batch 1) wait_sdd (prepares for batch 2 prefetch) Iteration #2: progress(): batch 4: memcpy batch 3: start_sdd batch 2: prefetch after pipeline progress(): model(batch 1) load_prefetch (prepares for model fwd on batch 2) wait_sdd (prepares for batch 3 prefetch) Reviewed By: zzzwen, joshuadeng Differential Revision: D59786807 fbshipit-source-id: 6261c07cd6823bc541463d24ff867ab0e43631ea
1 parent 09d1ff2 commit 9264186

File tree

3 files changed

+327
-50
lines changed

3 files changed

+327
-50
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 136 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1667,7 +1667,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
16671667

16681668
sdd = SparseDataDistUtil[ModelInput](
16691669
model=sharded_model_pipelined,
1670-
stream=torch.cuda.Stream(),
1670+
data_dist_stream=torch.cuda.Stream(),
16711671
apply_jit=False,
16721672
)
16731673

@@ -1695,7 +1695,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
16951695
PipelineStage(
16961696
name="start_sparse_data_dist",
16971697
runnable=sdd.start_sparse_data_dist,
1698-
stream=sdd.stream,
1698+
stream=sdd.data_dist_stream,
16991699
fill_callback=sdd.wait_sparse_data_dist,
17001700
),
17011701
]
@@ -1744,7 +1744,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
17441744

17451745
sdd = SparseDataDistUtil[ModelInput](
17461746
model=sharded_model_pipelined,
1747-
stream=torch.cuda.Stream(),
1747+
data_dist_stream=torch.cuda.Stream(),
17481748
apply_jit=False,
17491749
)
17501750

@@ -1762,7 +1762,7 @@ def gpu_preproc(x: StageOut) -> StageOut:
17621762
PipelineStage(
17631763
name="start_sparse_data_dist",
17641764
runnable=sdd.start_sparse_data_dist,
1765-
stream=sdd.stream,
1765+
stream=sdd.data_dist_stream,
17661766
fill_callback=sdd.wait_sparse_data_dist,
17671767
),
17681768
]
@@ -1860,7 +1860,7 @@ def test_model_detach(self) -> None:
18601860

18611861
sdd = SparseDataDistUtil[ModelInput](
18621862
model=sharded_model_pipelined,
1863-
stream=torch.cuda.Stream(),
1863+
data_dist_stream=torch.cuda.Stream(),
18641864
apply_jit=False,
18651865
)
18661866

@@ -1873,7 +1873,7 @@ def test_model_detach(self) -> None:
18731873
PipelineStage(
18741874
name="start_sparse_data_dist",
18751875
runnable=sdd.start_sparse_data_dist,
1876-
stream=sdd.stream,
1876+
stream=sdd.data_dist_stream,
18771877
fill_callback=sdd.wait_sparse_data_dist,
18781878
),
18791879
]
@@ -1964,3 +1964,133 @@ def test_model_detach(self) -> None:
19641964
# Check pipeline exhausted
19651965
preproc_input = pipeline.progress(dataloader)
19661966
self.assertIsNone(preproc_input)
1967+
1968+
@unittest.skipIf(
1969+
not torch.cuda.is_available(),
1970+
"Not enough GPUs, this test requires at least one GPU",
1971+
)
1972+
@settings(max_examples=4, deadline=None)
1973+
# pyre-ignore[56]
1974+
@given(
1975+
sharding_type=st.sampled_from(
1976+
[
1977+
ShardingType.TABLE_WISE.value,
1978+
]
1979+
),
1980+
kernel_type=st.sampled_from(
1981+
[
1982+
EmbeddingComputeKernel.FUSED_UVM_CACHING.value,
1983+
]
1984+
),
1985+
cache_precision=st.sampled_from(
1986+
[
1987+
DataType.FP16,
1988+
DataType.FP32,
1989+
]
1990+
),
1991+
load_factor=st.sampled_from(
1992+
[
1993+
0.2,
1994+
0.4,
1995+
]
1996+
),
1997+
)
1998+
def test_pipelining_prefetch(
1999+
self,
2000+
sharding_type: str,
2001+
kernel_type: str,
2002+
cache_precision: DataType,
2003+
load_factor: float,
2004+
) -> None:
2005+
model = self._setup_model()
2006+
2007+
fused_params = {
2008+
"cache_load_factor": load_factor,
2009+
"cache_precision": cache_precision,
2010+
"stochastic_rounding": False, # disable non-deterministic behavior when converting fp32<->fp16
2011+
}
2012+
fused_params_pipelined = {
2013+
**fused_params,
2014+
"prefetch_pipeline": True,
2015+
}
2016+
2017+
sharded_model, optim = self._generate_sharded_model_and_optimizer(
2018+
model, sharding_type, kernel_type, fused_params
2019+
)
2020+
(
2021+
sharded_model_pipelined,
2022+
optim_pipelined,
2023+
) = self._generate_sharded_model_and_optimizer(
2024+
model, sharding_type, kernel_type, fused_params_pipelined
2025+
)
2026+
2027+
copy_state_dict(
2028+
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
2029+
)
2030+
2031+
num_batches = 12
2032+
data = self._generate_data(
2033+
num_batches=num_batches,
2034+
batch_size=32,
2035+
)
2036+
2037+
non_pipelined_outputs = []
2038+
for batch in data:
2039+
batch = batch.to(self.device)
2040+
optim.zero_grad()
2041+
loss, pred = sharded_model(batch)
2042+
loss.backward()
2043+
optim.step()
2044+
non_pipelined_outputs.append(pred)
2045+
2046+
def gpu_preproc(x: StageOut) -> StageOut:
2047+
return x
2048+
2049+
sdd = SparseDataDistUtil[ModelInput](
2050+
model=sharded_model_pipelined,
2051+
data_dist_stream=torch.cuda.Stream(),
2052+
apply_jit=False,
2053+
prefetch_stream=torch.cuda.Stream(),
2054+
)
2055+
2056+
pipeline_stages = [
2057+
PipelineStage(
2058+
name="data_copy",
2059+
runnable=partial(get_h2d_func, device=self.device),
2060+
stream=torch.cuda.Stream(),
2061+
),
2062+
PipelineStage(
2063+
name="start_sparse_data_dist",
2064+
runnable=sdd.start_sparse_data_dist,
2065+
stream=sdd.data_dist_stream,
2066+
fill_callback=sdd.wait_sparse_data_dist,
2067+
),
2068+
PipelineStage(
2069+
name="prefetch",
2070+
runnable=sdd.prefetch,
2071+
# pyre-ignore
2072+
stream=sdd.prefetch_stream,
2073+
fill_callback=sdd.load_prefetch,
2074+
),
2075+
]
2076+
pipeline = StagedTrainPipeline(
2077+
pipeline_stages=pipeline_stages, compute_stream=torch.cuda.current_stream()
2078+
)
2079+
dataloader = iter(data)
2080+
2081+
pipelined_out = []
2082+
num_batches_processed = 0
2083+
2084+
while model_in := pipeline.progress(dataloader):
2085+
num_batches_processed += 1
2086+
optim_pipelined.zero_grad()
2087+
loss, pred = sharded_model_pipelined(model_in)
2088+
loss.backward()
2089+
optim_pipelined.step()
2090+
pipelined_out.append(pred)
2091+
2092+
self.assertEqual(num_batches_processed, num_batches)
2093+
2094+
self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
2095+
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
2096+
torch.testing.assert_close(out, ref_out)

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 11 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from torchrec.distributed.train_pipeline.utils import (
3434
_override_input_dist_forwards,
3535
_pipeline_detach_model,
36+
_prefetch_embeddings,
3637
_rewrite_model,
3738
_start_data_dist,
3839
_start_embedding_lookup,
@@ -1101,46 +1102,18 @@ def _prefetch(self, batch: Optional[In]) -> None:
11011102
batch.record_stream(
11021103
torch.get_device_module(self._device).current_stream()
11031104
)
1105+
data_per_pipelined_module = _prefetch_embeddings(
1106+
batch,
1107+
self._context,
1108+
self._pipelined_modules,
1109+
self._device,
1110+
self._stream_context,
1111+
self._data_dist_stream,
1112+
self._default_stream,
1113+
)
11041114
for sharded_module in self._pipelined_modules:
11051115
forward = sharded_module.forward
1106-
assert isinstance(forward, PrefetchPipelinedForward)
1107-
1108-
assert forward._name in self._context.input_dist_tensors_requests
1109-
request = self._context.input_dist_tensors_requests.pop(
1110-
forward._name
1111-
)
1112-
assert isinstance(request, Awaitable)
1113-
with record_function("## wait_sparse_data_dist ##"):
1114-
# Finish waiting on the dist_stream,
1115-
# in case some delayed stream scheduling happens during the wait() call.
1116-
with self._stream_context(self._data_dist_stream):
1117-
data = request.wait()
1118-
1119-
# Make sure that both result of input_dist and context
1120-
# are properly transferred to the current stream.
1121-
module_context = self._context.module_contexts[forward._name]
1122-
if self._data_dist_stream is not None:
1123-
torch.get_device_module(
1124-
self._device
1125-
).current_stream().wait_stream(self._data_dist_stream)
1126-
cur_stream = torch.get_device_module(
1127-
self._device
1128-
).current_stream()
1129-
1130-
assert isinstance(
1131-
data, (torch.Tensor, Multistreamable)
1132-
), f"{type(data)} must implement Multistreamable interface"
1133-
data.record_stream(cur_stream)
1134-
data.record_stream(self._default_stream)
1135-
1136-
module_context.record_stream(cur_stream)
1137-
module_context.record_stream(self._default_stream)
1138-
1139-
sharded_module.prefetch(
1140-
ctx=module_context,
1141-
dist_input=data,
1142-
forward_stream=self._default_stream,
1143-
)
1116+
data = data_per_pipelined_module[forward._name]
11441117
self._context.module_input_post_prefetch[forward._name] = data
11451118
self._context.module_contexts_post_prefetch[forward._name] = (
11461119
self._context.module_contexts.pop(forward._name)

0 commit comments

Comments
 (0)