Skip to content

Commit 1766a01

Browse files
committed
Cleanup some amp related behaviour to better support different (non-cuda) devices
1 parent a852318 commit 1766a01

File tree

6 files changed

+39
-47
lines changed

6 files changed

+39
-47
lines changed

benchmark.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
except ImportError:
3333
pass
3434

35-
has_native_amp = False
36-
try:
37-
if getattr(torch.cuda.amp, 'autocast') is not None:
38-
has_native_amp = True
39-
except AttributeError:
40-
pass
41-
4235
try:
4336
from deepspeed.profiling.flops_profiler import get_model_profile
4437
has_deepspeed_profiling = True
@@ -242,7 +235,7 @@ def __init__(
242235
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
243236
self.channels_last = kwargs.pop('channels_last', False)
244237
if self.amp_dtype is not None:
245-
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
238+
self.amp_autocast = partial(torch.amp.autocast, device_type=device, dtype=self.amp_dtype)
246239
else:
247240
self.amp_autocast = suppress
248241

inference.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,6 @@
2828
except ImportError:
2929
has_apex = False
3030

31-
has_native_amp = False
32-
try:
33-
if getattr(torch.cuda.amp, 'autocast') is not None:
34-
has_native_amp = True
35-
except AttributeError:
36-
pass
37-
3831
try:
3932
from functorch.compile import memory_efficient_fusion
4033
has_functorch = True
@@ -170,7 +163,6 @@ def main():
170163
# resolve AMP arguments based on PyTorch / Apex availability
171164
amp_autocast = suppress
172165
if args.amp:
173-
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
174166
assert args.amp_dtype in ('float16', 'bfloat16')
175167
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
176168
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)

timm/layers/fast_norm.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,30 @@
2828
_USE_FAST_NORM = False # defaulting to False for now
2929

3030

31+
def get_autocast_dtype(device: str = 'cuda'):
32+
try:
33+
return torch.get_autocast_dtype(device)
34+
except (AttributeError, TypeError):
35+
# dispatch to older device specific fns, only covering cuda/cpu devices here
36+
if device == 'cpu':
37+
return torch.get_autocast_cpu_dtype()
38+
else:
39+
assert device == 'cuda'
40+
return torch.get_autocast_gpu_dtype()
41+
42+
43+
def is_autocast_enabled(device: str = 'cuda'):
44+
try:
45+
return torch.is_autocast_enabled(device)
46+
except TypeError:
47+
# dispatch to older device specific fns, only covering cuda/cpu devices here
48+
if device == 'cpu':
49+
return torch.is_autocast_cpu_enabled()
50+
else:
51+
assert device == 'cuda'
52+
return torch.is_autocast_enabled() # defaults cuda (only cuda on older pytorch)
53+
54+
3155
def is_fast_norm():
3256
return _USE_FAST_NORM
3357

@@ -48,14 +72,14 @@ def fast_group_norm(
4872
# currently cannot use is_autocast_enabled within torchscript
4973
return F.group_norm(x, num_groups, weight, bias, eps)
5074

51-
if torch.is_autocast_enabled():
75+
if is_autocast_enabled(x.device.type):
5276
# normally native AMP casts GN inputs to float32
5377
# here we use the low precision autocast dtype
5478
# FIXME what to do re CPU autocast?
55-
dt = torch.get_autocast_gpu_dtype()
79+
dt = get_autocast_dtype(x.device.type)
5680
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
5781

58-
with torch.cuda.amp.autocast(enabled=False):
82+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
5983
return F.group_norm(x, num_groups, weight, bias, eps)
6084

6185

@@ -73,14 +97,14 @@ def fast_layer_norm(
7397
if has_apex:
7498
return fused_layer_norm_affine(x, weight, bias, normalized_shape, eps)
7599

76-
if torch.is_autocast_enabled():
100+
if is_autocast_enabled(x.device.type):
77101
# normally native AMP casts LN inputs to float32
78102
# apex LN does not, this is behaving like Apex
79-
dt = torch.get_autocast_gpu_dtype()
103+
dt = get_autocast_dtype(x.device.type)
80104
# FIXME what to do re CPU autocast?
81105
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
82106

83-
with torch.cuda.amp.autocast(enabled=False):
107+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
84108
return F.layer_norm(x, normalized_shape, weight, bias, eps)
85109

86110

timm/utils/cuda.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ def load_state_dict(self, state_dict):
4646
class NativeScaler:
4747
state_dict_key = "amp_scaler"
4848

49-
def __init__(self):
50-
self._scaler = torch.cuda.amp.GradScaler()
49+
def __init__(self, device='cuda'):
50+
try:
51+
self._scaler = torch.amp.GradScaler(device=device)
52+
except (AttributeError, TypeError) as e:
53+
self._scaler = torch.cuda.amp.GradScaler()
5154

5255
def __call__(
5356
self,

train.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,6 @@
4848
except ImportError:
4949
has_apex = False
5050

51-
has_native_amp = False
52-
try:
53-
if getattr(torch.cuda.amp, 'autocast') is not None:
54-
has_native_amp = True
55-
except AttributeError:
56-
pass
5751

5852
try:
5953
import wandb
@@ -442,7 +436,6 @@ def main():
442436
use_amp = 'apex'
443437
assert args.amp_dtype == 'float16'
444438
else:
445-
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
446439
use_amp = 'native'
447440
assert args.amp_dtype in ('float16', 'bfloat16')
448441
if args.amp_dtype == 'bfloat16':
@@ -572,15 +565,10 @@ def main():
572565
if utils.is_primary(args):
573566
_logger.info('Using NVIDIA APEX AMP. Training in mixed precision.')
574567
elif use_amp == 'native':
575-
try:
576-
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
577-
except (AttributeError, TypeError):
578-
# fallback to CUDA only AMP for PyTorch < 1.10
579-
assert device.type == 'cuda'
580-
amp_autocast = torch.cuda.amp.autocast
581-
if device.type == 'cuda' and amp_dtype == torch.float16:
568+
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
569+
if device.type in ('cuda',) and amp_dtype == torch.float16:
582570
# loss scaler only used for float16 (half) dtype, bfloat16 does not need it
583-
loss_scaler = NativeScaler()
571+
loss_scaler = NativeScaler(device=device.type)
584572
if utils.is_primary(args):
585573
_logger.info('Using native Torch AMP. Training in mixed precision.')
586574
else:

validate.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,6 @@
3434
except ImportError:
3535
has_apex = False
3636

37-
has_native_amp = False
38-
try:
39-
if getattr(torch.cuda.amp, 'autocast') is not None:
40-
has_native_amp = True
41-
except AttributeError:
42-
pass
43-
4437
try:
4538
from functorch.compile import memory_efficient_fusion
4639
has_functorch = True
@@ -183,7 +176,6 @@ def validate(args):
183176
use_amp = 'apex'
184177
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
185178
else:
186-
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
187179
assert args.amp_dtype in ('float16', 'bfloat16')
188180
use_amp = 'native'
189181
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16

0 commit comments

Comments
 (0)