Skip to content

Commit d14d02c

Browse files
committed
Enable Consistent SHA256 Hashing with reduced Planner Context (pytorch#3091)
Summary: Pull Request resolved: pytorch#3091 Even though SHA256 hashing is used, we're still not seeing the expected same hash generated from the original planner context inputs. This problem is due to Enumerator and Storage Reservation objects we were originally trying to hash containing attributes that differ between processes/instances. To resolve this we reduced the hashing context to only use the specific attributes we need from enumerator and storage reservation. Namely: * enumerator.enumerate(...)'s output - which is used as the `search_space` in both LP and OSS planner * We are storing the output of enumerate as an attribute `last_stored_search_space`. **This assumes enumerate will have been called before we hash the planner context inputs**. * StorageResveration's policy (aka whether `HeuristicalStorageReservation` is used or `FixedStorageReservation` * StorageResveration's initialization attributes: * _percentage * _parameter_multiplier for HeuristicalStorageReservation * _dense_tensor_estimate for HeuristicalStorageReservation Created helper functions: * `hash_planner_context_inputs` to be called in both planner.hash_planner_context_inputs and manifold loading call site (see D75723272) * `hash_sha256_to_int` to be passed in as the default hash function in hash_planner_context_inputs Also created a multiprocess unit test to quickly check if consistent hashes are being generated across different processes given the same input. Differential Revision: D76303748
1 parent 65b82f9 commit d14d02c

File tree

6 files changed

+237
-35
lines changed

6 files changed

+237
-35
lines changed

torchrec/distributed/planner/enumerators.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# pyre-strict
99

10+
import copy
1011
import logging
1112
from typing import Dict, List, Optional, Set, Tuple, Union
1213

@@ -102,6 +103,11 @@ def __init__(
102103
EmbeddingStorageEstimator(topology=topology, constraints=constraints),
103104
]
104105

106+
# Initializing caching for enumerate
107+
self._last_stored_search_space: Optional[List[ShardingOption]] = None
108+
self._last_stored_module: Optional[nn.Module] = None
109+
self._last_stored_sharders: Optional[List[ModuleSharder[nn.Module]]] = None
110+
105111
def enumerate(
106112
self,
107113
module: nn.Module,
@@ -118,6 +124,12 @@ def enumerate(
118124
List[ShardingOption]: valid sharding options with values populated.
119125
"""
120126

127+
if (
128+
self._last_stored_module == module
129+
and self._last_stored_sharders == sharders
130+
):
131+
return self._last_stored_search_space # pyre-ignore
132+
121133
self._sharder_map = {
122134
sharder_name(sharder.module_type): sharder for sharder in sharders
123135
}
@@ -230,8 +242,20 @@ def enumerate(
230242

231243
self.populate_estimates(sharding_options)
232244

245+
self._last_stored_module = module
246+
self._last_stored_sharders = sharders
247+
248+
# Caching the search space with a copy of sharding options, to avoid unexpected modifications to list
249+
self._last_stored_search_space = copy.deepcopy(sharding_options)
233250
return sharding_options
234251

252+
@property
253+
def last_stored_search_space(self) -> Optional[List[ShardingOption]]:
254+
# NOTE: This is the last search space stored by enumerate(...), do not use
255+
# this field in place of actually calling enumerate(...) as it will varie for each
256+
# module/sharders passed in.
257+
return self._last_stored_search_space
258+
235259
def populate_estimates(self, sharding_options: List[ShardingOption]) -> None:
236260
for estimator in self._estimators:
237261
estimator.estimate(sharding_options, self._sharder_map)

torchrec/distributed/planner/planners.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from torchrec.distributed.planner.types import (
4141
Enumerator,
42+
hash_planner_context_inputs,
4243
ParameterConstraints,
4344
Partitioner,
4445
PerfModel,
@@ -280,25 +281,21 @@ def collective_plan(
280281
sharders,
281282
)
282283

283-
def hash_planner_context_inputs(self) -> str:
284+
def hash_planner_context_inputs(self) -> int:
284285
"""
285286
Generates a hash for all planner inputs except for partitioner, proposer, performance model, and stats.
286287
These are all the inputs needed to verify whether a previously generated sharding plan is still valid in a new context.
287288
288289
Returns:
289290
Generates a hash capturing topology, batch size, enumerator, storage reservation, stats and constraints.
290291
"""
291-
hashable_list = [
292+
return hash_planner_context_inputs(
292293
self._topology,
293294
self._batch_size,
294295
self._enumerator,
295296
self._storage_reservation,
296-
frozenset(self._constraints.items()) if self._constraints else None,
297-
]
298-
serialized_list = str(hashable_list).encode("utf-8")
299-
hash_object = hashlib.sha256(serialized_list)
300-
hash_digest = hash_object.hexdigest()
301-
return hash_digest
297+
self._constraints,
298+
)
302299

303300
def plan(
304301
self,

torchrec/distributed/planner/storage_reservations.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ class FixedPercentageStorageReservation(StorageReservation):
163163
def __init__(self, percentage: float) -> None:
164164
assert percentage >= 0 and percentage <= 1
165165
self._percentage: float = percentage
166+
self._last_reserved_toplogy: Optional[Topology] = None
166167

167168
def reserve(
168169
self,
@@ -174,8 +175,14 @@ def reserve(
174175
) -> Topology:
175176
reserved_topology = copy.deepcopy(topology)
176177
_reserve_storage_percentage(reserved_topology, self._percentage)
178+
self._last_reserved_toplogy = reserved_topology
177179
return reserved_topology
178180

181+
@property
182+
def last_reserved_toplogy(self) -> Optional[Topology]:
183+
"Cached value of the most recent output from the reserve() method."
184+
return self._last_reserved_toplogy
185+
179186

180187
class HeuristicalStorageReservation(StorageReservation):
181188
"""
@@ -206,6 +213,7 @@ def __init__(
206213

207214
self._dense_storage: Optional[Storage] = None
208215
self._kjt_storage: Optional[Storage] = None
216+
self._last_reserved_toplogy: Optional[Topology] = None
209217

210218
def reserve(
211219
self,
@@ -215,6 +223,7 @@ def reserve(
215223
sharders: List[ModuleSharder[nn.Module]],
216224
constraints: Optional[Dict[str, ParameterConstraints]] = None,
217225
) -> Topology:
226+
# TODO: enable proper caching of topology values through _last_reserved_toplogy
218227
reserved_topology = copy.deepcopy(topology)
219228

220229
batch_inputs, shardable_modules = _get_batch_inputs_and_shardable_parameters(
@@ -262,8 +271,14 @@ def reserve(
262271
message=negative_storage_solution,
263272
)
264273

274+
self._last_reserved_toplogy = reserved_topology
265275
return reserved_topology
266276

277+
@property
278+
def last_reserved_toplogy(self) -> Optional[Topology]:
279+
"Cached value of the most recent output from the reserve() method."
280+
return self._last_reserved_toplogy
281+
267282

268283
class InferenceStorageReservation(StorageReservation):
269284
"""

torchrec/distributed/planner/tests/test_planners.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
from torch import nn
15-
from torchrec import EmbeddingConfig
15+
from torchrec import EmbeddingBagCollection, EmbeddingConfig
1616
from torchrec.distributed.embedding import EmbeddingCollectionSharder
1717
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
1818
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
@@ -306,6 +306,22 @@ def test_passing_info_through_constraints(self) -> None:
306306
class TestEmbeddingShardingHashPlannerContextInputs(unittest.TestCase):
307307

308308
def setUp(self) -> None:
309+
eb_config = EmbeddingBagConfig(
310+
name="table_0",
311+
embedding_dim=160,
312+
num_embeddings=10000,
313+
feature_names=["f1"],
314+
data_type=DataType.FP16,
315+
)
316+
module = EmbeddingBagCollection(
317+
tables=[eb_config],
318+
is_weighted=False,
319+
device=torch.device(
320+
"meta"
321+
), # Using meta device for now since only getting search space
322+
)
323+
sharders = [EmbeddingBagCollectionSharder()]
324+
309325
self.topology = Topology(
310326
local_world_size=8,
311327
world_size=1,
@@ -315,10 +331,20 @@ def setUp(self) -> None:
315331
self.enumerator = EmbeddingEnumerator(
316332
topology=self.topology, batch_size=self.batch_size
317333
)
334+
self.enumerator.enumerate(module, sharders) # pyre-ignore
335+
318336
self.storage_reservation = HeuristicalStorageReservation(percentage=0.15)
319337
self.perf_model = NoopPerfModel(topology=self.topology)
320338
self.constraints = {"table1": ParameterConstraints()}
321339

340+
self.storage_reservation.reserve(
341+
topology=self.topology,
342+
batch_size=self.batch_size,
343+
module=module,
344+
sharders=sharders, # pyre-ignore
345+
constraints=self.constraints,
346+
)
347+
322348
def test_hash_equality(self) -> None:
323349
planner1 = EmbeddingShardingPlanner(
324350
topology=self.topology,

torchrec/distributed/planner/tests/test_types.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,30 @@
88
# pyre-strict
99

1010
import unittest
11-
from typing import cast
11+
from typing import cast, Dict, Optional
1212
from unittest.mock import MagicMock
1313

1414
import torch
15+
from torch import multiprocessing
1516
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
17+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
18+
from torchrec.distributed.planner import EmbeddingShardingPlanner
19+
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
20+
from torchrec.distributed.planner.perf_models import NoopPerfModel
21+
from torchrec.distributed.planner.storage_reservations import (
22+
HeuristicalStorageReservation,
23+
)
1624

1725
from torchrec.distributed.planner.types import (
1826
ParameterConstraints,
1927
Shard,
2028
ShardingOption,
2129
Topology,
2230
)
31+
from torchrec.distributed.test_utils.multi_process import (
32+
MultiProcessContext,
33+
MultiProcessTestBase,
34+
)
2335
from torchrec.distributed.types import (
2436
BoundsCheckMode,
2537
CacheAlgorithm,
@@ -348,3 +360,75 @@ def test_hash_inequality(self) -> None:
348360
self.assertNotEqual(
349361
hash(pc1), hash(pc2), "Hashes should be different for different instances"
350362
)
363+
364+
365+
def _test_hashing_consistency(
366+
rank: int,
367+
world_size: int,
368+
backend: str,
369+
return_hash_dict: Dict[str, int],
370+
local_size: Optional[int] = None,
371+
) -> None:
372+
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
373+
topology = Topology(
374+
local_world_size=8,
375+
world_size=1,
376+
compute_device="cuda",
377+
)
378+
batch_size = 128
379+
enumerator = EmbeddingEnumerator(topology=topology, batch_size=batch_size)
380+
eb_config = EmbeddingBagConfig(
381+
name="table_0",
382+
embedding_dim=160,
383+
num_embeddings=10000,
384+
feature_names=["f1"],
385+
data_type=DataType.FP16,
386+
)
387+
module = EmbeddingBagCollection(
388+
tables=[eb_config],
389+
is_weighted=False,
390+
device=torch.device(
391+
"meta"
392+
), # Using meta device for now since only getting search space
393+
)
394+
sharders = [EmbeddingBagCollectionSharder()]
395+
enumerator.enumerate(module, sharders) # pyre-ignore
396+
storage_reservation = HeuristicalStorageReservation(percentage=0.15)
397+
constraints = {"table1": ParameterConstraints()}
398+
399+
storage_reservation.reserve(
400+
topology=topology,
401+
batch_size=batch_size,
402+
module=module,
403+
sharders=sharders, # pyre-ignore
404+
constraints=constraints,
405+
)
406+
perf_model = NoopPerfModel(topology=topology)
407+
408+
planner1 = EmbeddingShardingPlanner(
409+
topology=topology,
410+
batch_size=batch_size,
411+
enumerator=enumerator,
412+
storage_reservation=storage_reservation,
413+
performance_model=perf_model,
414+
constraints=constraints,
415+
)
416+
417+
h = planner1.hash_planner_context_inputs()
418+
return_hash_dict[str(rank)] = h
419+
420+
421+
class TestConsistentHashingBetweenProcesses(MultiProcessTestBase):
422+
423+
def test_hash_consistency(self) -> None:
424+
# planner
425+
world_size = 2
426+
return_hash_dict = multiprocessing.Manager().dict()
427+
self._run_multi_process_test(
428+
callable=_test_hashing_consistency,
429+
world_size=world_size,
430+
backend="nccl" if torch.cuda.is_available() else "gloo",
431+
return_hash_dict=return_hash_dict,
432+
)
433+
hashes = return_hash_dict.values()
434+
assert hashes[0] == hashes[1], "hash values are different."

0 commit comments

Comments
 (0)