Skip to content

Commit 0410a33

Browse files
Refine Script and args for Cpp Graph (intel#320)
1 parent 67d67ff commit 0410a33

File tree

15 files changed

+87
-81
lines changed

15 files changed

+87
-81
lines changed

.github/workflows/script/models/cpp_graph_inference.sh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -110,25 +110,25 @@ function main() {
110110
quantized_model="${model}-${precision}.bin"
111111
if [[ ! -e ${quantized_model} ]]; then
112112
if [[ ${precision} == "q4_j_vnni_b128" ]]; then
113-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 128 --scale_dtype fp32 --compute_type int8 --alg sym
113+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 128 --scale_dtype fp32 --compute_type int8 --alg sym
114114
elif [[ ${precision} == "q4_j_vnni_bf16_b32" ]]; then
115-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 32 --scale_dtype bf16 --compute_type int8 --alg sym
115+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 32 --scale_dtype bf16 --compute_type int8 --alg sym
116116
elif [[ ${precision} == "q4_j_vnni_b32" ]]; then
117-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 32 --scale_dtype fp32 --compute_type int8 --alg sym
117+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 32 --scale_dtype fp32 --compute_type int8 --alg sym
118118
elif [[ ${precision} == "q4_j_b32" ]]; then
119-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 32 --scale_dtype fp32 --compute_type fp32 --alg sym
119+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 32 --scale_dtype fp32 --compute_type fp32 --alg sym
120120
elif [[ ${precision} == "q4_j_b128" ]]; then
121-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 128 --scale_dtype fp32 --compute_type fp32 --alg sym
121+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 128 --scale_dtype fp32 --compute_type fp32 --alg sym
122122
elif [[ ${precision} == "q4_j_b128_asym" ]]; then
123-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 128 --scale_dtype fp32 --compute_type fp32 --alg asym
123+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 128 --scale_dtype fp32 --compute_type fp32 --alg asym
124124
elif [[ ${precision} == "q4_0" ]]; then
125-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 32 --compute_type ggml --alg sym
125+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 32 --compute_type ggml --alg sym
126126
elif [[ ${precision} == "q4_1" ]]; then
127-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4 --block_size 32 --compute_type ggml --alg asym
127+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4 --group_size 32 --compute_type ggml --alg asym
128128
elif [[ ${precision} == "q8_0" ]]; then
129-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 8 --block_size 32 --compute_type ggml --alg sym
129+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int8 --group_size 32 --compute_type ggml --alg sym
130130
else
131-
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --bits 4
131+
${quant_script} --model_file ${working_dir}/${model}-fp32.bin --out_file ${working_dir}/${model}-${precision}.bin --weight_dtype int4
132132
fi
133133
fi
134134
## run inference

intel_extension_for_transformers/llm/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from accelerate import init_empty_weights
2222
from neural_compressor import quantization
2323
from neural_compressor.config import PostTrainingQuantConfig
24-
from .nn import QuantizedLinearQBits # TODO: QuantizedLinearINT4, QuantizedLinearINT8
2524

2625

2726
logger = logging.getLogger(__name__)
@@ -108,6 +107,7 @@ def _replace_linear(
108107

109108
if isinstance(module, torch.nn.Linear) and name not in modules_to_not_convert:
110109
# Check if the current key is not in the `modules_to_not_convert`
110+
from .nn import QuantizedLinearQBits # TODO: QuantizedLinearINT4, QuantizedLinearINT8
111111
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
112112
with init_empty_weights():
113113
in_features = module.in_features

intel_extension_for_transformers/llm/runtime/graph/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ LLM one-click running script args explanations:
8282
| -p / --prompt | prompt to start generation with (default: empty) |
8383
| -n / --n_predict | number of tokens to predict (default: -1, -1 = infinity) |
8484
| -t / --threads | number of threads to use during computation (default: 56) |
85-
| -b / --batch_size | batch size for prompt processing (default: 512) |
85+
| -b / --batch_size_truncate | batch size for prompt processing (default: 512) |
8686
| -c / --ctx_size | size of the prompt context (default: 512, can not be larger than specific model's context window length) |
8787
| -s / --seed | NG seed (default: -1, use random seed for < 0) |
8888
| --repeat_penalty | penalize repeat sequence of tokens (default: 1.1, 1.0 = disabled) |
@@ -106,12 +106,12 @@ python scripts/convert.py --outtype f32 --outfile ne-f32.bin model_path
106106

107107
# quantize weights of fp32 ggml bin
108108
# model_name: llama, llama2, mpt, falcon, gptj, starcoder, dolly
109-
# to neuarl engine graph optimized q4_j with 128 block_size format (recommended)
109+
# optimized INT4 model with group size 128 (recommended)
110110
python scripts/quantize.py --model_name llama2 --model_file ne-f32.bin --out_file ne-q4_j.bin --weight_dtype int4 --block_size 128 --compute_type int8
111111

112112
# Alternativly you could run ggml q4_0 format like following
113113
python scripts/quantize.py --model_name llama2 --model_file ne-f32.bin --out_file ne-q4_0.bin --weight_dtype int4
114-
# or ues neuarl engine graph optimized q4_j with 32 block_size format
114+
# optimized INT4 model with group size 32
115115
python scripts/quantize.py --model_name llama2 --model_file ne-f32.bin --out_file ne-q4_j.bin --weight_dtype int4 --block_size 32 --compute_type int8
116116

117117
```
@@ -164,4 +164,3 @@ LLM running script args explanations:
164164
### 3. Tensor Parallelism cross nodes/sockets
165165

166166
We support tensor parallelism strategy for distributed inference/training on multi-node and multi-socket. You can refer to [tensor_parallelism.md](./tensor_parallelism.md) to enable this feature.
167-

intel_extension_for_transformers/llm/runtime/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# limitations under the License.
1717
import os
1818
from transformers import AutoConfig
19-
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert_model import convert_model
19+
from intel_extension_for_transformers.llm.runtime.graph.scripts.convert import convert_model
2020

2121
model_maps = {"gpt_neox": "gptneox", "RefinedWebModel": "falcon"}
2222

intel_extension_for_transformers/llm/runtime/graph/application/common.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -677,9 +677,9 @@ void quant_print_usage(int argc, char** argv, const quant_params& params) {
677677
" --config path to the configuration file (default: "
678678
")\n");
679679
fprintf(stderr, " --nthread N number of threads to use (default: 1)\n");
680-
fprintf(stderr, " --bits N number of bits to use for quantization (default: 4)\n");
680+
fprintf(stderr, " --weight_dtype N number of bits to use for quantization (default: 4)\n");
681681
fprintf(stderr, " --alg qquantization algorithm to use: sym/asym (default: sym)\n");
682-
fprintf(stderr, " --block_size N block size (default: 32)\n");
682+
fprintf(stderr, " --group_size N group size (default: 32)\n");
683683
fprintf(stderr, " --scale_dtype dtype fp32/bf16 type for scales (default: fp32)\n");
684684
fprintf(stderr,
685685
" --compute_type Gemm computation data type: int8/fp32/ggml (default: "
@@ -701,12 +701,12 @@ bool quant_params_parse(int argc, char** argv, quant_params& params) {
701701
params.config = argv[++i];
702702
} else if (arg == "--nthread") {
703703
params.nthread = std::stoi(argv[++i]);
704-
} else if (arg == "--bits") {
705-
params.bits = std::stoi(argv[++i]);
704+
} else if (arg == "--weight_dtype") {
705+
params.weight_dtype = argv[++i];
706706
} else if (arg == "--alg") {
707707
params.alg = argv[++i];
708-
} else if (arg == "--block_size") {
709-
params.block_size = std::stoi(argv[++i]);
708+
} else if (arg == "--group_size") {
709+
params.group_size = std::stoi(argv[++i]);
710710
} else if (arg == "--scale_dtype") {
711711
params.scale_dtype = argv[++i];
712712
} else if (arg == "--compute_type") {
@@ -734,19 +734,19 @@ bool quant_params_parse(int argc, char** argv, quant_params& params) {
734734

735735
ne_ftype quant_params_to_ftype(const quant_params& params) {
736736
if (params.compute_type == "ggml") {
737-
if (params.bits == 4) {
737+
if (params.weight_dtype == "int4") {
738738
if (params.alg == "sym") {
739739
return NE_FTYPE_MOSTLY_Q4_0;
740740
} else {
741741
return NE_FTYPE_MOSTLY_Q4_1;
742742
}
743-
} else if (params.bits == 5) {
743+
} else if (params.weight_dtype == "int5") {
744744
if (params.alg == "sym") {
745745
return NE_FTYPE_MOSTLY_Q5_0;
746746
} else {
747747
return NE_FTYPE_MOSTLY_Q5_1;
748748
}
749-
} else if (params.bits == 8) {
749+
} else if (params.weight_dtype == "int8") {
750750
return NE_FTYPE_MOSTLY_Q8_0;
751751
}
752752
} else {
@@ -757,19 +757,19 @@ ne_ftype quant_params_to_ftype(const quant_params& params) {
757757

758758
ne_type quant_params_to_type(const quant_params& params) {
759759
if (params.compute_type == "ggml") {
760-
if (params.bits == 4) {
760+
if (params.weight_dtype == "int4") {
761761
if (params.alg == "sym") {
762762
return NE_TYPE_Q4_0;
763763
} else {
764764
return NE_TYPE_Q4_1;
765765
}
766-
} else if (params.bits == 5) {
766+
} else if (params.weight_dtype == "int5") {
767767
if (params.alg == "sym") {
768768
return NE_TYPE_Q5_0;
769769
} else {
770770
return NE_TYPE_Q5_1;
771771
}
772-
} else if (params.bits == 8) {
772+
} else if (params.weight_dtype == "int8") {
773773
return NE_TYPE_Q8_0;
774774
}
775775
} else {

intel_extension_for_transformers/llm/runtime/graph/application/common.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,9 @@ struct quant_params {
148148
std::string config = "";
149149
int nthread = 1;
150150

151-
int32_t bits = 4;
151+
std::string weight_dtype = "int4";
152152
std::string alg = "sym";
153-
int32_t block_size = 32;
153+
int32_t group_size = 32;
154154
std::string scale_dtype = "fp32";
155155
std::string compute_type = "ggml";
156156
std::string model_name = "unknown";

intel_extension_for_transformers/llm/runtime/graph/application/main_pybind.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,9 @@ class Model {
5959
void reinit();
6060
std::string generate(const std::string& prompt, bool sentence_mode = true);
6161
bool is_token_end() { return token_eos; }
62-
static int quant_model(const std::string& model_path, const std::string& out_path, int bits, const std::string& alg,
63-
int block_size, const std::string& scale_dtype, const std::string& compute_type);
62+
static int quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
63+
const std::string& alg, int group_size, const std::string& scale_dtype,
64+
const std::string& compute_type);
6465

6566
private:
6667
model_context* ctx = nullptr;
@@ -212,8 +213,9 @@ int Model::post_process(float* logits) {
212213
return id;
213214
}
214215

215-
int Model::quant_model(const std::string& model_path, const std::string& out_path, int bits, const std::string& alg,
216-
int block_size, const std::string& scale_dtype, const std::string& compute_type) {
216+
int Model::quant_model(const std::string& model_path, const std::string& out_path, const std::string& weight_dtype,
217+
const std::string& alg, int group_size, const std::string& scale_dtype,
218+
const std::string& compute_type) {
217219
quant_params q_params;
218220
#ifdef MODEL_NAME
219221
q_params.model_name = MODEL_NAME;
@@ -226,9 +228,9 @@ int Model::quant_model(const std::string& model_path, const std::string& out_pat
226228
q_params.model_arch = mt;
227229
q_params.model_file = model_path;
228230
q_params.out_file = out_path;
229-
q_params.bits = bits;
231+
q_params.weight_dtype = weight_dtype;
230232
q_params.alg = alg;
231-
q_params.block_size = block_size;
233+
q_params.group_size = group_size;
232234
q_params.scale_dtype = scale_dtype;
233235
q_params.compute_type = compute_type;
234236

@@ -300,7 +302,7 @@ PYBIND11_MODULE(chatglm_cpp, m)
300302
.def("generate", &Model::generate, "Generate tokens with prompt", py::arg("prompt"),
301303
py::arg("sentence_mode") = true)
302304
.def_static("quant_model", &Model::quant_model, "Quantize model", py::arg("model_path"), py::arg("out_path"),
303-
py::arg("bits") = 4, py::arg("alg") = "sym", py::arg("block_size") = 32,
305+
py::arg("weight_dtype") = "int4", py::arg("alg") = "sym", py::arg("group_size") = 32,
304306
py::arg("scale_dtype") = "fp32", py::arg("compute_type") = "ggml")
305307
.def("is_token_end", &Model::is_token_end)
306308
.def("reinit", &Model::reinit);

intel_extension_for_transformers/llm/runtime/graph/models/model_utils/model_utils.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ model_token model_sample_token(struct model_context* ctx, model_token_data_array
782782
// quantization
783783
//
784784
quant_params_internal quant_params_to_internal(const quant_params& params) {
785-
return quant_params_internal{parse_bits(params.bits), parse_alg(params.alg), params.block_size,
785+
return quant_params_internal{parse_bits(params.weight_dtype), parse_alg(params.alg), params.group_size,
786786
parse_scale_dtype(params.scale_dtype), parse_compute_type(params.compute_type)};
787787
}
788788

@@ -799,7 +799,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
799799
if (params.alg != quant_alg::sym) {
800800
printf("Current not support asymmetric int8 computation, reset to symmetric\n");
801801
}
802-
if (params.block_size == -1) {
802+
if (params.group_size == -1) {
803803
using Kernel = WeiS4ClipFp32PerN<GcCompInt8, JblasAVX512F>;
804804
using KernelRef = WeiS4ClipFp32PerN<GcCompInt8, JblasNoSIMD>;
805805
static Kernel kernel;
@@ -815,7 +815,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
815815
using KernelRef = WeiS4ClipFp32<GcCompInt8KBlock, JblasNoSIMD>;
816816
static Kernel kernel;
817817
static KernelRef kernelref;
818-
packedw = kernel.createStorage(n, k, params.block_size);
818+
packedw = kernel.createStorage(n, k, params.group_size);
819819
if (cd->AVX512F()) {
820820
kernel.packTransposeWeight(n, k, f32ptr, k, packedw);
821821
} else {
@@ -827,7 +827,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
827827
using KernelRef = WeiS4ClipFp32<GcCompFp32, JblasNoSIMD>;
828828
static Kernel kernel;
829829
static Kernel kernelref;
830-
packedw = kernel.createStorage(n, k, params.block_size, params.alg == quant_alg::sym);
830+
packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym);
831831
if (cd->AVX512_FP16()) {
832832
kernel.packTransposeWeight(n, k, f32ptr, k, packedw);
833833
} else {
@@ -838,7 +838,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
838838
using KernelRef = WeiS4ClipFp32<GcCompBf16, JblasNoSIMD>;
839839
static Kernel kernel;
840840
static Kernel kernelref;
841-
packedw = kernel.createStorage(n, k, params.block_size, params.alg == quant_alg::sym);
841+
packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym);
842842
if (cd->AMX_BF16()) {
843843
kernel.packTransposeWeight(n, k, f32ptr, k, packedw);
844844
} else {
@@ -854,7 +854,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
854854
if (params.alg != quant_alg::sym) {
855855
printf("Current not support asymmetric int8 computation, reset to symmetric\n");
856856
}
857-
if (params.block_size == -1) {
857+
if (params.group_size == -1) {
858858
using Kernel = WeiS8Fp32PerN<GcCompInt8, JblasAVX512F>;
859859
using KernelRef = WeiS8Fp32PerN<GcCompInt8, JblasNoSIMD>;
860860
static Kernel kernel;
@@ -870,7 +870,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
870870
using KernelRef = WeiS8Fp32<GcCompInt8KBlock, JblasNoSIMD>;
871871
static Kernel kernel;
872872
static Kernel kernelref;
873-
packedw = kernel.createStorage(n, k, params.block_size);
873+
packedw = kernel.createStorage(n, k, params.group_size);
874874
if (cd->AVX512F()) {
875875
kernel.packTransposeWeight(n, k, f32ptr, k, packedw);
876876
} else {
@@ -882,7 +882,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
882882
using KernelRef = WeiS8Fp32<GcCompFp32, JblasNoSIMD>;
883883
static Kernel kernel;
884884
static Kernel kernelref;
885-
packedw = kernel.createStorage(n, k, params.block_size, params.alg == quant_alg::sym);
885+
packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym);
886886
if (cd->AVX512_FP16()) {
887887
kernel.packTransposeWeight(n, k, f32ptr, k, packedw);
888888
} else {
@@ -893,7 +893,7 @@ size_t jblas_quantize(const float* f32ptr, void* dstpr, const quant_params_inter
893893
using KernelRef = WeiS8Fp32<GcCompBf16, JblasNoSIMD>;
894894
static Kernel kernel;
895895
static Kernel kernelref;
896-
packedw = kernel.createStorage(n, k, params.block_size, params.alg == quant_alg::sym);
896+
packedw = kernel.createStorage(n, k, params.group_size, params.alg == quant_alg::sym);
897897
if (cd->AMX_BF16()) {
898898
kernel.packTransposeWeight(n, k, f32ptr, k, packedw);
899899
} else {

intel_extension_for_transformers/llm/runtime/graph/models/model_utils/quant_config.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
#include "core/data_types.h"
1919

2020
enum class quant_bits : int { q4 = 0, q8, count };
21-
static inline quant_bits parse_bits(int bits) {
22-
if (bits == 4) {
21+
static inline quant_bits parse_bits(const std::string& bits) {
22+
if (bits == "int4") {
2323
return quant_bits::q4;
2424
}
25-
if (bits == 8) {
25+
if (bits == "int8") {
2626
return quant_bits::q8;
2727
}
2828
return quant_bits::count;
@@ -88,15 +88,15 @@ static inline quant_comp parse_compute_type(std::string arg) {
8888
struct quant_params_internal {
8989
quant_bits bits = quant_bits::q4;
9090
quant_alg alg = quant_alg::sym;
91-
int32_t block_size = 32;
91+
int32_t group_size = 32;
9292
quant_sdtype scale_dtype = quant_sdtype::fp16;
9393
quant_comp compute_type = quant_comp::ggml;
9494
bool valid() const {
9595
return bits != quant_bits::count && alg != quant_alg::count && scale_dtype != quant_sdtype::count &&
9696
compute_type != quant_comp::count;
9797
}
9898
std::string getstr() {
99-
return std::to_string(int(bits)) + "_" + std::to_string(int(alg)) + "_" + std::to_string(block_size) + "_" +
99+
return std::to_string(int(bits)) + "_" + std::to_string(int(alg)) + "_" + std::to_string(group_size) + "_" +
100100
std::to_string(int(scale_dtype)) + "_" + std::to_string(int(compute_type));
101101
}
102102
};

0 commit comments

Comments
 (0)