Skip to content

Commit 76ebbe5

Browse files
committed
Add MPS GGUF dequantization support
Add Metal kernel path for GGUF quantized models on MPS (Apple Metal). Implements dequant+matmul for Q4_0, Q8_0, and Q4_K types via the dequant_gguf kernel package, with a numpy-based fallback using the gguf Python library. Changes: - gguf.py: Add MPS branch in _fused_mul_mat_gguf and _apply_gguf_embedding to route through gguf_dequant_on_mps instead of CUDA ops - gguf.py: Fix get_supported_act_dtypes and get_min_capability for MPS - mps_dequant.py: Add GGUF section with Metal kernel import, numpy fallback, and gguf_dequant_on_mps entry point Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6) Signed-off-by: Rob Taylor <rob.taylor@chipflow.io>
1 parent 2f8681e commit 76ebbe5

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

vllm/model_executor/layers/quantization/gguf.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,17 @@ def get_name(self) -> QuantizationMethods:
6262
def get_supported_act_dtypes(self) -> list[torch.dtype]:
6363
# GGUF dequantization kernels use half precision (fp16) internally.
6464
# bfloat16 has precision issues on Blackwell devices.
65+
if current_platform.is_mps():
66+
return [torch.half, torch.float32]
6567
if current_platform.has_device_capability(100):
6668
logger.warning_once("GGUF has precision issues with bfloat16 on Blackwell.")
6769
return [torch.half, torch.float32]
6870
return [torch.half, torch.bfloat16, torch.float32]
6971

7072
@classmethod
7173
def get_min_capability(cls) -> int:
74+
if current_platform.is_mps():
75+
return -1 # MPS has no CUDA compute capability
7276
return 60
7377

7478
@classmethod
@@ -188,17 +192,34 @@ def is_layer_skipped_gguf(
188192
def _fused_mul_mat_gguf(
189193
x: torch.Tensor, qweight: torch.Tensor, qweight_type: int
190194
) -> torch.Tensor:
191-
if qweight_type in IMATRIX_QUANT_TYPES:
192-
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
193-
else:
194-
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
195195
# HACK: when doing chunked prefill we don't generate output tokens
196196
# so input to logits generator is empty which causes invalid parameter
197197
if x.shape[0] == 0:
198198
return torch.empty(x.shape[0], qweight.shape[0], dtype=x.dtype, device=x.device)
199199
# there is no need to call any kernel for fp16/bf16
200200
if qweight_type in UNQUANTIZED_TYPES:
201201
return x @ qweight.T
202+
203+
# MPS path: dequantize then matmul (no fused CUDA kernels available)
204+
if current_platform.is_mps():
205+
if qweight_type in DEQUANT_TYPES:
206+
from vllm.model_executor.layers.quantization.utils.mps_dequant import (
207+
gguf_dequant_on_mps,
208+
)
209+
210+
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
211+
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
212+
weight = gguf_dequant_on_mps(qweight, qweight_type, *shape, x.dtype)
213+
return x @ weight.T
214+
qweight_type = WeightType(qweight_type)
215+
raise NotImplementedError(
216+
f"Unsupported GGUF quantization type on MPS: {qweight_type}"
217+
)
218+
219+
if qweight_type in IMATRIX_QUANT_TYPES:
220+
mmvq_safe = 8 if qweight.shape[0] > 5120 else 16
221+
else:
222+
mmvq_safe = 2 if qweight.shape[0] > 5120 else 6
202223
# enable MMVQ in contiguous batching with batch_size=1
203224
if x.shape[0] <= mmvq_safe and qweight_type in MMVQ_QUANT_TYPES:
204225
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
@@ -385,9 +406,18 @@ def _apply_gguf_embedding(
385406
x_flat = x.flatten()
386407
assert hidden_size == qweight.shape[1] // type_size * block_size
387408
quant = torch.index_select(qweight, dim=0, index=x_flat)
388-
dequant = ops.ggml_dequantize(
389-
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
390-
)
409+
if current_platform.is_mps():
410+
from vllm.model_executor.layers.quantization.utils.mps_dequant import (
411+
gguf_dequant_on_mps,
412+
)
413+
414+
dequant = gguf_dequant_on_mps(
415+
quant, qweight_type, x_flat.shape[0], hidden_size, dtype
416+
)
417+
else:
418+
dequant = ops.ggml_dequantize(
419+
quant, qweight_type, hidden_size, x_flat.shape[0], dtype
420+
)
391421
return dequant.view(*x.shape, hidden_size)
392422
else:
393423
qweight_type = WeightType(qweight_type)

vllm/model_executor/layers/quantization/utils/mps_dequant.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
"""MPS (Metal) dequantization utilities for AWQ and GPTQ int4 models.
3+
"""MPS (Metal) dequantization utilities for AWQ, GPTQ, and GGUF models.
44
5-
Uses the dequant_int4 Metal kernel package when available, with a pure
6-
PyTorch fallback for environments where the kernel isn't installed.
5+
Uses Metal kernel packages when available, with pure PyTorch/numpy
6+
fallbacks for environments where the kernels aren't installed.
77
"""
88

99
from typing import Any
@@ -17,6 +17,10 @@
1717
_metal_dequant = None
1818
_metal_import_attempted = False
1919

20+
# Metal kernel types: Q4_0=2, Q4_1=3, Q5_0=6, Q5_1=7, Q8_0=8,
21+
# Q2_K=10, Q3_K=11, Q4_K=12, Q5_K=13, Q6_K=14
22+
_METAL_GGUF_TYPES = {2, 3, 6, 7, 8, 10, 11, 12, 13, 14}
23+
2024

2125
def _get_metal_dequant():
2226
"""Try to import Metal dequant kernel package (cached)."""
@@ -223,3 +227,66 @@ def gptq_dequant_matmul(
223227
if bias is not None:
224228
out.add_(bias)
225229
return out.reshape(out_shape)
230+
231+
232+
# ── GGUF ──
233+
234+
_metal_dequant_gguf = None
235+
_metal_gguf_import_attempted = False
236+
237+
238+
def _get_metal_dequant_gguf():
239+
"""Try to import Metal dequant_gguf kernel package (cached)."""
240+
global _metal_dequant_gguf, _metal_gguf_import_attempted
241+
if not _metal_gguf_import_attempted:
242+
_metal_gguf_import_attempted = True
243+
try:
244+
import dequant_gguf
245+
246+
_metal_dequant_gguf = dequant_gguf
247+
logger.info("Using Metal dequant_gguf kernel for GGUF dequantization")
248+
except ImportError:
249+
logger.info(
250+
"dequant_gguf Metal kernel not found, "
251+
"falling back to numpy-based GGUF dequantization"
252+
)
253+
return _metal_dequant_gguf
254+
255+
256+
def _pytorch_dequant_gguf(
257+
W: torch.Tensor,
258+
quant_type: int,
259+
m: int,
260+
n: int,
261+
dtype: torch.dtype | None = None,
262+
) -> torch.Tensor:
263+
"""Fallback GGUF dequantization using the gguf Python library.
264+
265+
This does a GPU→CPU→GPU round-trip via numpy, so it's slow but correct.
266+
"""
267+
import numpy as np
268+
from gguf import GGMLQuantizationType, dequantize
269+
270+
qt = GGMLQuantizationType(quant_type)
271+
w_np = W.cpu().numpy().view(np.uint8)
272+
result = dequantize(w_np, qt)
273+
out_dtype = dtype if dtype is not None else torch.float16
274+
return torch.tensor(result, dtype=out_dtype, device=W.device).reshape(m, n)
275+
276+
277+
def gguf_dequant_on_mps(
278+
W: torch.Tensor,
279+
quant_type: int,
280+
m: int,
281+
n: int,
282+
dtype: torch.dtype | None = None,
283+
) -> torch.Tensor:
284+
"""Dequantize GGUF weights on MPS.
285+
286+
Uses Metal kernel if available for all standard GGUF types,
287+
falls back to gguf library (numpy) for unsupported types (IQ*).
288+
"""
289+
metal = _get_metal_dequant_gguf()
290+
if metal is not None and quant_type in _METAL_GGUF_TYPES:
291+
return metal.dequantize_gguf(W, quant_type, m, n, dtype)
292+
return _pytorch_dequant_gguf(W, quant_type, m, n, dtype)

0 commit comments

Comments
 (0)