Skip to content

Commit bce8ae3

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
add GPU sync and numerical value tests (#2194)
Summary: Pull Request resolved: #2194 **TLDR: Significant clean up of recmetrics tests and fixing long standing testing issues with particular metrics. Because RecMetrics is community sourced, the quality of metric and test implementations can have high variance which leads to SEVs later on due to the high surface area RecMetrics operates on.** Added GPU sync tests to simulate gathering metric states on to rank 0 and computing. Tests don't cover this case before, which has resulted in SEVs in the past as users aren't aware of how RecMetrics collects and computes metrics. Added numerical value tests, most metrics are do not have this which can result in issues down the line if metrics need to be changed/accommodate future changes. Also we've found inconsistencies sometimes from other methods, so always good to check here. We compare each metric to a reference implementation from literature to ensure the values are as expected. Fixed and cleaned up some tests, particularly enforcing a standard of quality that can be referenced for future metric implementations and tests. Reviewed By: henrylhtsang Differential Revision: D59173140 fbshipit-source-id: 7e9e8e6666264f7a05f252143a51d0e3e7034d9d
1 parent d317c0b commit bce8ae3

15 files changed

+566
-222
lines changed

torchrec/metrics/test_utils/__init__.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,7 @@ def rec_metric_gpu_sync_test_launcher(
365365
entry_point: Callable[..., None],
366366
batch_size: int = BATCH_SIZE,
367367
batch_window_size: int = BATCH_WINDOW_SIZE,
368-
**kwargs: Any,
368+
**kwargs: Dict[str, Any],
369369
) -> None:
370370
with tempfile.TemporaryDirectory() as tmpdir:
371371
lc = get_launch_config(
@@ -385,6 +385,7 @@ def rec_metric_gpu_sync_test_launcher(
385385
should_validate_update,
386386
batch_size,
387387
batch_window_size,
388+
kwargs.get("n_classes", None),
388389
)
389390

390391

@@ -402,6 +403,7 @@ def sync_test_helper(
402403
batch_window_size: int = BATCH_WINDOW_SIZE,
403404
n_classes: Optional[int] = None,
404405
zero_weights: bool = False,
406+
**kwargs: Dict[str, Any],
405407
) -> None:
406408
rank = int(os.environ["RANK"])
407409
world_size = int(os.environ["WORLD_SIZE"])
@@ -413,13 +415,19 @@ def sync_test_helper(
413415

414416
tasks = gen_test_tasks(task_names)
415417

418+
if n_classes:
419+
# pyre-ignore[6]: Incompatible parameter type
420+
kwargs["number_of_classes"] = n_classes
421+
416422
auc = target_clazz(
417423
world_size=world_size,
418424
batch_size=batch_size,
419425
my_rank=rank,
420426
compute_on_all_ranks=compute_on_all_ranks,
421427
tasks=tasks,
422428
window_size=batch_window_size * world_size,
429+
# pyre-ignore[6]: Incompatible parameter type
430+
**kwargs,
423431
)
424432

425433
weight_value: Optional[torch.Tensor] = None
@@ -466,10 +474,17 @@ def sync_test_helper(
466474
res = auc.compute()
467475

468476
if rank == 0:
469-
assert torch.allclose(
470-
test_metrics[1][task_names[0]],
471-
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
472-
)
477+
# Serving Calibration uses Calibration naming inconsistently
478+
if metric_name == "serving_calibration":
479+
assert torch.allclose(
480+
test_metrics[1][task_names[0]],
481+
res[f"{metric_name}-{task_names[0]}|window_calibration"],
482+
)
483+
else:
484+
assert torch.allclose(
485+
test_metrics[1][task_names[0]],
486+
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
487+
)
473488

474489
# we also test the case where other rank has more tensors than rank 0
475490
auc.reset()
@@ -489,10 +504,17 @@ def sync_test_helper(
489504
res = auc.compute()
490505

491506
if rank == 0:
492-
assert torch.allclose(
493-
test_metrics[1][task_names[0]],
494-
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
495-
)
507+
# Serving Calibration uses Calibration naming inconsistently
508+
if metric_name == "serving_calibration":
509+
assert torch.allclose(
510+
test_metrics[1][task_names[0]],
511+
res[f"{metric_name}-{task_names[0]}|window_calibration"],
512+
)
513+
else:
514+
assert torch.allclose(
515+
test_metrics[1][task_names[0]],
516+
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
517+
)
496518

497519
dist.destroy_process_group()
498520

torchrec/metrics/tests/test_accuracy.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1818
from torchrec.metrics.test_utils import (
1919
metric_test_helper,
20+
rec_metric_gpu_sync_test_launcher,
2021
rec_metric_value_test_launcher,
2122
RecTaskInfo,
23+
sync_test_helper,
2224
TestMetric,
2325
)
2426

@@ -251,3 +253,24 @@ def test_accuracy(self) -> None:
251253
except AssertionError:
252254
print("Assertion error caught with data set ", inputs)
253255
raise
256+
257+
258+
class AccuracyGPUSyncTest(unittest.TestCase):
259+
clazz: Type[RecMetric] = AccuracyMetric
260+
task_name: str = "accuracy"
261+
262+
def test_sync_accuracy(self) -> None:
263+
rec_metric_gpu_sync_test_launcher(
264+
target_clazz=AccuracyMetric,
265+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
266+
test_clazz=TestAccuracyMetric,
267+
metric_name=AccuracyGPUSyncTest.task_name,
268+
task_names=["t1"],
269+
fused_update_limit=0,
270+
compute_on_all_ranks=False,
271+
should_validate_update=False,
272+
world_size=2,
273+
batch_size=5,
274+
batch_window_size=20,
275+
entry_point=sync_test_helper,
276+
)

torchrec/metrics/tests/test_auprc.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
)
2424
from torchrec.metrics.test_utils import (
2525
metric_test_helper,
26+
rec_metric_gpu_sync_test_launcher,
2627
rec_metric_value_test_launcher,
28+
sync_test_helper,
2729
TestMetric,
2830
)
2931

@@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None:
346348
)
347349

348350
self.assertIn("grouping_keys", auprc.get_required_inputs())
351+
352+
353+
class AUPRCGPUSyncTest(unittest.TestCase):
354+
clazz: Type[RecMetric] = AUPRCMetric
355+
task_name: str = "auprc"
356+
357+
def test_sync_auprc(self) -> None:
358+
rec_metric_gpu_sync_test_launcher(
359+
target_clazz=AUPRCMetric,
360+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
361+
test_clazz=TestAUPRCMetric,
362+
metric_name=AUPRCGPUSyncTest.task_name,
363+
task_names=["t1"],
364+
fused_update_limit=0,
365+
compute_on_all_ranks=False,
366+
should_validate_update=False,
367+
world_size=2,
368+
batch_size=5,
369+
batch_window_size=20,
370+
entry_point=sync_test_helper,
371+
)

torchrec/metrics/tests/test_calibration.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None:
7779
world_size=WORLD_SIZE,
7880
entry_point=metric_test_helper,
7981
)
82+
83+
84+
class CalibrationGPUSyncTest(unittest.TestCase):
85+
clazz: Type[RecMetric] = CalibrationMetric
86+
task_name: str = "calibration"
87+
88+
def test_sync_calibration(self) -> None:
89+
rec_metric_gpu_sync_test_launcher(
90+
target_clazz=CalibrationMetric,
91+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
92+
test_clazz=TestCalibrationMetric,
93+
metric_name=CalibrationGPUSyncTest.task_name,
94+
task_names=["t1"],
95+
fused_update_limit=0,
96+
compute_on_all_ranks=False,
97+
should_validate_update=False,
98+
world_size=2,
99+
batch_size=5,
100+
batch_window_size=20,
101+
entry_point=sync_test_helper,
102+
)

torchrec/metrics/tests/test_ctr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None:
7173
world_size=WORLD_SIZE,
7274
entry_point=metric_test_helper,
7375
)
76+
77+
78+
class CTRGPUSyncTest(unittest.TestCase):
79+
clazz: Type[RecMetric] = CTRMetric
80+
task_name: str = "ctr"
81+
82+
def test_sync_ctr(self) -> None:
83+
rec_metric_gpu_sync_test_launcher(
84+
target_clazz=CTRMetric,
85+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
86+
test_clazz=TestCTRMetric,
87+
metric_name=CTRGPUSyncTest.task_name,
88+
task_names=["t1"],
89+
fused_update_limit=0,
90+
compute_on_all_ranks=False,
91+
should_validate_update=False,
92+
world_size=2,
93+
batch_size=5,
94+
batch_window_size=20,
95+
entry_point=sync_test_helper,
96+
)

torchrec/metrics/tests/test_mae.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -74,3 +76,24 @@ def test_fused_mae(self) -> None:
7476
world_size=WORLD_SIZE,
7577
entry_point=metric_test_helper,
7678
)
79+
80+
81+
class MAEGPUSyncTest(unittest.TestCase):
82+
clazz: Type[RecMetric] = MAEMetric
83+
task_name: str = "mae"
84+
85+
def test_sync_mae(self) -> None:
86+
rec_metric_gpu_sync_test_launcher(
87+
target_clazz=MAEMetric,
88+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
89+
test_clazz=TestMAEMetric,
90+
metric_name=MAEGPUSyncTest.task_name,
91+
task_names=["t1"],
92+
fused_update_limit=0,
93+
compute_on_all_ranks=False,
94+
should_validate_update=False,
95+
world_size=2,
96+
batch_size=5,
97+
batch_window_size=20,
98+
entry_point=sync_test_helper,
99+
)

torchrec/metrics/tests/test_mse.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
1616
from torchrec.metrics.test_utils import (
1717
metric_test_helper,
18+
rec_metric_gpu_sync_test_launcher,
1819
rec_metric_value_test_launcher,
20+
sync_test_helper,
1921
TestMetric,
2022
)
2123

@@ -123,3 +125,24 @@ def test_fused_rmse(self) -> None:
123125
world_size=WORLD_SIZE,
124126
entry_point=metric_test_helper,
125127
)
128+
129+
130+
class MSEGPUSyncTest(unittest.TestCase):
131+
clazz: Type[RecMetric] = MSEMetric
132+
task_name: str = "mse"
133+
134+
def test_sync_mse(self) -> None:
135+
rec_metric_gpu_sync_test_launcher(
136+
target_clazz=MSEMetric,
137+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
138+
test_clazz=TestMSEMetric,
139+
metric_name=MSEGPUSyncTest.task_name,
140+
task_names=["t1"],
141+
fused_update_limit=0,
142+
compute_on_all_ranks=False,
143+
should_validate_update=False,
144+
world_size=2,
145+
batch_size=5,
146+
batch_window_size=20,
147+
entry_point=sync_test_helper,
148+
)

torchrec/metrics/tests/test_multiclass_recall.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
2020
from torchrec.metrics.test_utils import (
2121
metric_test_helper,
22+
rec_metric_gpu_sync_test_launcher,
2223
rec_metric_value_test_launcher,
24+
sync_test_helper,
2325
TestMetric,
2426
)
2527

@@ -113,3 +115,26 @@ def test_multiclass_recall_update_fused(self) -> None:
113115
batch_window_size=10,
114116
n_classes=N_CLASSES,
115117
)
118+
119+
120+
class MulticlassRecallGPUSyncTest(unittest.TestCase):
121+
clazz: Type[RecMetric] = MulticlassRecallMetric
122+
task_name: str = "multiclass_recall"
123+
124+
def test_sync_multiclass_recall(self) -> None:
125+
rec_metric_gpu_sync_test_launcher(
126+
target_clazz=MulticlassRecallMetric,
127+
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
128+
test_clazz=TestMulticlassRecallMetric,
129+
metric_name=MulticlassRecallGPUSyncTest.task_name,
130+
task_names=["t1"],
131+
fused_update_limit=0,
132+
compute_on_all_ranks=False,
133+
should_validate_update=False,
134+
world_size=2,
135+
batch_size=5,
136+
batch_window_size=20,
137+
entry_point=sync_test_helper,
138+
# pyre-ignore[6] Incompatible parameter type
139+
n_classes=N_CLASSES,
140+
)

0 commit comments

Comments
 (0)