Skip to content

Commit 4c203eb

Browse files
ilyas409facebook-github-bot
authored andcommitted
Address 10% performance issue (~16 ms per call) when using batch_size_stages in TorchRec Metrics with GPU (#2814)
Summary: Pull Request resolved: #2814 When num_batch was stored **and used** as a tensor that was part of the module, it was moved to GPU when the parent module was moved to GPU. Accessing a tensor with a GPU value within a CPU workload causes a CUDA StreamSync because the tensor value needs to be moved from HBM to RAM via ATen. Instead, we are using a separate variable `num_batch` for calculation, and synchronizing its value to as a tensor using state dict hooks. **From profiler:** *** Title torchrec/metrics/throughput.py(172): _get_batch_size Start 399.423 ms Wall Duration 16.001 ms Self Time 0.019 ms *** Title aten::is_nonzero Category cpu_op Start 399.464 ms Wall Duration 15.954 ms Self Time 0.003 ms *** Title aten::_local_scalar_dense Category cpu_op Start 399.468 ms Wall Duration 15.948 ms Self Time 0.015 ms *** Title cudaStreamSynchronize Category cuda_runtime Start 399.506 ms Wall Duration 15.904 ms *** {F1976014130} {F1976133311} Reviewed By: xunnanxu, iamzainhuda Differential Revision: D71079949 fbshipit-source-id: 5b21198e2579611904056f899a8c237e517ab22a
1 parent d954720 commit 4c203eb

File tree

3 files changed

+215
-8
lines changed

3 files changed

+215
-8
lines changed

torchrec/metrics/tests/test_metric_module.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
)
3030
from torchrec.metrics.metrics_config import (
3131
_DEFAULT_WINDOW_SIZE,
32+
BatchSizeStage,
3233
DefaultMetricsConfig,
3334
DefaultTaskInfo,
3435
EmptyMetricsConfig,
@@ -536,3 +537,52 @@ def test_adjust_compute_interval_1_30(self) -> None:
536537
min_interval=1.0,
537538
max_interval=30.0,
538539
)
540+
541+
def test_save_and_load_state_dict(self) -> None:
542+
# Test without batch_size_stages
543+
metric_module = generate_metric_module(
544+
TestMetricModule,
545+
metrics_config=DefaultMetricsConfig,
546+
batch_size=128,
547+
world_size=1,
548+
my_rank=0,
549+
state_metrics_mapping={},
550+
device=torch.device("cpu"),
551+
)
552+
metric_module.update(gen_test_batch(128))
553+
554+
state_dict_without_bss = metric_module.state_dict()
555+
# Make sure state loading works and doesn't throw an error
556+
metric_module.load_state_dict(state_dict_without_bss)
557+
# Make sure num_batch in the throughput module is not in state_dict
558+
self.assertFalse("throughput_metric.num_batch" in state_dict_without_bss)
559+
560+
# Test with batch_size_stages
561+
metric_module = generate_metric_module(
562+
TestMetricModule,
563+
metrics_config=DefaultMetricsConfig,
564+
batch_size=128,
565+
world_size=1,
566+
my_rank=0,
567+
state_metrics_mapping={},
568+
device=torch.device("cpu"),
569+
batch_size_stages=[BatchSizeStage(256, 100), BatchSizeStage(512, None)],
570+
)
571+
572+
# Update metric 100 times
573+
for _ in range(100):
574+
metric_module.update(gen_test_batch(128))
575+
576+
# Simulate a checkpoint save
577+
state_dict = metric_module.state_dict()
578+
# Make sure num_batch is updated correctly to 100
579+
self.assertEqual(state_dict["throughput_metric.num_batch"], 100)
580+
581+
# Simulate a checkpoint load
582+
metric_module.load_state_dict(state_dict)
583+
# Make sure num_batch is correctly restored
584+
throughput_metric = metric_module.throughput_metric
585+
self.assertIsNotNone(throughput_metric)
586+
self.assertEqual(throughput_metric._num_batch, 100)
587+
# Make sure num_batch is correctly synchronized
588+
self.assertEqual(throughput_metric._num_batch, 100)

torchrec/metrics/tests/test_throughput.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# pyre-ignore-all-errors[56]
1111

1212
import unittest
13+
from collections import OrderedDict
14+
from typing import Any, Dict
1315
from unittest.mock import Mock, patch
1416

1517
import torch
@@ -213,3 +215,125 @@ def test_batch_size_schedule(self, time_mock: Mock) -> None:
213215
"throughput-throughput|batch_size": 512,
214216
},
215217
)
218+
219+
def test_num_batch_without_batch_size_stages(self) -> None:
220+
# Create the module without the batch_size_stages
221+
throughput_metric = ThroughputMetric(
222+
batch_size=self.batch_size,
223+
world_size=self.world_size,
224+
window_seconds=100,
225+
batch_size_stages=None,
226+
)
227+
228+
# Make sure num_batch is not present as an argument of the class
229+
self.assertFalse(hasattr(throughput_metric, "num_batch"))
230+
231+
throughput_metric.update()
232+
state_dict: Dict[str, Any] = throughput_metric.state_dict()
233+
# Ensure num_batch is not included in the state_dict for the module without batch_size_stages
234+
self.assertNotIn("num_batch", state_dict)
235+
236+
def test_state_dict_load_module_lifecycle(self) -> None:
237+
"""
238+
A test to ensure that the load_state_dict and state_dict hooks correctly handle the num_batch attribute
239+
through the module lifecycle.
240+
"""
241+
242+
throughput_metric = ThroughputMetric(
243+
batch_size=32,
244+
world_size=4,
245+
window_seconds=100,
246+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
247+
)
248+
249+
self.assertTrue(hasattr(throughput_metric, "_num_batch"))
250+
251+
# Stage 1: create metric and update the state_dict before persisting it
252+
# Update metric, expecting num_batch to be incremented to 1
253+
throughput_metric.update()
254+
# Ensure num_batch is 1
255+
self.assertEqual(throughput_metric._num_batch, 1)
256+
# Ensure num_batch is included in the state_dict and has the correct value
257+
state_dict: Dict[str, Any] = throughput_metric.state_dict()
258+
self.assertIn("num_batch", state_dict)
259+
# Ensure num_batch was saved to state_dict with the correct value
260+
self.assertEqual(state_dict["num_batch"].item(), throughput_metric._num_batch)
261+
262+
# Stage 2: load the state_dict and ensure num_batch is loaded correctly
263+
264+
# Create a new metric instance
265+
new_throughput_metric = ThroughputMetric(
266+
batch_size=32,
267+
world_size=4,
268+
window_seconds=100,
269+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
270+
)
271+
# Ensure num_batch is 0
272+
self.assertEqual(new_throughput_metric._num_batch, 0)
273+
# Load the state_dict
274+
new_throughput_metric.load_state_dict(state_dict)
275+
# Ensure num_batch is loaded from the state_dict with the correct value
276+
self.assertEqual(new_throughput_metric._num_batch, 1)
277+
278+
# Stage 3: update the metric after loading the state and resave the state_dict
279+
280+
# Save the state_dict
281+
state_dict = new_throughput_metric.state_dict()
282+
# Ensure num_batch is included in the state_dict
283+
self.assertIn("num_batch", state_dict)
284+
# Ensure num_batch was saved to state_dict with the correct value
285+
self.assertEqual(
286+
state_dict["num_batch"].item(), new_throughput_metric._num_batch
287+
)
288+
289+
def test_state_dict_hook_adds_key(self) -> None:
290+
"""
291+
Ensures that the state_dict_hook adds the 'num_batch' key to the state_dict
292+
when batch_size_stages is True.
293+
"""
294+
throughput_metric = ThroughputMetric(
295+
batch_size=32,
296+
world_size=4,
297+
window_seconds=100,
298+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
299+
)
300+
for _ in range(5):
301+
throughput_metric.update()
302+
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
303+
prefix: str = "test_prefix_"
304+
ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {})
305+
self.assertIn(f"{prefix}num_batch", state_dict)
306+
self.assertEqual(state_dict[f"{prefix}num_batch"].item(), 5)
307+
308+
def test_state_dict_hook_no_batch_size_stages(self) -> None:
309+
"""
310+
Verifies that the state_dict_hook does not add the 'num_batch' key when
311+
batch_size_stages is None.
312+
"""
313+
throughput_metric = ThroughputMetric(
314+
batch_size=32,
315+
world_size=4,
316+
window_seconds=100,
317+
batch_size_stages=None,
318+
)
319+
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
320+
prefix: str = "test_prefix_"
321+
ThroughputMetric.state_dict_hook(throughput_metric, state_dict, prefix, {})
322+
self.assertNotIn(f"{prefix}num_batch", state_dict)
323+
324+
def test_load_state_dict_hook_restores_value(self) -> None:
325+
"""
326+
Checks that load_state_dict_hook correctly restores the 'num_batch' value
327+
from the state_dict.
328+
"""
329+
throughput_metric = ThroughputMetric(
330+
batch_size=32,
331+
world_size=4,
332+
window_seconds=100,
333+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
334+
)
335+
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
336+
prefix: str = "test_prefix_"
337+
state_dict[f"{prefix}num_batch"] = torch.tensor(10, dtype=torch.long)
338+
throughput_metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], [])
339+
self.assertEqual(throughput_metric._num_batch, 10)

torchrec/metrics/throughput.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
import logging
1414
import math
1515
import time
16-
from collections import deque
17-
from typing import Deque, Dict, List, Optional
16+
from collections import deque, OrderedDict
17+
from typing import Any, Deque, Dict, List, Optional
1818

1919
import torch
2020
import torch.nn as nn
@@ -112,12 +112,14 @@ def __init__(
112112
batch_size_stages
113113
)
114114

115+
if self._batch_size_stages is not None:
116+
# Keep track of the number of batches if using batch_size_stages
117+
self._num_batch: int = 0
118+
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
119+
self.register_state_dict_post_hook(self.state_dict_hook)
120+
115121
self.register_buffer("total_examples", torch.tensor(0, dtype=torch.long))
116122
self.register_buffer("warmup_examples", torch.tensor(0, dtype=torch.long))
117-
if batch_size_stages is not None:
118-
# only load num_batch when batch_size_stages is set.
119-
# So ckpt can be backward compatible -> non-existing key won't be loaded and crash
120-
self.register_buffer("num_batch", torch.tensor(0, dtype=torch.long))
121123
self.register_buffer(
122124
"time_lapse_after_warmup", torch.tensor(0, dtype=torch.double)
123125
)
@@ -181,6 +183,7 @@ def _get_batch_size(self) -> int:
181183
return self._batch_size
182184

183185
# Get batch size from batch_size_stages
186+
assert self._num_batch is not None, "num_batch should not be None"
184187
batch_size_stages = none_throws(self._batch_size_stages)
185188
while self._batch_size_stages:
186189
stage = self._batch_size_stages[0]
@@ -189,7 +192,7 @@ def _get_batch_size(self) -> int:
189192
assert len(batch_size_stages) == 1
190193
return stage.batch_size
191194
# This stage finished
192-
if stage.max_iters < self.num_batch:
195+
if stage.max_iters < self._num_batch:
193196
batch_size_stages.pop(0)
194197
# Move to the next stage
195198
continue
@@ -208,7 +211,7 @@ def update(self) -> None:
208211
ts = time.monotonic()
209212
self._steps += 1
210213
if self._batch_size_stages is not None:
211-
self.num_batch += 1
214+
self._num_batch += 1
212215
batch_examples = self._batch_examples()
213216
self.total_examples += batch_examples
214217
self.attempt_examples += batch_examples
@@ -276,3 +279,33 @@ def compute(self) -> Dict[str, torch.Tensor]:
276279
)
277280

278281
return ret
282+
283+
@staticmethod
284+
def state_dict_hook(
285+
module: nn.Module,
286+
state_dict: OrderedDict[str, torch.Tensor],
287+
prefix: str,
288+
local_metadata: Dict[str, Any],
289+
) -> None:
290+
if module._batch_size_stages is not None:
291+
# Save the number of batches used for the throughput calculation to the state dict
292+
num_batch_key = f"{prefix}num_batch"
293+
state_dict[num_batch_key] = torch.tensor(
294+
module._num_batch, dtype=torch.long
295+
)
296+
297+
def load_state_dict_hook(
298+
self,
299+
state_dict: OrderedDict[str, torch.Tensor],
300+
prefix: str,
301+
local_metadata: Dict[str, Any],
302+
strict: bool,
303+
missing_keys: List[str],
304+
unexpected_keys: List[str],
305+
error_msgs: List[str],
306+
) -> None:
307+
key = f"{prefix}num_batch"
308+
if key in state_dict and self._batch_size_stages is not None:
309+
# Restore the number of batches used for the throughput calculation from the state dict
310+
num_batch_tensor = state_dict.pop(key)
311+
self._num_batch = int(num_batch_tensor.item())

0 commit comments

Comments
 (0)