Skip to content

Commit 592fd66

Browse files
iamzainhudaPaulZhang12
authored andcommitted
add metric to log model output kwargs (#2279)
Summary: Pull Request resolved: #2279 Differential Revision: D60986155 fbshipit-source-id: 25d389c7acbab81f43726e9991ff8ddd9a34ddb4
1 parent 4364e9a commit 592fd66

File tree

5 files changed

+137
-1
lines changed

5 files changed

+137
-1
lines changed

torchrec/metrics/metric_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from torchrec.metrics.multiclass_recall import MulticlassRecallMetric
4242
from torchrec.metrics.ndcg import NDCGMetric
4343
from torchrec.metrics.ne import NEMetric
44+
from torchrec.metrics.output import OutputMetric
4445
from torchrec.metrics.precision import PrecisionMetric
4546
from torchrec.metrics.rauc import RAUCMetric
4647
from torchrec.metrics.rec_metric import RecMetric, RecMetricList
@@ -80,6 +81,7 @@
8081
RecMetricEnum.RECALL: RecallMetric,
8182
RecMetricEnum.SERVING_NE: ServingNEMetric,
8283
RecMetricEnum.SERVING_CALIBRATION: ServingCalibrationMetric,
84+
RecMetricEnum.OUTPUT: OutputMetric,
8385
}
8486

8587

torchrec/metrics/metrics_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ class RecMetricEnum(RecMetricEnumBase):
4141
RECALL = "recall"
4242
SERVING_NE = "serving_ne"
4343
SERVING_CALIBRATION = "serving_calibration"
44+
OUTPUT = "output"
4445

4546

4647
@dataclass(unsafe_hash=True, eq=True)

torchrec/metrics/metrics_namespace.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class MetricName(MetricNameBase):
6363
NDCG = "ndcg"
6464
XAUC = "xauc"
6565
SCALAR = "scalar"
66+
OUTPUT = "output"
6667

6768
TOTAL_POSITIVE_EXAMPLES = "total_positive_examples"
6869
TOTAL_NEGATIVE_EXAMPLES = "total_negative_examples"
@@ -112,6 +113,8 @@ class MetricNamespace(MetricNamespaceBase):
112113
SERVING_NE = "serving_ne"
113114
SERVING_CALIBRATION = "serving_calibration"
114115

116+
OUTPUT = "output"
117+
115118

116119
class MetricPrefix(StrValueMixin, Enum):
117120
DEFAULT = ""

torchrec/metrics/output.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from typing import Any, Dict, List, Optional, Type
11+
12+
import torch
13+
from torch import distributed as dist
14+
15+
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
16+
from torchrec.metrics.rec_metric import (
17+
MetricComputationReport,
18+
RecComputeMode,
19+
RecMetric,
20+
RecMetricComputation,
21+
RecMetricException,
22+
RecTaskInfo,
23+
)
24+
25+
26+
class OutputMetricComputation(RecMetricComputation):
27+
"""
28+
Metric that logs whatever model outputs are given in kwargs
29+
TODO - make this generic metric that can be used for any model output tensor
30+
"""
31+
32+
def __init__(self, *args: Any, **kwargs: Any) -> None:
33+
super().__init__(*args, **kwargs)
34+
self._add_state(
35+
"latest_imp",
36+
torch.zeros(self._n_tasks, dtype=torch.double),
37+
add_window_state=False,
38+
dist_reduce_fx="sum",
39+
persistent=False,
40+
)
41+
self._add_state(
42+
"total_latest_imp",
43+
torch.zeros(self._n_tasks, dtype=torch.double),
44+
add_window_state=False,
45+
dist_reduce_fx="sum",
46+
persistent=False,
47+
)
48+
49+
def update(
50+
self,
51+
*,
52+
predictions: Optional[torch.Tensor],
53+
labels: torch.Tensor,
54+
weights: Optional[torch.Tensor],
55+
**kwargs: Dict[str, Any],
56+
) -> None:
57+
required_list = ["latest_imp", "total_latest_imp"]
58+
if "required_inputs" not in kwargs or not all(
59+
item in kwargs["required_inputs"] for item in required_list
60+
):
61+
raise RecMetricException(
62+
"OutputMetricComputation requires 'latest_imp' and 'total_latest_imp' in kwargs"
63+
)
64+
states = {
65+
"latest_imp": kwargs["required_inputs"]["latest_imp"]
66+
.float()
67+
.mean(dim=-1, dtype=torch.double),
68+
"total_latest_imp": kwargs["required_inputs"]["total_latest_imp"]
69+
.float()
70+
.mean(dim=-1, dtype=torch.double),
71+
}
72+
73+
for state_name, state_value in states.items():
74+
setattr(self, state_name, state_value)
75+
76+
def _compute(self) -> List[MetricComputationReport]:
77+
return [
78+
MetricComputationReport(
79+
name=MetricName.OUTPUT,
80+
metric_prefix=MetricPrefix.DEFAULT,
81+
value=self.latest_imp,
82+
description="_latest_imp",
83+
),
84+
MetricComputationReport(
85+
name=MetricName.OUTPUT,
86+
metric_prefix=MetricPrefix.DEFAULT,
87+
value=self.total_latest_imp,
88+
description="_total_latest_imp",
89+
),
90+
]
91+
92+
93+
class OutputMetric(RecMetric):
94+
_namespace: MetricNamespace = MetricNamespace.OUTPUT
95+
_computation_class: Type[RecMetricComputation] = OutputMetricComputation
96+
97+
def __init__(
98+
self,
99+
world_size: int,
100+
my_rank: int,
101+
batch_size: int,
102+
tasks: List[RecTaskInfo],
103+
compute_mode: RecComputeMode = RecComputeMode.UNFUSED_TASKS_COMPUTATION,
104+
window_size: int = 100,
105+
fused_update_limit: int = 0,
106+
compute_on_all_ranks: bool = False,
107+
should_validate_update: bool = False,
108+
process_group: Optional[dist.ProcessGroup] = None,
109+
**kwargs: Dict[str, Any],
110+
) -> None:
111+
super().__init__(
112+
world_size=world_size,
113+
my_rank=my_rank,
114+
batch_size=batch_size,
115+
tasks=tasks,
116+
compute_mode=compute_mode,
117+
window_size=window_size,
118+
fused_update_limit=fused_update_limit,
119+
compute_on_all_ranks=compute_on_all_ranks,
120+
should_validate_update=should_validate_update,
121+
process_group=process_group,
122+
**kwargs,
123+
)
124+
self._required_inputs.add("latest_imp")
125+
self._required_inputs.add("total_latest_imp")

torchrec/metrics/rec_metric.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,13 @@ def _update(
623623
else:
624624
continue
625625
if "required_inputs" in kwargs:
626+
# Expand scalars to match the shape of the predictions
626627
kwargs["required_inputs"] = {
627-
k: v.view(task_labels.size())
628+
k: (
629+
v.view(task_labels.size())
630+
if v.numel() > 1
631+
else v.expand(task_labels.size())
632+
)
628633
for k, v in kwargs["required_inputs"].items()
629634
}
630635
metric_.update(

0 commit comments

Comments
 (0)