|
10 | 10 | # pyre-ignore-all-errors[56]
|
11 | 11 |
|
12 | 12 | import unittest
|
| 13 | +from collections import OrderedDict |
| 14 | +from typing import Any, Dict |
13 | 15 | from unittest.mock import Mock, patch
|
14 | 16 |
|
15 | 17 | import torch
|
@@ -213,3 +215,125 @@ def test_batch_size_schedule(self, time_mock: Mock) -> None:
|
213 | 215 | "throughput-throughput|batch_size": 512,
|
214 | 216 | },
|
215 | 217 | )
|
| 218 | + |
| 219 | + def test_num_batch_without_batch_size_stages(self) -> None: |
| 220 | + # Create the module without the batch_size_stages |
| 221 | + throughput_metric = ThroughputMetric( |
| 222 | + batch_size=self.batch_size, |
| 223 | + world_size=self.world_size, |
| 224 | + window_seconds=100, |
| 225 | + batch_size_stages=None, |
| 226 | + ) |
| 227 | + |
| 228 | + # Make sure num_batch is not present as an argument of the class |
| 229 | + self.assertFalse(hasattr(throughput_metric, "num_batch")) |
| 230 | + |
| 231 | + throughput_metric.update() |
| 232 | + state_dict: Dict[str, Any] = throughput_metric.state_dict() |
| 233 | + # Ensure num_batch is not included in the state_dict for the module without batch_size_stages |
| 234 | + self.assertNotIn("num_batch", state_dict) |
| 235 | + |
| 236 | + def test_state_dict_load_module_lifecycle(self) -> None: |
| 237 | + """ |
| 238 | + A test to ensure that the load_state_dict and state_dict hooks correctly handle the num_batch attribute |
| 239 | + through the module lifecycle. |
| 240 | + """ |
| 241 | + |
| 242 | + throughput_metric = ThroughputMetric( |
| 243 | + batch_size=32, |
| 244 | + world_size=4, |
| 245 | + window_seconds=100, |
| 246 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 247 | + ) |
| 248 | + |
| 249 | + self.assertTrue(hasattr(throughput_metric, "_num_batch")) |
| 250 | + |
| 251 | + # Stage 1: create metric and update the state_dict before persisting it |
| 252 | + # Update metric, expecting num_batch to be incremented to 1 |
| 253 | + throughput_metric.update() |
| 254 | + # Ensure num_batch is 1 |
| 255 | + self.assertEqual(throughput_metric._num_batch, 1) |
| 256 | + # Ensure num_batch is included in the state_dict and has the correct value |
| 257 | + state_dict: Dict[str, Any] = throughput_metric.state_dict() |
| 258 | + self.assertIn("num_batch", state_dict) |
| 259 | + # Ensure num_batch was saved to state_dict with the correct value |
| 260 | + self.assertEqual(state_dict["num_batch"].item(), throughput_metric._num_batch) |
| 261 | + |
| 262 | + # Stage 2: load the state_dict and ensure num_batch is loaded correctly |
| 263 | + |
| 264 | + # Create a new metric instance |
| 265 | + new_throughput_metric = ThroughputMetric( |
| 266 | + batch_size=32, |
| 267 | + world_size=4, |
| 268 | + window_seconds=100, |
| 269 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 270 | + ) |
| 271 | + # Ensure num_batch is 0 |
| 272 | + self.assertEqual(new_throughput_metric._num_batch, 0) |
| 273 | + # Load the state_dict |
| 274 | + new_throughput_metric.load_state_dict(state_dict) |
| 275 | + # Ensure num_batch is loaded from the state_dict with the correct value |
| 276 | + self.assertEqual(new_throughput_metric._num_batch, 1) |
| 277 | + |
| 278 | + # Stage 3: update the metric after loading the state and resave the state_dict |
| 279 | + |
| 280 | + # Save the state_dict |
| 281 | + state_dict = new_throughput_metric.state_dict() |
| 282 | + # Ensure num_batch is included in the state_dict |
| 283 | + self.assertIn("num_batch", state_dict) |
| 284 | + # Ensure num_batch was saved to state_dict with the correct value |
| 285 | + self.assertEqual( |
| 286 | + state_dict["num_batch"].item(), new_throughput_metric._num_batch |
| 287 | + ) |
| 288 | + |
| 289 | + def test_state_dict_hook_adds_key(self) -> None: |
| 290 | + """ |
| 291 | + Ensures that the state_dict_hook adds the 'num_batch' key to the state_dict |
| 292 | + when batch_size_stages is True. |
| 293 | + """ |
| 294 | + throughput_metric = ThroughputMetric( |
| 295 | + batch_size=32, |
| 296 | + world_size=4, |
| 297 | + window_seconds=100, |
| 298 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 299 | + ) |
| 300 | + for _ in range(5): |
| 301 | + throughput_metric.update() |
| 302 | + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| 303 | + prefix: str = "test_prefix_" |
| 304 | + ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {}) |
| 305 | + self.assertIn(f"{prefix}num_batch", state_dict) |
| 306 | + self.assertEqual(state_dict[f"{prefix}num_batch"].item(), 5) |
| 307 | + |
| 308 | + def test_state_dict_hook_no_batch_size_stages(self) -> None: |
| 309 | + """ |
| 310 | + Verifies that the state_dict_hook does not add the 'num_batch' key when |
| 311 | + batch_size_stages is None. |
| 312 | + """ |
| 313 | + throughput_metric = ThroughputMetric( |
| 314 | + batch_size=32, |
| 315 | + world_size=4, |
| 316 | + window_seconds=100, |
| 317 | + batch_size_stages=None, |
| 318 | + ) |
| 319 | + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| 320 | + prefix: str = "test_prefix_" |
| 321 | + ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {}) |
| 322 | + self.assertNotIn(f"{prefix}num_batch", state_dict) |
| 323 | + |
| 324 | + def test_load_state_dict_hook_restores_value(self) -> None: |
| 325 | + """ |
| 326 | + Checks that load_state_dict_hook correctly restores the 'num_batch' value |
| 327 | + from the state_dict. |
| 328 | + """ |
| 329 | + throughput_metric = ThroughputMetric( |
| 330 | + batch_size=32, |
| 331 | + world_size=4, |
| 332 | + window_seconds=100, |
| 333 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 334 | + ) |
| 335 | + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| 336 | + prefix: str = "test_prefix_" |
| 337 | + state_dict[f"{prefix}num_batch"] = torch.tensor(10, dtype=torch.long) |
| 338 | + throughput_metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], []) |
| 339 | + self.assertEqual(throughput_metric._num_batch, 10) |
0 commit comments