Skip to content

Commit 919bbcb

Browse files
sarckkfacebook-github-bot
authored andcommitted
Support MCH for semi-sync (assuming no eviction) (#2753)
Summary: Pull Request resolved: #2753 ZCH modules return a tuple of awaitables for embeddings and remapped KJTs. Update semi-sync training code to account for this. Reviewed By: dstaay-fb Differential Revision: D69861054 fbshipit-source-id: 2bec964209ce84e973e2c37c8aa7465f129a1e24
1 parent 0108153 commit 919bbcb

File tree

6 files changed

+273
-20
lines changed

6 files changed

+273
-20
lines changed

torchrec/distributed/composable/tests/test_ddp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,13 @@ def _run(cls, rank: int, world_size: int, path: str) -> None:
105105
weighted_tables=weighted_tables,
106106
dense_device=ctx.device,
107107
)
108+
# pyre-ignore
108109
m.sparse.ebc = trec_shard(
109110
module=m.sparse.ebc,
110111
device=ctx.device,
111112
plan=column_wise(ranks=list(range(world_size))),
112113
)
114+
# pyre-ignore
113115
m.sparse.weighted_ebc = trec_shard(
114116
module=m.sparse.weighted_ebc,
115117
device=ctx.device,

torchrec/distributed/composable/tests/test_fsdp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,13 @@ def _run( # noqa
8383
m.sparse.parameters(),
8484
{"lr": 0.01},
8585
)
86+
# pyre-ignore
8687
m.sparse.ebc = trec_shard(
8788
module=m.sparse.ebc,
8889
device=ctx.device,
8990
plan=row_wise(),
9091
)
92+
# pyre-ignore
9193
m.sparse.weighted_ebc = trec_shard(
9294
module=m.sparse.weighted_ebc,
9395
device=ctx.device,

torchrec/distributed/test_utils/test_model.py

Lines changed: 210 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,18 @@
2626
)
2727
from torchrec.distributed.fused_embedding import FusedEmbeddingCollectionSharder
2828
from torchrec.distributed.fused_embeddingbag import FusedEmbeddingBagCollectionSharder
29-
from torchrec.distributed.types import QuantizedCommCodecs
29+
from torchrec.distributed.mc_embedding_modules import (
30+
BaseManagedCollisionEmbeddingCollectionSharder,
31+
)
32+
from torchrec.distributed.mc_embeddingbag import (
33+
ShardedManagedCollisionEmbeddingBagCollection,
34+
)
35+
from torchrec.distributed.mc_modules import ManagedCollisionCollectionSharder
36+
from torchrec.distributed.types import (
37+
ParameterSharding,
38+
QuantizedCommCodecs,
39+
ShardingEnv,
40+
)
3041
from torchrec.distributed.utils import CopyableMixin
3142
from torchrec.modules.activation import SwishLayerNorm
3243
from torchrec.modules.embedding_configs import (
@@ -39,6 +50,12 @@
3950
from torchrec.modules.feature_processor import PositionWeightedProcessor
4051
from torchrec.modules.feature_processor_ import PositionWeightedModuleCollection
4152
from torchrec.modules.fp_embedding_modules import FeatureProcessedEmbeddingBagCollection
53+
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingBagCollection
54+
from torchrec.modules.mc_modules import (
55+
DistanceLFU_EvictionPolicy,
56+
ManagedCollisionCollection,
57+
MCHManagedCollisionModule,
58+
)
4259
from torchrec.modules.regroup import KTRegroupAsDict
4360
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
4461
from torchrec.streamable import Pipelineable
@@ -1351,6 +1368,7 @@ def __init__(
13511368
feature_processor_modules: Optional[Dict[str, torch.nn.Module]] = None,
13521369
over_arch_clazz: Type[nn.Module] = TestOverArch,
13531370
postproc_module: Optional[nn.Module] = None,
1371+
zch: bool = False,
13541372
) -> None:
13551373
super().__init__(
13561374
tables=cast(List[BaseEmbeddingConfig], tables),
@@ -1362,12 +1380,20 @@ def __init__(
13621380
if weighted_tables is None:
13631381
weighted_tables = []
13641382
self.dense = TestDenseArch(num_float_features, dense_device)
1365-
self.sparse = TestSparseArch(
1366-
tables,
1367-
weighted_tables,
1368-
sparse_device,
1369-
max_feature_lengths,
1370-
)
1383+
if zch:
1384+
self.sparse: nn.Module = TestSparseArchZCH(
1385+
tables,
1386+
weighted_tables,
1387+
torch.device("meta"),
1388+
return_remapped=True,
1389+
)
1390+
else:
1391+
self.sparse = TestSparseArch(
1392+
tables,
1393+
weighted_tables,
1394+
sparse_device,
1395+
max_feature_lengths,
1396+
)
13711397

13721398
embedding_names = (
13731399
list(embedding_groups.values())[0] if embedding_groups else None
@@ -1687,6 +1713,64 @@ def compute_kernels(
16871713
return [self._kernel_type]
16881714

16891715

1716+
class TestMCSharder(ManagedCollisionCollectionSharder):
1717+
def __init__(
1718+
self,
1719+
sharding_type: str,
1720+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
1721+
) -> None:
1722+
self._sharding_type = sharding_type
1723+
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
1724+
1725+
def sharding_types(self, compute_device_type: str) -> List[str]:
1726+
return [self._sharding_type]
1727+
1728+
1729+
class TestEBCSharderMCH(
1730+
BaseManagedCollisionEmbeddingCollectionSharder[
1731+
ManagedCollisionEmbeddingBagCollection
1732+
]
1733+
):
1734+
def __init__(
1735+
self,
1736+
sharding_type: str,
1737+
kernel_type: str,
1738+
fused_params: Optional[Dict[str, Any]] = None,
1739+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
1740+
) -> None:
1741+
super().__init__(
1742+
TestEBCSharder(
1743+
sharding_type, kernel_type, fused_params, qcomm_codecs_registry
1744+
),
1745+
TestMCSharder(sharding_type, qcomm_codecs_registry),
1746+
qcomm_codecs_registry=qcomm_codecs_registry,
1747+
)
1748+
1749+
@property
1750+
def module_type(self) -> Type[ManagedCollisionEmbeddingBagCollection]:
1751+
return ManagedCollisionEmbeddingBagCollection
1752+
1753+
def shard(
1754+
self,
1755+
module: ManagedCollisionEmbeddingBagCollection,
1756+
params: Dict[str, ParameterSharding],
1757+
env: ShardingEnv,
1758+
device: Optional[torch.device] = None,
1759+
module_fqn: Optional[str] = None,
1760+
) -> ShardedManagedCollisionEmbeddingBagCollection:
1761+
if device is None:
1762+
device = torch.device("cuda")
1763+
return ShardedManagedCollisionEmbeddingBagCollection(
1764+
module,
1765+
params,
1766+
# pyre-ignore [6]
1767+
ebc_sharder=self._e_sharder,
1768+
mc_sharder=self._mc_sharder,
1769+
env=env,
1770+
device=device,
1771+
)
1772+
1773+
16901774
class TestFusedEBCSharder(FusedEmbeddingBagCollectionSharder):
16911775
def __init__(
16921776
self,
@@ -2188,3 +2272,122 @@ def forward(self, input: ModelInput) -> ModelInput:
21882272
modified_input = copy.deepcopy(input)
21892273
modified_input.idlist_features = self.fp_proc(modified_input.idlist_features)
21902274
return modified_input
2275+
2276+
2277+
class TestSparseArchZCH(nn.Module):
2278+
"""
2279+
Basic nn.Module for testing MCH EmbeddingBagCollection
2280+
2281+
Args:
2282+
tables
2283+
weighted_tables
2284+
device
2285+
return_remapped
2286+
2287+
Call Args:
2288+
features
2289+
weighted_features
2290+
batch_size
2291+
2292+
Returns:
2293+
KeyedTensor
2294+
2295+
Example::
2296+
2297+
TestSparseArch()
2298+
"""
2299+
2300+
def __init__(
2301+
self,
2302+
tables: List[EmbeddingBagConfig],
2303+
weighted_tables: List[EmbeddingBagConfig],
2304+
device: torch.device,
2305+
return_remapped: bool = False,
2306+
) -> None:
2307+
super().__init__()
2308+
self._return_remapped = return_remapped
2309+
2310+
mc_modules = {}
2311+
for table in tables:
2312+
mc_modules[table.name] = MCHManagedCollisionModule(
2313+
zch_size=table.num_embeddings,
2314+
input_hash_size=4000,
2315+
device=device,
2316+
# TODO: If eviction interval is set to
2317+
# a low number (e.g. 2), semi-sync pipeline test will
2318+
# fail with in-place modification error during
2319+
# loss.backward(). This is because during semi-sync training,
2320+
# we run embedding module forward after autograd graph
2321+
# is constructed, but if MCH eviction happens, the
2322+
# variable used in autograd will have been modified
2323+
eviction_interval=1000,
2324+
eviction_policy=DistanceLFU_EvictionPolicy(),
2325+
)
2326+
2327+
self.ebc: ManagedCollisionEmbeddingBagCollection = (
2328+
ManagedCollisionEmbeddingBagCollection(
2329+
EmbeddingBagCollection(
2330+
tables=tables,
2331+
device=device,
2332+
),
2333+
ManagedCollisionCollection(
2334+
managed_collision_modules=mc_modules,
2335+
embedding_configs=tables,
2336+
),
2337+
return_remapped_features=self._return_remapped,
2338+
)
2339+
)
2340+
2341+
self.weighted_ebc: Optional[ManagedCollisionEmbeddingBagCollection] = None
2342+
if weighted_tables:
2343+
weighted_mc_modules = {}
2344+
for table in weighted_tables:
2345+
weighted_mc_modules[table.name] = MCHManagedCollisionModule(
2346+
zch_size=table.num_embeddings,
2347+
input_hash_size=4000,
2348+
device=device,
2349+
# TODO: Support MCH evictions during semi-sync
2350+
eviction_interval=1000,
2351+
eviction_policy=DistanceLFU_EvictionPolicy(),
2352+
)
2353+
self.weighted_ebc: ManagedCollisionEmbeddingBagCollection = (
2354+
ManagedCollisionEmbeddingBagCollection(
2355+
EmbeddingBagCollection(
2356+
tables=weighted_tables,
2357+
device=device,
2358+
is_weighted=True,
2359+
),
2360+
ManagedCollisionCollection(
2361+
managed_collision_modules=weighted_mc_modules,
2362+
embedding_configs=weighted_tables,
2363+
),
2364+
return_remapped_features=self._return_remapped,
2365+
)
2366+
)
2367+
2368+
def forward(
2369+
self,
2370+
features: KeyedJaggedTensor,
2371+
weighted_features: Optional[KeyedJaggedTensor] = None,
2372+
batch_size: Optional[int] = None,
2373+
) -> KeyedTensor:
2374+
"""
2375+
Runs forward and MC EBC and optionally, weighted MC EBC,
2376+
then merges the results into one KeyedTensor
2377+
2378+
Args:
2379+
features
2380+
weighted_features
2381+
batch_size
2382+
Returns:
2383+
KeyedTensor
2384+
"""
2385+
ebc, _ = self.ebc(features)
2386+
ebc = _post_ebc_test_wrap_function(ebc)
2387+
w_ebc, _ = (
2388+
self.weighted_ebc(weighted_features)
2389+
if self.weighted_ebc is not None and weighted_features is not None
2390+
else None
2391+
)
2392+
result = _post_sparsenn_forward(ebc, None, w_ebc, batch_size)
2393+
return result

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from unittest.mock import MagicMock
1818

1919
import torch
20-
from hypothesis import given, settings, strategies as st, Verbosity
20+
from hypothesis import assume, given, settings, strategies as st, Verbosity
2121
from torch import nn, optim
2222
from torch._dynamo.testing import reduce_to_scalar_loss
2323
from torch._dynamo.utils import counters
@@ -1531,7 +1531,7 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
15311531
not torch.cuda.is_available(),
15321532
"Not enough GPUs, this test requires at least one GPU",
15331533
)
1534-
@settings(max_examples=4, deadline=None)
1534+
@settings(max_examples=8, deadline=None)
15351535
# pyre-ignore[56]
15361536
@given(
15371537
start_batch=st.sampled_from([0, 6]),
@@ -1547,17 +1547,21 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
15471547
EmbeddingComputeKernel.FUSED.value,
15481548
]
15491549
),
1550+
zch=st.booleans(),
15501551
)
15511552
def test_equal_to_non_pipelined(
15521553
self,
15531554
start_batch: int,
15541555
stash_gradients: bool,
15551556
sharding_type: str,
15561557
kernel_type: str,
1558+
zch: bool,
15571559
) -> None:
15581560
"""
15591561
Checks that pipelined training is equivalent to non-pipelined training.
15601562
"""
1563+
# ZCH only supports row-wise currently
1564+
assume(not zch or (zch and sharding_type != ShardingType.TABLE_WISE.value))
15611565
torch.autograd.set_detect_anomaly(True)
15621566
data = self._generate_data(
15631567
num_batches=12,
@@ -1572,7 +1576,7 @@ def test_equal_to_non_pipelined(
15721576
**fused_params,
15731577
}
15741578

1575-
model = self._setup_model()
1579+
model = self._setup_model(zch=zch)
15761580
sharded_model, optim = self._generate_sharded_model_and_optimizer(
15771581
model, sharding_type, kernel_type, fused_params
15781582
)

torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from torchrec.distributed.test_utils.test_model import (
2222
ModelInput,
2323
TestEBCSharder,
24+
TestEBCSharderMCH,
2425
TestSparseNN,
2526
)
2627
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist
@@ -96,13 +97,15 @@ def _setup_model(
9697
model_type: Type[nn.Module] = TestSparseNN,
9798
enable_fsdp: bool = False,
9899
postproc_module: Optional[nn.Module] = None,
100+
zch: bool = False,
99101
) -> nn.Module:
100102
unsharded_model = model_type(
101103
tables=self.tables,
102104
weighted_tables=self.weighted_tables,
103105
dense_device=self.device,
104106
sparse_device=torch.device("meta"),
105107
postproc_module=postproc_module,
108+
zch=zch,
106109
)
107110
if enable_fsdp:
108111
unsharded_model.over.dhn_arch.linear0 = FSDP(
@@ -135,6 +138,11 @@ def _generate_sharded_model_and_optimizer(
135138
kernel_type=kernel_type,
136139
fused_params=fused_params,
137140
)
141+
mc_sharder = TestEBCSharderMCH(
142+
sharding_type=sharding_type,
143+
kernel_type=kernel_type,
144+
fused_params=fused_params,
145+
)
138146
sharded_model = DistributedModelParallel(
139147
module=copy.deepcopy(model),
140148
env=ShardingEnv.from_process_group(self.pg),
@@ -144,7 +152,11 @@ def _generate_sharded_model_and_optimizer(
144152
cast(
145153
ModuleSharder[nn.Module],
146154
sharder,
147-
)
155+
),
156+
cast(
157+
ModuleSharder[nn.Module],
158+
mc_sharder,
159+
),
148160
],
149161
)
150162
# default fused optimizer is SGD w/ lr=0.1; we need to drop params

0 commit comments

Comments
 (0)