Skip to content

Commit f26fcdf

Browse files
authored
[Bugfix][ROCm] Fix lru_cache on paged_mqa_logits_module (vllm-project#37547)
Signed-off-by: Stig-Arne Grönroos <stig-arne.gronroos@amd.com>
1 parent bc9c6fb commit f26fcdf

File tree

1 file changed

+39
-38
lines changed

1 file changed

+39
-38
lines changed

vllm/v1/attention/ops/rocm_aiter_mla_sparse.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch(
273273
return logits
274274

275275

276+
@functools.lru_cache
277+
def paged_mqa_logits_module():
278+
paged_mqa_logits_module_path = None
279+
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
280+
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
281+
elif (
282+
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits") is not None
283+
):
284+
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
285+
286+
if paged_mqa_logits_module_path is not None:
287+
try:
288+
module = importlib.import_module(paged_mqa_logits_module_path)
289+
return module
290+
except ImportError:
291+
return None
292+
return None
293+
294+
276295
def rocm_fp8_paged_mqa_logits(
277296
q_fp8: torch.Tensor,
278297
kv_cache_fp8: torch.Tensor,
@@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits(
305324
"""
306325
from vllm._aiter_ops import rocm_aiter_ops
307326

308-
@functools.lru_cache
309-
def paged_mqa_logits_module():
310-
paged_mqa_logits_module_path = None
311-
if importlib.util.find_spec("aiter.ops.triton.pa_mqa_logits") is not None:
312-
paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
313-
elif (
314-
importlib.util.find_spec("aiter.ops.triton.attention.pa_mqa_logits")
315-
is not None
316-
):
317-
paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
318-
319-
if paged_mqa_logits_module_path is not None:
320-
try:
321-
module = importlib.import_module(paged_mqa_logits_module_path)
322-
return module
323-
except ImportError:
324-
return None
325-
return None
326-
327327
aiter_paged_mqa_logits_module = None
328328
if rocm_aiter_ops.is_enabled():
329329
aiter_paged_mqa_logits_module = paged_mqa_logits_module()
@@ -400,6 +400,26 @@ def fp8_mqa_logits_torch(
400400
return logits
401401

402402

403+
@functools.lru_cache
404+
def mqa_logits_module():
405+
mqa_logits_module_path = None
406+
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
407+
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
408+
elif (
409+
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
410+
is not None
411+
):
412+
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
413+
414+
if mqa_logits_module_path is not None:
415+
try:
416+
module = importlib.import_module(mqa_logits_module_path)
417+
return module
418+
except ImportError:
419+
return None
420+
return None
421+
422+
403423
def rocm_fp8_mqa_logits(
404424
q: torch.Tensor,
405425
kv: tuple[torch.Tensor, torch.Tensor],
@@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits(
429449
# path after aiter merge this kernel into main
430450
from vllm._aiter_ops import rocm_aiter_ops
431451

432-
@functools.lru_cache
433-
def mqa_logits_module():
434-
mqa_logits_module_path = None
435-
if importlib.util.find_spec("aiter.ops.triton.fp8_mqa_logits") is not None:
436-
mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
437-
elif (
438-
importlib.util.find_spec("aiter.ops.triton.attention.fp8_mqa_logits")
439-
is not None
440-
):
441-
mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
442-
443-
if mqa_logits_module_path is not None:
444-
try:
445-
module = importlib.import_module(mqa_logits_module_path)
446-
return module
447-
except ImportError:
448-
return None
449-
return None
450-
451452
aiter_mqa_logits_module = None
452453
if rocm_aiter_ops.is_enabled():
453454
aiter_mqa_logits_module = mqa_logits_module()

0 commit comments

Comments
 (0)