Skip to content

Commit 306057a

Browse files
Bin WenPaulZhang12
authored andcommitted
torchrec module (#2297)
Summary: Pull Request resolved: #2297 torchrec module to incorporate hash zch kernel with supporting training and inference needs. Reviewed By: dstaay-fb, bixue2010 Differential Revision: D60942972 fbshipit-source-id: c3a7f6fa77a7edfa6881c2b55454cb1b44779832
1 parent fade86a commit 306057a

File tree

2 files changed

+96
-68
lines changed

2 files changed

+96
-68
lines changed

torchrec/distributed/mc_modules.py

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@
5151
ShardMetadata,
5252
)
5353
from torchrec.distributed.utils import append_prefix
54-
from torchrec.modules.mc_modules import (
55-
apply_mc_method_to_jt_dict,
56-
ManagedCollisionCollection,
57-
)
54+
from torchrec.modules.mc_modules import ManagedCollisionCollection
5855
from torchrec.modules.utils import construct_jagged_tensors
5956
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
6057

@@ -191,25 +188,28 @@ def _initialize_torch_state(self) -> None:
191188
if name not in shardable_buffers:
192189
continue
193190

191+
sharded_sizes = list(tensor.shape)
192+
sharded_sizes[0] = shard_size
193+
shard_offsets = [0] * len(sharded_sizes)
194+
shard_offsets[0] = shard_offset
195+
global_sizes = list(tensor.shape)
196+
global_sizes[0] = global_size
194197
self._model_parallel_mc_buffer_name_to_sharded_tensor[name] = (
195198
ShardedTensor._init_from_local_shards(
196199
[
197200
Shard(
198201
tensor=tensor,
199202
metadata=ShardMetadata(
200-
# pyre-ignore [6]
201-
shard_offsets=[shard_offset],
202-
# pyre-ignore [6]
203-
shard_sizes=[shard_size],
203+
shard_offsets=shard_offsets,
204+
shard_sizes=sharded_sizes,
204205
placement=(
205206
f"rank:{self._env.rank}/cuda:"
206207
f"{get_local_rank(self._env.world_size, self._env.rank)}"
207208
),
208209
),
209210
)
210211
],
211-
# pyre-ignore [6]
212-
torch.Size([global_size]),
212+
torch.Size(global_sizes),
213213
process_group=self._env.process_group,
214214
)
215215
)
@@ -256,9 +256,7 @@ def _create_managed_collision_modules(
256256
self, module: ManagedCollisionCollection
257257
) -> None:
258258

259-
self._mc_module_name_shard_metadata: DefaultDict[
260-
str, DefaultDict[str, List[int]]
261-
] = defaultdict(lambda: defaultdict(list))
259+
self._mc_module_name_shard_metadata: DefaultDict[str, List[int]] = defaultdict()
262260
self._feature_to_offset: Dict[str, int] = {}
263261

264262
for sharding in self._embedding_shardings:
@@ -392,15 +390,19 @@ def input_dist(
392390
self._has_uninitialized_input_dists = False
393391

394392
with torch.no_grad():
393+
features_dict = features.to_dict()
394+
output: Dict[str, JaggedTensor] = features_dict.copy()
395+
for table, mc_module in self._managed_collision_modules.items():
396+
feature_list: List[str] = self._table_to_features[table]
397+
mc_input: Dict[str, JaggedTensor] = {}
398+
for feature in feature_list:
399+
mc_input[feature] = features_dict[feature]
400+
mc_input = mc_module.preprocess(mc_input)
401+
output.update(mc_input)
402+
395403
# NOTE shared features not currently supported
396-
features = KeyedJaggedTensor.from_jt_dict(
397-
apply_mc_method_to_jt_dict(
398-
"preprocess",
399-
features.to_dict(),
400-
self._table_to_features,
401-
self._managed_collision_modules,
402-
)
403-
)
404+
features = KeyedJaggedTensor.from_jt_dict(output)
405+
404406
if self._features_order:
405407
features = features.permute(
406408
self._features_order,
@@ -456,19 +458,17 @@ def compute(
456458
-1, features.stride()
457459
)
458460
features_dict = features.to_dict()
459-
features_dict = apply_mc_method_to_jt_dict(
460-
"profile",
461-
features_dict=features_dict,
462-
table_to_features=self._table_to_features,
463-
managed_collisions=self._managed_collision_modules,
464-
)
465-
features_dict = apply_mc_method_to_jt_dict(
466-
"remap",
467-
features_dict=features_dict,
468-
table_to_features=self._table_to_features,
469-
managed_collisions=self._managed_collision_modules,
470-
)
471-
remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(features_dict))
461+
output: Dict[str, JaggedTensor] = features_dict.copy()
462+
for table, mc_module in self._managed_collision_modules.items():
463+
feature_list: List[str] = self._table_to_features[table]
464+
mc_input: Dict[str, JaggedTensor] = {}
465+
for feature in feature_list:
466+
mc_input[feature] = features_dict[feature]
467+
mc_input = mc_module.profile(mc_input)
468+
mc_input = mc_module.remap(mc_input)
469+
output.update(mc_input)
470+
471+
remapped_kjts.append(KeyedJaggedTensor.from_jt_dict(output))
472472

473473
return KJTList(remapped_kjts)
474474

@@ -522,6 +522,7 @@ def create_context(self) -> ManagedCollisionCollectionContext:
522522
return ManagedCollisionCollectionContext(sharding_contexts=[])
523523

524524
def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
525+
# TODO (bwen): this does not include `_hash_zch_identities`
525526
for fqn, _ in self.named_buffers():
526527
yield append_prefix(prefix, fqn)
527528

torchrec/modules/mc_modules.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
#!/usr/bin/env python3
1111

1212
import abc
13-
from collections import defaultdict
1413
from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union
1514

1615
import torch
@@ -30,23 +29,15 @@
3029

3130
@torch.fx.wrap
3231
def apply_mc_method_to_jt_dict(
32+
mc_module: nn.Module,
3333
method: str,
3434
features_dict: Dict[str, JaggedTensor],
35-
table_to_features: Dict[str, List[str]],
36-
managed_collisions: nn.ModuleDict,
3735
) -> Dict[str, JaggedTensor]:
3836
"""
3937
Applies an MC method to a dictionary of JaggedTensors, returning the updated dictionary with same ordering
4038
"""
41-
mc_output: Dict[str, JaggedTensor] = features_dict.copy()
42-
for table, features in table_to_features.items():
43-
mc_input: Dict[str, JaggedTensor] = {}
44-
for feature in features:
45-
mc_input[feature] = features_dict[feature]
46-
mc_module = managed_collisions[table]
47-
attr = getattr(mc_module, method)
48-
mc_output.update(attr(mc_input))
49-
return mc_output
39+
attr = getattr(mc_module, method)
40+
return attr(features_dict)
5041

5142

5243
@torch.no_grad()
@@ -153,6 +144,14 @@ def evict(self) -> Optional[torch.Tensor]:
153144
"""
154145
pass
155146

147+
@abc.abstractmethod
148+
def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
149+
pass
150+
151+
@abc.abstractmethod
152+
def profile(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
153+
pass
154+
156155
@abc.abstractmethod
157156
def forward(
158157
self,
@@ -203,6 +202,8 @@ class ManagedCollisionCollection(nn.Module):
203202
embedding_confgs (List[BaseEmbeddingConfig]): List of embedding configs, for each table with a managed collsion module
204203
"""
205204

205+
_table_to_features: Dict[str, List[str]]
206+
206207
def __init__(
207208
self,
208209
managed_collision_modules: Dict[str, ManagedCollisionModule],
@@ -216,10 +217,13 @@ def __init__(
216217
for config in embedding_configs
217218
for feature in config.feature_names
218219
}
219-
self._table_to_features: Dict[str, List[str]] = defaultdict(list)
220+
self._table_to_features = {}
220221

221222
self._compute_jt_dict_to_kjt = ComputeJTDictToKJT()
222223
for feature, table in self._feature_to_table.items():
224+
if table not in self._table_to_features:
225+
self._table_to_features[table] = []
226+
223227
self._table_to_features[table].append(feature)
224228

225229
table_to_config = {config.name: config for config in embedding_configs}
@@ -243,25 +247,18 @@ def forward(
243247
self,
244248
features: KeyedJaggedTensor,
245249
) -> KeyedJaggedTensor:
246-
features_dict = apply_mc_method_to_jt_dict(
247-
"preprocess",
248-
features_dict=features.to_dict(),
249-
table_to_features=self._table_to_features,
250-
managed_collisions=self._managed_collision_modules,
251-
)
252-
features_dict = apply_mc_method_to_jt_dict(
253-
"profile",
254-
features_dict=features_dict,
255-
table_to_features=self._table_to_features,
256-
managed_collisions=self._managed_collision_modules,
257-
)
258-
features_dict = apply_mc_method_to_jt_dict(
259-
"remap",
260-
features_dict=features_dict,
261-
table_to_features=self._table_to_features,
262-
managed_collisions=self._managed_collision_modules,
263-
)
264-
return self._compute_jt_dict_to_kjt(features_dict)
250+
features_dict = features.to_dict()
251+
output: Dict[str, JaggedTensor] = features_dict.copy()
252+
for table, mc_module in self._managed_collision_modules.items():
253+
feature_list: List[str] = self._table_to_features[table]
254+
mc_input: Dict[str, JaggedTensor] = {}
255+
for feature in feature_list:
256+
mc_input[feature] = features_dict[feature]
257+
mc_input = mc_module.preprocess(mc_input)
258+
mc_input = mc_module.profile(mc_input)
259+
mc_input = mc_module.remap(mc_input)
260+
output.update(mc_input)
261+
return self._compute_jt_dict_to_kjt(output)
265262

266263
def evict(self) -> Dict[str, Optional[torch.Tensor]]:
267264
evictions: Dict[str, Optional[torch.Tensor]] = {}
@@ -933,7 +930,17 @@ def _init_history_buffers(self, features: Dict[str, JaggedTensor]) -> None:
933930
self._history_metadata[metadata_name] = getattr(self, buffer_name)
934931

935932
@torch.no_grad()
936-
def preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
933+
def preprocess(
934+
self,
935+
features: Dict[str, JaggedTensor],
936+
) -> Dict[str, JaggedTensor]:
937+
return apply_mc_method_to_jt_dict(
938+
self,
939+
"_preprocess",
940+
features,
941+
)
942+
943+
def _preprocess(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
937944
if self._input_hash_func is None:
938945
return features
939946
preprocessed_features: Dict[str, JaggedTensor] = {}
@@ -1070,6 +1077,16 @@ def _coalesce_history(self) -> None:
10701077
def profile(
10711078
self,
10721079
features: Dict[str, JaggedTensor],
1080+
) -> Dict[str, JaggedTensor]:
1081+
return apply_mc_method_to_jt_dict(
1082+
self,
1083+
"_profile",
1084+
features,
1085+
)
1086+
1087+
def _profile(
1088+
self,
1089+
features: Dict[str, JaggedTensor],
10731090
) -> Dict[str, JaggedTensor]:
10741091
if not self.training:
10751092
return features
@@ -1115,7 +1132,17 @@ def profile(
11151132
return features
11161133

11171134
@torch.no_grad()
1118-
def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
1135+
def remap(
1136+
self,
1137+
features: Dict[str, JaggedTensor],
1138+
) -> Dict[str, JaggedTensor]:
1139+
return apply_mc_method_to_jt_dict(
1140+
self,
1141+
"_remap",
1142+
features,
1143+
)
1144+
1145+
def _remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
11191146

11201147
remapped_features: Dict[str, JaggedTensor] = {}
11211148
for name, feature in features.items():

0 commit comments

Comments
 (0)