File tree Expand file tree Collapse file tree 2 files changed +11
-8
lines changed
intel_pytorch_extension_py Expand file tree Collapse file tree 2 files changed +11
-8
lines changed Original file line number Diff line number Diff line change @@ -82,16 +82,19 @@ def enable_auto_optimization(mixed_dtype = None, train = False):
8282 """
8383 if mixed_dtype != None :
8484 core .enable_auto_dnnl ()
85+ enable_auto_mix_precision (mixed_dtype , train )
86+
87+ def enable_auto_mix_precision (mixed_dtype = torch .bfloat16 , train = False ):
8588 running_mode = 'training' if train else 'inference'
86- enable_auto_mix_precision (AmpConf (mixed_dtype ), running_mode ).__enter__ ()
89+ auto_mix_precision (AmpConf (mixed_dtype ), running_mode ).__enter__ ()
8790
8891def get_auto_optimization ():
8992 return get_auto_mix_precision
9093
9194def get_train ():
9295 return core .get_train ()
9396
94- class enable_auto_mix_precision (_DecoratorContextManager ):
97+ class auto_mix_precision (_DecoratorContextManager ):
9598 def __init__ (self , conf , running_mode = 'inference' ):
9699 self .pre_mixed_dtype = get_auto_mix_precision ()
97100 self .pre_running_mode = get_train ()
Original file line number Diff line number Diff line change 44class AutoMixPrecision (object ):
55 def __init__ (self , enable_or_not = False , train = False ):
66 self .old_value = ipex .get_auto_mix_precision ()
7- self .pre_running_mode = 'training' if ipex .get_train () else 'inference'
7+ self .train_old_value = ipex .get_train ()
88 self .enable_or_not = enable_or_not
9- self .running_mode = 'training' if train else 'inference'
9+ self .train = train
1010
1111 def __enter__ (self ):
1212 if self .enable_or_not :
13- ipex .enable_auto_mix_precision (ipex . AmpConf ( torch .bfloat16 ), self .running_mode ). __enter__ ( )
13+ ipex .enable_auto_mix_precision (mixed_dtype = torch .bfloat16 , train = self .train )
1414 else :
15- ipex .enable_auto_mix_precision (ipex . AmpConf ( None )). __enter__ ( )
15+ ipex .enable_auto_mix_precision (mixed_dtype = None )
1616
1717 def __exit__ (self , * args , ** kwargs ):
1818 if self .old_value :
19- ipex .enable_auto_mix_precision (ipex . AmpConf ( torch .bfloat16 ), self .pre_running_mode ). __enter__ ( )
19+ ipex .enable_auto_mix_precision (mixed_dtype = torch .bfloat16 , train = self .train_old_value )
2020 else :
21- ipex .enable_auto_mix_precision (ipex . AmpConf ( None )). __enter__ ( )
21+ ipex .enable_auto_mix_precision (mixed_dtype = None )
2222
2323class AutoDNNL (object ):
2424 def __init__ (self , enable_or_not = False ):
You can’t perform that action at this time.
0 commit comments