Skip to content

Commit 0d827ea

Browse files
Yihang Yangfacebook-github-bot
Yihang Yang
authored andcommitted
Refactor Tensorboard logging code to avoid external dependency issues (#2713)
Summary: Pull Request resolved: #2713 # Context D68737755 caused issues to torchrec repo in GitHub. # Code Decision Revert changes in D68737755 using ``` $ hg backout D68737755 ``` Reviewed By: aliafzal Differential Revision: D68856792 fbshipit-source-id: 9567fb007041138a46aa94d957f706e4a82ccbe6
1 parent 487fbb7 commit 0d827ea

File tree

1 file changed

+3
-168
lines changed

1 file changed

+3
-168
lines changed

torchrec/modules/mc_modules.py

Lines changed: 3 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -15,155 +15,12 @@
1515

1616
import torch
1717

18-
from tensorboard.adhoc import Adhoc
19-
from torch import distributed as dist, nn
18+
from torch import nn
2019
from torchrec.modules.embedding_configs import BaseEmbeddingConfig
2120
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
2221

23-
logger: Logger = getLogger(__name__)
24-
25-
26-
class ScalarLogger(torch.nn.Module):
27-
"""
28-
A logger to report various metrics related to ZCH.
29-
This module is adapted from ScalarLogger for multi-probe ZCH.
30-
31-
Args:
32-
name (str): Name of the embedding table.
33-
frequency (int): Frequency of reporting metrics in number of iterations.
34-
35-
Example::
36-
scalar_logger = ScalarLogger(
37-
name=name,
38-
frequency=tb_logging_frequency,
39-
)
40-
"""
41-
42-
def __init__(
43-
self,
44-
name: str,
45-
frequency: int,
46-
) -> None:
47-
"""
48-
Initializes the logger.
49-
50-
Args:
51-
name (str): Name of the embedding table.
52-
frequency (int): Frequency of reporting metrics in number of iterations.
53-
54-
Returns:
55-
None
56-
"""
57-
super().__init__()
58-
self._name: str = name
59-
self._frequency: int = frequency
60-
61-
# size related metrics
62-
self._unused_size: int = 0
63-
self._active_size: int = 0
64-
self._total_size: int = 0
65-
66-
# scalar step
67-
self._scalar_logger_steps: int = 0
68-
69-
def should_report(self) -> bool:
70-
"""
71-
Returns whether the logger should report metrics.
72-
This function only returns True for rank 0 and every self._frequency steps.
73-
"""
74-
if self._scalar_logger_steps % self._frequency != 0:
75-
return False
76-
rank: int = -1
77-
if dist.is_available() and dist.is_initialized():
78-
rank = dist.get_rank()
79-
return rank == 0
80-
81-
def build_metric_name(
82-
self,
83-
metric: str,
84-
run_type: str,
85-
) -> str:
86-
"""
87-
Builds the metric name for reporting.
88-
89-
Args:
90-
metric (str): Name of the metric.
91-
run_type (str): Run type of the model, e.g. train, eval, etc.
92-
93-
Returns:
94-
str: Metric name.
95-
"""
96-
return f"mc_zch_stats/{self._name}/{metric}/{run_type}"
9722

98-
def update_size(
99-
self,
100-
counts: torch.Tensor,
101-
) -> None:
102-
"""
103-
Updates the size related metrics.
104-
105-
Args:
106-
counts (torch.Tensor): Counts of each id in the embedding table.
107-
108-
Returns:
109-
None
110-
"""
111-
zero_counts = counts == 0
112-
self._unused_size = int(torch.sum(zero_counts).item())
113-
114-
self._total_size = counts.shape[0]
115-
self._active_size = self._total_size - self._unused_size
116-
117-
def forward(
118-
self,
119-
run_type: str,
120-
) -> None:
121-
"""
122-
Reports various metrics related to ZCH.
123-
124-
Args:
125-
run_type (str): Run type of the model, e.g. train, eval, etc.
126-
127-
Returns:
128-
None
129-
"""
130-
if self.should_report():
131-
total_size = self._total_size + 0.001
132-
usused_ratio = round(self._unused_size / total_size, 3)
133-
active_ratio = round(self._active_size / total_size, 3)
134-
135-
Adhoc.writer().add_scalar(
136-
self.build_metric_name("unused_size", run_type),
137-
self._unused_size,
138-
self._scalar_logger_steps,
139-
)
140-
Adhoc.writer().add_scalar(
141-
self.build_metric_name("usused_ratio", run_type),
142-
usused_ratio,
143-
self._scalar_logger_steps,
144-
)
145-
Adhoc.writer().add_scalar(
146-
self.build_metric_name("active_size", run_type),
147-
self._active_size,
148-
self._scalar_logger_steps,
149-
)
150-
Adhoc.writer().add_scalar(
151-
self.build_metric_name("active_ratio", run_type),
152-
active_ratio,
153-
self._scalar_logger_steps,
154-
)
155-
156-
logger.info(f"{self._name=}, {run_type=}")
157-
logger.info(f"{self._total_size=}")
158-
logger.info(f"{self._unused_size=}, {usused_ratio=}")
159-
logger.info(f"{self._active_size=}, {active_ratio=}")
160-
161-
# reset after reporting
162-
self._unused_size = 0
163-
self._active_size = 0
164-
self._total_size = 0
165-
166-
self._scalar_logger_steps += 1
23+
logger: Logger = getLogger(__name__)
16724

16825

16926
@torch.fx.wrap
@@ -1126,7 +983,6 @@ def __init__(
1126983
output_global_offset: int = 0, # typically not provided by user
1127984
output_segments: Optional[List[int]] = None, # typically not provided by user
1128985
buckets: int = 1,
1129-
tb_logging_frequency: int = 0,
1130986
) -> None:
1131987
if output_segments is None:
1132988
output_segments = [output_global_offset, output_global_offset + zch_size]
@@ -1164,16 +1020,6 @@ def __init__(
11641020
self._evicted: bool = False
11651021
self._last_eviction_iter: int = -1
11661022

1167-
## ------ logging ------
1168-
self._tb_logging_frequency = tb_logging_frequency
1169-
self._scalar_logger: Optional[ScalarLogger] = None
1170-
if self._tb_logging_frequency > 0:
1171-
assert self._name is not None, "name must be provided for logging"
1172-
self._scalar_logger = ScalarLogger(
1173-
name=self._name,
1174-
frequency=self._tb_logging_frequency,
1175-
)
1176-
11771023
def _init_buffers(self) -> None:
11781024
self.register_buffer(
11791025
"_mch_sorted_raw_ids",
@@ -1459,26 +1305,16 @@ def profile(
14591305
self._coalesce_history()
14601306
self._last_eviction_iter = self._current_iter
14611307

1462-
if self._scalar_logger is not None:
1463-
self._scalar_logger.update_size(counts=self._mch_metadata["counts"])
1464-
14651308
return features
14661309

14671310
def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
1468-
remapped_features = _mch_remap(
1311+
return _mch_remap(
14691312
features,
14701313
self._mch_sorted_raw_ids,
14711314
self._mch_remapped_ids_mapping,
14721315
self._output_global_offset + self._zch_size - 1,
14731316
)
14741317

1475-
if self._scalar_logger is not None:
1476-
self._scalar_logger(
1477-
run_type="train" if self.training else "eval",
1478-
)
1479-
1480-
return remapped_features
1481-
14821318
def forward(
14831319
self,
14841320
features: Dict[str, JaggedTensor],
@@ -1557,5 +1393,4 @@ def rebuild_with_output_id_range(
15571393
output_global_offset=output_id_range[0],
15581394
output_segments=output_segments,
15591395
buckets=len(output_segments) - 1,
1560-
tb_logging_frequency=self._tb_logging_frequency,
15611396
)

0 commit comments

Comments
 (0)