Skip to content

Commit fade86a

Browse files
gnahzgPaulZhang12
authored andcommitted
skip empty rank lookup (#2293)
Summary: Pull Request resolved: #2293 We have observe issue when run in predictor side with empty ranks (rank with TBEs). In this diff, we try to skip the creation of lookup for empty rank to remove all invalid operations for empty rank. Reviewed By: dstaay-fb Differential Revision: D61020328 fbshipit-source-id: 2278a1f13d2981030a01919c080d7ef010bfa010
1 parent 592fd66 commit fade86a

File tree

4 files changed

+108
-9
lines changed

4 files changed

+108
-9
lines changed

torchrec/distributed/dist_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@ def forward(self, tensors: List[torch.Tensor]) -> torch.Tensor:
10501050
Returns:
10511051
Awaitable[torch.Tensor]: awaitable of the merged embeddings.
10521052
"""
1053-
assert len(tensors) == self._world_size
1053+
assert len(tensors) <= self._world_size
10541054

10551055
is_target_device_cpu: bool = self._device.type == "cpu"
10561056

torchrec/distributed/embedding_lookup.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,21 +1014,43 @@ def __init__(
10141014
"meta" if device is not None and device.type == "meta" else "cuda"
10151015
)
10161016

1017+
self._is_empty_rank: List[bool] = []
10171018
for rank in range(world_size):
1018-
self._embedding_lookups_per_rank.append(
1019-
# TODO add position weighted module support
1020-
MetaInferGroupedPooledEmbeddingsLookup(
1021-
grouped_configs=grouped_configs_per_rank[rank],
1022-
device=rank_device(device_type, rank),
1023-
fused_params=fused_params,
1019+
empty_rank = len(grouped_configs_per_rank[rank]) == 0
1020+
self._is_empty_rank.append(empty_rank)
1021+
if not empty_rank:
1022+
self._embedding_lookups_per_rank.append(
1023+
# TODO add position weighted module support
1024+
MetaInferGroupedPooledEmbeddingsLookup(
1025+
grouped_configs=grouped_configs_per_rank[rank],
1026+
device=rank_device(device_type, rank),
1027+
fused_params=fused_params,
1028+
)
10241029
)
1025-
)
10261030

10271031
def get_tbes_to_register(
10281032
self,
10291033
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
10301034
return get_tbes_to_register_from_iterable(self._embedding_lookups_per_rank)
10311035

1036+
def forward(
1037+
self,
1038+
input_dist_outputs: InputDistOutputs,
1039+
) -> List[torch.Tensor]:
1040+
embeddings: List[torch.Tensor] = []
1041+
sparse_features = [
1042+
input_dist_outputs.features[i]
1043+
for i, is_empty in enumerate(self._is_empty_rank)
1044+
if not is_empty
1045+
]
1046+
# syntax for torchscript
1047+
for i, embedding_lookup in enumerate(
1048+
self._embedding_lookups_per_rank,
1049+
):
1050+
sparse_features_rank = sparse_features[i]
1051+
embeddings.append(embedding_lookup.forward(sparse_features_rank))
1052+
return embeddings
1053+
10321054

10331055
class InferGroupedEmbeddingsLookup(
10341056
InferGroupedLookupMixin,

torchrec/distributed/test_utils/infer_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,11 +825,15 @@ def shard_qebc(
825825
expected_shards: Optional[List[List[Tuple[Tuple[int, int, int, int], str]]]] = None,
826826
plan: Optional[ShardingPlan] = None,
827827
ebc_fqn: str = "_module.sparse.ebc",
828+
shard_score_ebc: bool = False,
828829
) -> torch.nn.Module:
829830
sharder = TestQuantEBCSharder(
830831
sharding_type=sharding_type.value,
831832
kernel_type=EmbeddingComputeKernel.QUANT.value,
832-
shardable_params=[table.name for table in mi.tables],
833+
shardable_params=(
834+
[table.name for table in mi.tables]
835+
+ ([table.name for table in mi.weighted_tables] if shard_score_ebc else [])
836+
),
833837
)
834838
if not plan:
835839
# pyre-ignore

torchrec/distributed/tests/test_infer_shardings.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,79 @@ def test_tw(self, weight_dtype: torch.dtype, device_type: str) -> None:
204204
ShardingType.TABLE_WISE.value,
205205
)
206206

207+
@unittest.skipIf(
208+
torch.cuda.device_count() <= 1,
209+
"Not enough GPUs available",
210+
)
211+
# pyre-ignore
212+
@given(
213+
weight_dtype=st.sampled_from([torch.qint8]),
214+
device_type=st.sampled_from(["cuda"]),
215+
)
216+
@settings(max_examples=4, deadline=None)
217+
def test_tw_ebc_full_rank_weighted_ebc_with_empty_rank(
218+
self, weight_dtype: torch.dtype, device_type: str
219+
) -> None:
220+
set_propogate_device(True)
221+
num_embeddings = 256
222+
emb_dim = 16
223+
world_size = 2
224+
batch_size = 4
225+
local_device = torch.device(f"{device_type}:0")
226+
mi = create_test_model(
227+
num_embeddings,
228+
emb_dim,
229+
world_size,
230+
batch_size,
231+
dense_device=local_device,
232+
sparse_device=local_device,
233+
quant_state_dict_split_scale_bias=True,
234+
weight_dtype=weight_dtype,
235+
num_features=6, # 6 sparse features on ebc
236+
num_weighted_features=1, # only 1 weighted sparse feature on weighted_ebc
237+
)
238+
239+
non_sharded_model = mi.quant_model
240+
241+
sharded_model = shard_qebc(
242+
mi,
243+
sharding_type=ShardingType.TABLE_WISE,
244+
device=local_device,
245+
expected_shards=None,
246+
shard_score_ebc=True,
247+
)
248+
249+
self.assertEqual(
250+
len(
251+
sharded_model._module.sparse.ebc._lookups[0]._embedding_lookups_per_rank
252+
),
253+
2,
254+
)
255+
self.assertEqual(
256+
len(
257+
sharded_model._module.sparse.weighted_ebc._lookups[
258+
0
259+
]._embedding_lookups_per_rank
260+
),
261+
1,
262+
)
263+
264+
inputs = [
265+
model_input_to_forward_args(inp.to(local_device))
266+
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
267+
]
268+
269+
sharded_model.load_state_dict(non_sharded_model.state_dict())
270+
271+
sharded_output = sharded_model(*inputs[0])
272+
non_sharded_output = non_sharded_model(*inputs[0])
273+
assert_close(sharded_output, non_sharded_output)
274+
275+
gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
276+
gm_script = torch.jit.script(gm)
277+
gm_script_output = gm_script(*inputs[0])
278+
assert_close(sharded_output, gm_script_output)
279+
207280
@unittest.skipIf(
208281
torch.cuda.device_count() <= 1,
209282
"Not enough GPUs available",

0 commit comments

Comments
 (0)