Skip to content

Commit d759c85

Browse files
committed
redefine enable_auto_mix_precision api
1 parent 7a8b5aa commit d759c85

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,19 @@ def enable_auto_optimization(mixed_dtype = None, train = False):
8282
"""
8383
if mixed_dtype != None:
8484
core.enable_auto_dnnl()
85+
enable_auto_mix_precision(mixed_dtype, train)
86+
87+
def enable_auto_mix_precision(mixed_dtype = torch.bfloat16, train = False):
8588
running_mode = 'training' if train else 'inference'
86-
enable_auto_mix_precision(AmpConf(mixed_dtype), running_mode).__enter__()
89+
auto_mix_precision(AmpConf(mixed_dtype), running_mode).__enter__()
8790

8891
def get_auto_optimization():
8992
return get_auto_mix_precision
9093

9194
def get_train():
9295
return core.get_train()
9396

94-
class enable_auto_mix_precision(_DecoratorContextManager):
97+
class auto_mix_precision(_DecoratorContextManager):
9598
def __init__(self, conf, running_mode = 'inference'):
9699
self.pre_mixed_dtype = get_auto_mix_precision()
97100
self.pre_running_mode = get_train()

tests/cpu/common_ipex_conf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44
class AutoMixPrecision(object):
55
def __init__(self, enable_or_not = False, train = False):
66
self.old_value = ipex.get_auto_mix_precision()
7-
self.pre_running_mode = 'training' if ipex.get_train() else 'inference'
7+
self.train_old_value = ipex.get_train()
88
self.enable_or_not = enable_or_not
9-
self.running_mode = 'training' if train else 'inference'
9+
self.train = train
1010

1111
def __enter__(self):
1212
if self.enable_or_not:
13-
ipex.enable_auto_mix_precision(ipex.AmpConf(torch.bfloat16), self.running_mode).__enter__()
13+
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=self.train)
1414
else:
15-
ipex.enable_auto_mix_precision(ipex.AmpConf(None)).__enter__()
15+
ipex.enable_auto_mix_precision(mixed_dtype=None)
1616

1717
def __exit__(self, *args, **kwargs):
1818
if self.old_value:
19-
ipex.enable_auto_mix_precision(ipex.AmpConf(torch.bfloat16), self.pre_running_mode).__enter__()
19+
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=self.train_old_value)
2020
else:
21-
ipex.enable_auto_mix_precision(ipex.AmpConf(None)).__enter__()
21+
ipex.enable_auto_mix_precision(mixed_dtype=None)
2222

2323
class AutoDNNL(object):
2424
def __init__(self, enable_or_not = False):

0 commit comments

Comments
 (0)