|
15 | 15 |
|
16 | 16 | import torch
|
17 | 17 |
|
18 |
| -from tensorboard.adhoc import Adhoc |
19 |
| -from torch import distributed as dist, nn |
| 18 | +from torch import nn |
20 | 19 | from torchrec.modules.embedding_configs import BaseEmbeddingConfig
|
21 | 20 | from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
|
22 | 21 |
|
23 |
| -logger: Logger = getLogger(__name__) |
24 |
| - |
25 |
| - |
26 |
| -class ScalarLogger(torch.nn.Module): |
27 |
| - """ |
28 |
| - A logger to report various metrics related to ZCH. |
29 |
| - This module is adapted from ScalarLogger for multi-probe ZCH. |
30 |
| -
|
31 |
| - Args: |
32 |
| - name (str): Name of the embedding table. |
33 |
| - frequency (int): Frequency of reporting metrics in number of iterations. |
34 |
| -
|
35 |
| - Example:: |
36 |
| - scalar_logger = ScalarLogger( |
37 |
| - name=name, |
38 |
| - frequency=tb_logging_frequency, |
39 |
| - ) |
40 |
| - """ |
41 |
| - |
42 |
| - def __init__( |
43 |
| - self, |
44 |
| - name: str, |
45 |
| - frequency: int, |
46 |
| - ) -> None: |
47 |
| - """ |
48 |
| - Initializes the logger. |
49 |
| -
|
50 |
| - Args: |
51 |
| - name (str): Name of the embedding table. |
52 |
| - frequency (int): Frequency of reporting metrics in number of iterations. |
53 |
| -
|
54 |
| - Returns: |
55 |
| - None |
56 |
| - """ |
57 |
| - super().__init__() |
58 |
| - self._name: str = name |
59 |
| - self._frequency: int = frequency |
60 |
| - |
61 |
| - # size related metrics |
62 |
| - self._unused_size: int = 0 |
63 |
| - self._active_size: int = 0 |
64 |
| - self._total_size: int = 0 |
65 |
| - |
66 |
| - # scalar step |
67 |
| - self._scalar_logger_steps: int = 0 |
68 |
| - |
69 |
| - def should_report(self) -> bool: |
70 |
| - """ |
71 |
| - Returns whether the logger should report metrics. |
72 |
| - This function only returns True for rank 0 and every self._frequency steps. |
73 |
| - """ |
74 |
| - if self._scalar_logger_steps % self._frequency != 0: |
75 |
| - return False |
76 |
| - rank: int = -1 |
77 |
| - if dist.is_available() and dist.is_initialized(): |
78 |
| - rank = dist.get_rank() |
79 |
| - return rank == 0 |
80 |
| - |
81 |
| - def build_metric_name( |
82 |
| - self, |
83 |
| - metric: str, |
84 |
| - run_type: str, |
85 |
| - ) -> str: |
86 |
| - """ |
87 |
| - Builds the metric name for reporting. |
88 |
| -
|
89 |
| - Args: |
90 |
| - metric (str): Name of the metric. |
91 |
| - run_type (str): Run type of the model, e.g. train, eval, etc. |
92 |
| -
|
93 |
| - Returns: |
94 |
| - str: Metric name. |
95 |
| - """ |
96 |
| - return f"mc_zch_stats/{self._name}/{metric}/{run_type}" |
97 | 22 |
|
98 |
| - def update_size( |
99 |
| - self, |
100 |
| - counts: torch.Tensor, |
101 |
| - ) -> None: |
102 |
| - """ |
103 |
| - Updates the size related metrics. |
104 |
| -
|
105 |
| - Args: |
106 |
| - counts (torch.Tensor): Counts of each id in the embedding table. |
107 |
| -
|
108 |
| - Returns: |
109 |
| - None |
110 |
| - """ |
111 |
| - zero_counts = counts == 0 |
112 |
| - self._unused_size = int(torch.sum(zero_counts).item()) |
113 |
| - |
114 |
| - self._total_size = counts.shape[0] |
115 |
| - self._active_size = self._total_size - self._unused_size |
116 |
| - |
117 |
| - def forward( |
118 |
| - self, |
119 |
| - run_type: str, |
120 |
| - ) -> None: |
121 |
| - """ |
122 |
| - Reports various metrics related to ZCH. |
123 |
| -
|
124 |
| - Args: |
125 |
| - run_type (str): Run type of the model, e.g. train, eval, etc. |
126 |
| -
|
127 |
| - Returns: |
128 |
| - None |
129 |
| - """ |
130 |
| - if self.should_report(): |
131 |
| - total_size = self._total_size + 0.001 |
132 |
| - usused_ratio = round(self._unused_size / total_size, 3) |
133 |
| - active_ratio = round(self._active_size / total_size, 3) |
134 |
| - |
135 |
| - Adhoc.writer().add_scalar( |
136 |
| - self.build_metric_name("unused_size", run_type), |
137 |
| - self._unused_size, |
138 |
| - self._scalar_logger_steps, |
139 |
| - ) |
140 |
| - Adhoc.writer().add_scalar( |
141 |
| - self.build_metric_name("usused_ratio", run_type), |
142 |
| - usused_ratio, |
143 |
| - self._scalar_logger_steps, |
144 |
| - ) |
145 |
| - Adhoc.writer().add_scalar( |
146 |
| - self.build_metric_name("active_size", run_type), |
147 |
| - self._active_size, |
148 |
| - self._scalar_logger_steps, |
149 |
| - ) |
150 |
| - Adhoc.writer().add_scalar( |
151 |
| - self.build_metric_name("active_ratio", run_type), |
152 |
| - active_ratio, |
153 |
| - self._scalar_logger_steps, |
154 |
| - ) |
155 |
| - |
156 |
| - logger.info(f"{self._name=}, {run_type=}") |
157 |
| - logger.info(f"{self._total_size=}") |
158 |
| - logger.info(f"{self._unused_size=}, {usused_ratio=}") |
159 |
| - logger.info(f"{self._active_size=}, {active_ratio=}") |
160 |
| - |
161 |
| - # reset after reporting |
162 |
| - self._unused_size = 0 |
163 |
| - self._active_size = 0 |
164 |
| - self._total_size = 0 |
165 |
| - |
166 |
| - self._scalar_logger_steps += 1 |
| 23 | +logger: Logger = getLogger(__name__) |
167 | 24 |
|
168 | 25 |
|
169 | 26 | @torch.fx.wrap
|
@@ -1126,7 +983,6 @@ def __init__(
|
1126 | 983 | output_global_offset: int = 0, # typically not provided by user
|
1127 | 984 | output_segments: Optional[List[int]] = None, # typically not provided by user
|
1128 | 985 | buckets: int = 1,
|
1129 |
| - tb_logging_frequency: int = 0, |
1130 | 986 | ) -> None:
|
1131 | 987 | if output_segments is None:
|
1132 | 988 | output_segments = [output_global_offset, output_global_offset + zch_size]
|
@@ -1164,16 +1020,6 @@ def __init__(
|
1164 | 1020 | self._evicted: bool = False
|
1165 | 1021 | self._last_eviction_iter: int = -1
|
1166 | 1022 |
|
1167 |
| - ## ------ logging ------ |
1168 |
| - self._tb_logging_frequency = tb_logging_frequency |
1169 |
| - self._scalar_logger: Optional[ScalarLogger] = None |
1170 |
| - if self._tb_logging_frequency > 0: |
1171 |
| - assert self._name is not None, "name must be provided for logging" |
1172 |
| - self._scalar_logger = ScalarLogger( |
1173 |
| - name=self._name, |
1174 |
| - frequency=self._tb_logging_frequency, |
1175 |
| - ) |
1176 |
| - |
1177 | 1023 | def _init_buffers(self) -> None:
|
1178 | 1024 | self.register_buffer(
|
1179 | 1025 | "_mch_sorted_raw_ids",
|
@@ -1459,26 +1305,16 @@ def profile(
|
1459 | 1305 | self._coalesce_history()
|
1460 | 1306 | self._last_eviction_iter = self._current_iter
|
1461 | 1307 |
|
1462 |
| - if self._scalar_logger is not None: |
1463 |
| - self._scalar_logger.update_size(counts=self._mch_metadata["counts"]) |
1464 |
| - |
1465 | 1308 | return features
|
1466 | 1309 |
|
1467 | 1310 | def remap(self, features: Dict[str, JaggedTensor]) -> Dict[str, JaggedTensor]:
|
1468 |
| - remapped_features = _mch_remap( |
| 1311 | + return _mch_remap( |
1469 | 1312 | features,
|
1470 | 1313 | self._mch_sorted_raw_ids,
|
1471 | 1314 | self._mch_remapped_ids_mapping,
|
1472 | 1315 | self._output_global_offset + self._zch_size - 1,
|
1473 | 1316 | )
|
1474 | 1317 |
|
1475 |
| - if self._scalar_logger is not None: |
1476 |
| - self._scalar_logger( |
1477 |
| - run_type="train" if self.training else "eval", |
1478 |
| - ) |
1479 |
| - |
1480 |
| - return remapped_features |
1481 |
| - |
1482 | 1318 | def forward(
|
1483 | 1319 | self,
|
1484 | 1320 | features: Dict[str, JaggedTensor],
|
@@ -1557,5 +1393,4 @@ def rebuild_with_output_id_range(
|
1557 | 1393 | output_global_offset=output_id_range[0],
|
1558 | 1394 | output_segments=output_segments,
|
1559 | 1395 | buckets=len(output_segments) - 1,
|
1560 |
| - tb_logging_frequency=self._tb_logging_frequency, |
1561 | 1396 | )
|
0 commit comments