|
27 | 27 | To obtain the input of batch norm, which is necessary to backward through
|
28 | 28 | it, we recompute convolution forward again during the backward pass.
|
29 | 29 |
|
| 30 | +It is important to note that the usage of this optimization is situational. |
| 31 | +Though (by avoiding one buffer saved) we always reduce the memory allocated at |
| 32 | +the end of the forward pass, there are cases when the *peak* memory allocated |
| 33 | +may not actually be reduced. See the final section for more details. |
| 34 | +
|
30 | 35 | For simplicity, in this tutorial we hardcode `bias=False`, `stride=1`, `padding=0`, `dilation=1`,
|
31 | 36 | and `groups=1` for Conv2D. For BatchNorm2D, we hardcode `eps=1e-3`, `momentum=0.1`,
|
32 | 37 | `affine=False`, and `track_running_statistics=False`. Another small difference
|
@@ -238,6 +243,9 @@ def reset_parameters(self) -> None:
|
238 | 243 | from torchvision import datasets, transforms
|
239 | 244 | from torch.optim.lr_scheduler import StepLR
|
240 | 245 |
|
| 246 | +# Record memory allocated at the end of the forward pass |
| 247 | +memory_allocated = [[],[]] |
| 248 | + |
241 | 249 | class Net(nn.Module):
|
242 | 250 | def __init__(self, fused=True):
|
243 | 251 | super(Net, self).__init__()
|
@@ -275,6 +283,10 @@ def forward(self, x):
|
275 | 283 | F.relu_(x)
|
276 | 284 | x = self.fc2(x)
|
277 | 285 | output = F.log_softmax(x, dim=1)
|
| 286 | + if fused: |
| 287 | + memory_allocated[0].append(torch.cuda.memory_allocated()) |
| 288 | + else: |
| 289 | + memory_allocated[1].append(torch.cuda.memory_allocated()) |
278 | 290 | return output
|
279 | 291 |
|
280 | 292 | def train(model, device, train_loader, optimizer, epoch):
|
@@ -339,24 +351,44 @@ def test(model, device, test_loader):
|
339 | 351 | # A Comparison of Memory Usage
|
340 | 352 | # -------------------------------------------------------------------
|
341 | 353 | # If cuda is enabled, print out memory usage for both `fused=True` and `fused=False`
|
| 354 | +# For an example run on RTX 3070, CuDNN 8.0.5: fused peak memory: 1.56GB, |
| 355 | +# unfused peak memory: 2.68GB |
| 356 | +# |
| 357 | +# It is important to note that the *peak* memory usage for this model may vary depending |
| 358 | +# the specific CuDNN convolution algorithm used. For shallower models, it |
| 359 | +# may be possible for the peak memory allocated of the fused model to exceed |
| 360 | +# that of the unfused model! This is because the memory allocated to compute |
| 361 | +# certain CuDNN convolution algorithms can be high enough to "hide" the typical peak |
| 362 | +# you would expect to be near the start of the backward pass. |
| 363 | +# |
| 364 | +# For this reason, we also record and display the memory allocated at the end |
| 365 | +# of the forward pass as an approximation, and to demonstrate that we indeed |
| 366 | +# allocate one fewer buffer per fused conv-bn pair. |
| 367 | +from statistics import mean |
| 368 | + |
| 369 | +torch.backends.cudnn.enabled = True |
| 370 | + |
342 | 371 | if use_cuda:
|
343 |
| - mems = [] |
| 372 | + peak_memory_allocated = [] |
| 373 | + |
344 | 374 | for fused in (True, False):
|
345 | 375 | torch.manual_seed(123456)
|
346 |
| - torch.cuda.reset_peak_memory_stats() |
347 | 376 |
|
348 | 377 | model = Net(fused=fused).to(device)
|
349 | 378 | optimizer = optim.Adadelta(model.parameters(), lr=1.0)
|
350 | 379 | scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
|
351 | 380 |
|
352 |
| - for epoch in range(2): |
| 381 | + for epoch in range(1): |
353 | 382 | train(model, device, train_loader, optimizer, epoch)
|
354 | 383 | test(model, device, test_loader)
|
355 | 384 | scheduler.step()
|
| 385 | + peak_memory_allocated.append(torch.cuda.max_memory_allocated()) |
| 386 | + torch.cuda.reset_peak_memory_stats() |
| 387 | + print("CuDNN version:", torch.backends.cudnn.version()) |
| 388 | + print() |
| 389 | + print("Peak memory allocated:") |
| 390 | + print(f"fused: {peak_memory_allocated[0]/1024**3:.2f}GB, unfused: {peak_memory_allocated[1]/1024**3:.2f}GB") |
| 391 | + print("Memory allocated at end of forward pass:") |
| 392 | + print(f"fused: {mean(memory_allocated[0])/1024**3:.2f}GB, unfused: {mean(memory_allocated[1])/1024**3:.2f}GB") |
| 393 | + |
356 | 394 |
|
357 |
| - mems.append(torch.cuda.max_memory_allocated(device=None) / 1024**3) |
358 |
| - # Example run: fused peak memory: 1.56GB, unfused peak memory: 2.68GB |
359 |
| - # |
360 |
| - # NOTE: Actual memory usage may vary depending the specific CuDNN convolution algorithm used |
361 |
| - print(f"CuDNN version: {torch.backends.cudnn.version()}") |
362 |
| - print(f"fused peak memory: {mems[0]:.2f}GB, unfused peak memory: {mems[0]:.2f}GB") |
|
0 commit comments