Skip to content

Commit 7767f56

Browse files
refine user API (#170)
* refine user API * use torch.cpu.amp.autocast for bfloat16 autocast * redesign int8 quantization path * simple autocast_kernel.cpp code * redesign api v2 * change ipex.optimize api * change ipex.optimize api v2 * code format change
1 parent 060ea58 commit 7767f56

20 files changed

+517
-555
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .optimizer_utils import *
1515
from .weight_cast import *
1616
from .optim import *
17+
from .quantization import *
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .autocast_mode import autocast, calibrate
1+
from .autocast_mode import *
22

intel_pytorch_extension_py/amp/autocast_mode.py

Lines changed: 8 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -2,80 +2,22 @@
22
import functools
33
import warnings
44
import numpy as np
5-
#from torch._six import container_abcs, string_classes
65
import _torch_ipex as core
7-
from .. import conf
8-
9-
class autocast(object):
10-
def __init__(self, enabled=True, configure=conf.AmpConf(torch.bfloat16)):
11-
supported_dtype = [torch.bfloat16, torch.int8]
12-
if configure.dtype not in supported_dtype :
13-
warnings.warn("In CPU autocast, but the target dtype is not supported. Disable the autocast.")
14-
warnings.warn("Supported dtype input is: torch.bfloat16, torch.int8.")
15-
enabled = False
16-
configure = conf.AmpConf(torch.bfloat16)
17-
self._enabled = enabled
18-
self._dtype = configure.dtype
196

7+
class _autocast_bf16(torch.cpu.amp.autocast):
208
def __enter__(self):
21-
self.prev = core.is_autocast_enabled()
22-
self.prev_dtype = core.get_autocast_dtype()
23-
self.pre_calibration_state = core.get_int8_calibration()
24-
core.set_autocast_enabled(self._enabled)
9+
self.prev = torch.is_autocast_cpu_enabled()
10+
self.prev_dtype = torch.get_autocast_cpu_dtype()
11+
torch.set_autocast_cpu_enabled(self._enabled)
2512
core.set_autocast_dtype(self._dtype)
26-
core.autocast_increment_nesting()
27-
if torch.int8 == self._dtype:
28-
core.disable_int8_calibration()
29-
30-
def __exit__(self, *args):
31-
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
32-
if core.autocast_decrement_nesting() == 0:
33-
core.clear_autocast_cache()
34-
core.clear_autocast_cache_int8()
35-
core.set_autocast_enabled(self.prev)
36-
core.set_autocast_dtype(self.prev_dtype)
37-
if torch.int8 == self._dtype:
38-
if self.pre_calibration_state:
39-
core.enable_int8_calibration()
40-
else:
41-
core.disable_int8_calibration()
42-
return False
43-
44-
def __call__(self, func):
45-
@functools.wraps(func)
46-
def decorate_autocast(*args, **kwargs):
47-
with self:
48-
return func(*args, **kwargs)
49-
return decorate_autocast
50-
51-
class calibrate(object):
52-
def __init__(self):
53-
self.pre_calibration_state = core.get_int8_calibration()
54-
55-
def __enter__(self):
56-
self.prev = core.is_autocast_enabled()
57-
self.prev_dtype = core.get_autocast_dtype()
58-
core.set_autocast_enabled(True)
59-
core.set_autocast_dtype(torch.int8)
60-
core.autocast_increment_nesting()
61-
core.enable_int8_calibration()
13+
torch.autocast_increment_nesting()
6214

6315
def __exit__(self, *args):
6416
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
65-
if core.autocast_decrement_nesting() == 0:
17+
if torch.autocast_decrement_nesting() == 0:
6618
core.clear_autocast_cache()
67-
core.set_autocast_enabled(self.prev)
19+
torch.set_autocast_cpu_enabled(self.prev)
6820
core.set_autocast_dtype(self.prev_dtype)
69-
core.calibration_reset()
70-
if self.pre_calibration_state:
71-
core.enable_int8_calibration()
72-
else:
73-
core.disable_int8_calibration()
7421
return False
7522

76-
def __call__(self, func):
77-
@functools.wraps(func)
78-
def decorate_autocast(*args, **kwargs):
79-
with self:
80-
return func(*args, **kwargs)
81-
return decorate_autocast
23+
torch.cpu.amp.autocast = _autocast_bf16

intel_pytorch_extension_py/conf.py

Lines changed: 11 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -3,126 +3,33 @@
33
import torch
44
import _torch_ipex as core
55

6-
76
qscheme_dict ={torch.per_tensor_affine:0,
87
torch.per_channel_affine:1,
98
torch.per_tensor_symmetric:2,
109
torch.per_channel_symmetric:3,
1110
torch.torch.per_channel_affine_float_qparams:4}
1211

13-
class AmpConf(object):
14-
def __init__(self, mixed_dtype=torch.bfloat16, configure_file=None, qscheme=torch.per_tensor_affine):
15-
self.dtype = mixed_dtype
12+
class QuantConf(object):
13+
def __init__(self, configure_file=None, qscheme=torch.per_tensor_affine):
1614
self.configure_file = configure_file
1715

18-
if self.dtype == torch.int8:
19-
core.clear_indicators()
20-
assert qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric], \
21-
"qscheme is only support torch.per_tensor_affine and torch.per_tensor_symmetric now"
22-
core.set_int8_qscheme(qscheme_dict[qscheme])
16+
core.clear_indicators()
17+
assert qscheme in [torch.per_tensor_affine, torch.per_tensor_symmetric], \
18+
"qscheme is only support torch.per_tensor_affine and torch.per_tensor_symmetric now"
19+
core.set_int8_qscheme(qscheme_dict[qscheme])
2320

24-
# for int8 path, if user give a exited configure file, load it.
25-
if self.configure_file != None and self.dtype == torch.int8:
21+
# if user provides an existing configuration file, load it
22+
if self.configure_file != None:
2623
if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:
2724
with open(self.configure_file, 'r') as f:
2825
configures = json.load(f)
2926
core.load_indicators_file(configures)
3027
else:
3128
assert False, 'Can not load a empty file or none existed file, plese first do calibartion step'
3229

33-
# for int8 quantization, will save the date after doing calibration step.
34-
def save(self, configure_file, default_recipe=True):
35-
core.add_indicators()
30+
def save(self, configure_file):
3631
configures = core.get_int8_configures()
37-
if default_recipe:
38-
configures = self.get_default_recipe(configures)
3932
with open(configure_file, 'w') as fp:
4033
json.dump(configures, fp, indent = 4)
41-
42-
def get_default_recipe(self, configures):
43-
elt_wise = ['relu', 'sigmoid', 'gelu']
44-
inplace_ops = ['relu_', 'add_']
45-
shape_ops = ['flatten']
46-
# get default recipe,
47-
# q+dq+conv+q+dq+relu => q+dq+conv+relu
48-
# q+dq+op1+q+dq+q+dq+op2+q+dq => q+dq+op1+q+dq+op2+q+dq
49-
default_configures = configures
50-
num_ops = len(default_configures)
51-
for cur_id in range(num_ops):
52-
cur_op = default_configures[cur_id]['name']
53-
if cur_op == 'dropout':
54-
continue
55-
inputs = default_configures[cur_id]['inputs_flow']
56-
num_input = len(inputs)
57-
pre_ops = {}
58-
for i_num in range(num_input):
59-
inp = inputs[i_num]
60-
for pre_id in range(cur_id):
61-
pre_op = default_configures[pre_id]['name']
62-
pre_out = default_configures[pre_id]['outputs_flow']
63-
num_out= len(pre_out)
64-
for o_num in range(num_out):
65-
# pre_op+qu+dequ+qu+dequ+cur_op+qu+dequ -> pre_op+qu+dequ+cur_op+qu+dequ.
66-
# for relu, sigmoid or other elt_wise ops, id pre_op is conv, linear, then
67-
# remove qu+dequ between them for fusion: pre_op+cur_op+qu_dequ.
68-
if pre_out[o_num] == inp:
69-
if (cur_op not in inplace_ops) \
70-
or (cur_op in inplace_ops and \
71-
(pre_op == 'conv2d' or pre_op == 'conv3d' or pre_op == 'linear')):
72-
if pre_op not in inplace_ops and pre_op != 'dropout':
73-
default_configures[pre_id]['outputs_quantized'][o_num] = False
74-
if cur_op in elt_wise \
75-
and (pre_op == 'conv2d' or pre_op == 'conv3d' or pre_op == 'linear' or pre_op == 'add'):
76-
default_configures[cur_id]['inputs_quantized'][i_num] = False
77-
if cur_op == 'add':
78-
pre_ops[i_num] = pre_op
79-
if cur_op in shape_ops:
80-
# for pooling case, the input and output always has same scale and zero point,
81-
# if the pooling's post ops is flatten, need sync flatten's input and output's
82-
# scale and zero point to pooling.
83-
if pre_op in ['max_pool2d', 'adaptive_avg_pool2d']:
84-
default_configures[cur_id]['input_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num]
85-
default_configures[cur_id]['input_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num]
86-
default_configures[cur_id]['output_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num]
87-
default_configures[cur_id]['output_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num]
88-
if pre_op in shape_ops:
89-
# if pre op is flatten, sync the input's scale and zero point to flatten.
90-
default_configures[cur_id]['input_scales'][i_num] = default_configures[pre_id]['output_scales'][o_num]
91-
default_configures[cur_id]['input_zero_points'][i_num] = default_configures[pre_id]['output_zero_points'][o_num]
92-
# conv op conv op
93-
# \ / \ /
94-
# q q \ q
95-
# \ / => \ /
96-
# dq dq \ dq
97-
# \ / \ /
98-
# add add
99-
if len(pre_ops) > 0:
100-
for key, value in pre_ops.items():
101-
if value == 'conv2d' or value == 'conv3d' or value == 'linear':
102-
default_configures[cur_id]['inputs_quantized'][key] = False
103-
break
104-
105-
# if add pre_op hasn't conv and linear, not need add q, dq for accuracy.
106-
pre_inputs = pre_ops.values()
107-
if cur_op == 'add' and \
108-
('conv2d' not in pre_inputs and 'conv3d' not in pre_inputs and 'linear' not in pre_inputs):
109-
default_configures[cur_id]['inputs_quantized'][0] = False
110-
default_configures[cur_id]['inputs_quantized'][1] = False
111-
112-
# post process for add, linear, if cur op hasn't post quantized op, i.e. 'outputs_quantized' is True,
113-
# for good perfromance, the default recipe:
114-
# int8_input -> op -> q -> dq will converted to int8_input -> op.
115-
ops_remove_q_dq_after = ['add', 'linear', 'conv2d']
116-
# post process for flatten, if flatten's pre-pop and post op are fp32 op, don't need add q and dq
117-
# before and after it.
118-
ops_remove_q_dq_before_after = ['flatten']
119-
for cur_id in range(num_ops):
120-
cur_op = default_configures[cur_id]['name']
121-
if cur_op in ops_remove_q_dq_after and default_configures[cur_id]['outputs_quantized'][0]:
122-
default_configures[cur_id]['outputs_quantized'][0] = False
123-
if cur_op in ops_remove_q_dq_before_after and default_configures[cur_id]['inputs_quantized'][0] \
124-
and default_configures[cur_id]['outputs_quantized'][0]:
125-
default_configures[cur_id]['inputs_quantized'][0] = False
126-
default_configures[cur_id]['outputs_quantized'][0] = False
127-
128-
return default_configures
34+
# clear indicators after saved
35+
core.clear_indicators()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .quantization_utils import calibrate, convert

0 commit comments

Comments
 (0)