Skip to content

Commit 876222e

Browse files
Ziliang Zhaofacebook-github-bot
authored andcommitted
Create an interface to enable eviction policy (TREC part) (#2481)
Summary: Pull Request resolved: #2481 To support various types of eviction policy, the HashZchManagedCollisionModule needs to be able to calculate a score (e.g., TTL) for each incoming ID and pass it to the kernel. The latter will make an informed eviction decision based on the existing score associated with the identity (stored in metadata), and a reference value (e.g., current timestamp). Reviewed By: dracifer, shruthign Differential Revision: D64163927 fbshipit-source-id: b5e65fbf0568e7a4950988128fd19c1ffa4426ca
1 parent 54ec8aa commit 876222e

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

torchrec/distributed/mc_modules.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,8 @@ def compute(
646646
table: JaggedTensor(
647647
values=kjt.values(),
648648
lengths=kjt.lengths(),
649+
# TODO: improve this temp solution by passing real weights
650+
weights=torch.tensor(kjt.length_per_key()),
649651
)
650652
}
651653
mcm = self._managed_collision_modules[table]
@@ -660,6 +662,8 @@ def compute(
660662
table: JaggedTensor(
661663
values=features.values(),
662664
lengths=features.lengths(),
665+
# TODO: improve this temp solution by passing real weights
666+
weights=torch.tensor(kjt.length_per_key()),
663667
)
664668
}
665669
mcm = self._managed_collision_modules[table]
@@ -673,6 +677,7 @@ def compute(
673677
keys=fns,
674678
values=values,
675679
lengths=features.lengths(),
680+
# original weights instead of features splits
676681
weights=features.weights_or_none(),
677682
)
678683
)

torchrec/modules/mc_modules.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,11 @@ def _mcc_lazy_init(
7676
return (features, created_feature_order, features_order)
7777

7878

79+
@torch.fx.wrap
80+
def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor:
81+
return torch.tensor(kjt.length_per_key())
82+
83+
7984
@torch.no_grad()
8085
def dynamic_threshold_filter(
8186
id_counts: torch.Tensor,
@@ -368,6 +373,7 @@ def forward(
368373
table: JaggedTensor(
369374
values=kjt.values(),
370375
lengths=kjt.lengths(),
376+
weights=_get_length_per_key(kjt),
371377
)
372378
}
373379
mc_input = mc_module(mc_input)

0 commit comments

Comments
 (0)