Skip to content

Commit 0a8e007

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Enable Vulkan 4-bit weight only quantization in export_llama (#6235)
Summary: Pull Request resolved: #6235 As title. ghstack-source-id: 248349849 exported-using-ghexport Reviewed By: jorgep31415 Differential Revision: D64406456 fbshipit-source-id: 20890a7391f821b9b58063bef305909d34d48a18
1 parent 58ee33d commit 0a8e007

File tree

5 files changed

+9
-5
lines changed

5 files changed

+9
-5
lines changed

backends/vulkan/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ runtime.python_library(
6262
],
6363
visibility = [
6464
"//executorch/backends/...",
65+
"//executorch/examples/...",
6566
],
6667
deps = [
6768
":int4_weight_only_quantizer",

examples/models/llama2/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ runtime.python_library(
103103
deps = [
104104
"//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform",
105105
"//caffe2:torch",
106+
"//executorch/backends/vulkan/_passes:vulkan_passes",
106107
"//executorch/examples/models:model_base",
107108
"//executorch/examples/models:models",
108109
"//executorch/extension/llm/custom_ops:custom_ops_aot_py",

examples/models/llama2/export_llama_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def build_args_parser() -> argparse.ArgumentParser:
157157
"--quantization_mode",
158158
type=str,
159159
default=None,
160-
choices=["int8", "8da4w", "8da4w-gptq"],
160+
choices=["int8", "8da4w", "8da4w-gptq", "vulkan_4w"],
161161
help="type of quantization",
162162
)
163163

examples/models/llama2/source_transformation/quantize.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import torch.nn as nn
1313
import torch.nn.functional as F
1414

15+
from executorch.backends.vulkan._passes import VkInt4WeightOnlyQuantizer
16+
1517
from executorch.extension.llm.export.builder import DType
1618

1719
from sentencepiece import SentencePieceProcessor
@@ -31,7 +33,7 @@
3133
fsLinear = nn.Linear
3234

3335

34-
def quantize(
36+
def quantize( # noqa C901
3537
model: torch.nn.Module,
3638
qmode: str,
3739
activation_dtype: Optional[DType],
@@ -131,6 +133,9 @@ def quantize(
131133
)
132134
model = gptq_quantizer.quantize(model, inputs)
133135
return model
136+
elif qmode == "vulkan_4w":
137+
model = VkInt4WeightOnlyQuantizer().quantize(model)
138+
return model
134139
else:
135140
raise Exception(f"Unrecognized quantize mode: {qmode}")
136141

extension/llm/export/partitioner_lib.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@ def get_vulkan_partitioner(
3737
assert (
3838
dtype_override == "fp32" or dtype_override is None
3939
), "Vulkan backend does not support non fp32 dtypes at the moment"
40-
assert (
41-
quantization_mode is None
42-
), "Vulkan backend does not support quantization at the moment"
4340
from executorch.backends.vulkan.partitioner.vulkan_partitioner import (
4441
VulkanPartitioner,
4542
)

0 commit comments

Comments
 (0)