@@ -273,6 +273,25 @@ def fp8_paged_mqa_logits_torch(
273273 return logits
274274
275275
276+ @functools .lru_cache
277+ def paged_mqa_logits_module ():
278+ paged_mqa_logits_module_path = None
279+ if importlib .util .find_spec ("aiter.ops.triton.pa_mqa_logits" ) is not None :
280+ paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
281+ elif (
282+ importlib .util .find_spec ("aiter.ops.triton.attention.pa_mqa_logits" ) is not None
283+ ):
284+ paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
285+
286+ if paged_mqa_logits_module_path is not None :
287+ try :
288+ module = importlib .import_module (paged_mqa_logits_module_path )
289+ return module
290+ except ImportError :
291+ return None
292+ return None
293+
294+
276295def rocm_fp8_paged_mqa_logits (
277296 q_fp8 : torch .Tensor ,
278297 kv_cache_fp8 : torch .Tensor ,
@@ -305,25 +324,6 @@ def rocm_fp8_paged_mqa_logits(
305324 """
306325 from vllm ._aiter_ops import rocm_aiter_ops
307326
308- @functools .lru_cache
309- def paged_mqa_logits_module ():
310- paged_mqa_logits_module_path = None
311- if importlib .util .find_spec ("aiter.ops.triton.pa_mqa_logits" ) is not None :
312- paged_mqa_logits_module_path = "aiter.ops.triton.pa_mqa_logits"
313- elif (
314- importlib .util .find_spec ("aiter.ops.triton.attention.pa_mqa_logits" )
315- is not None
316- ):
317- paged_mqa_logits_module_path = "aiter.ops.triton.attention.pa_mqa_logits"
318-
319- if paged_mqa_logits_module_path is not None :
320- try :
321- module = importlib .import_module (paged_mqa_logits_module_path )
322- return module
323- except ImportError :
324- return None
325- return None
326-
327327 aiter_paged_mqa_logits_module = None
328328 if rocm_aiter_ops .is_enabled ():
329329 aiter_paged_mqa_logits_module = paged_mqa_logits_module ()
@@ -400,6 +400,26 @@ def fp8_mqa_logits_torch(
400400 return logits
401401
402402
403+ @functools .lru_cache
404+ def mqa_logits_module ():
405+ mqa_logits_module_path = None
406+ if importlib .util .find_spec ("aiter.ops.triton.fp8_mqa_logits" ) is not None :
407+ mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
408+ elif (
409+ importlib .util .find_spec ("aiter.ops.triton.attention.fp8_mqa_logits" )
410+ is not None
411+ ):
412+ mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
413+
414+ if mqa_logits_module_path is not None :
415+ try :
416+ module = importlib .import_module (mqa_logits_module_path )
417+ return module
418+ except ImportError :
419+ return None
420+ return None
421+
422+
403423def rocm_fp8_mqa_logits (
404424 q : torch .Tensor ,
405425 kv : tuple [torch .Tensor , torch .Tensor ],
@@ -429,25 +449,6 @@ def rocm_fp8_mqa_logits(
429449 # path after aiter merge this kernel into main
430450 from vllm ._aiter_ops import rocm_aiter_ops
431451
432- @functools .lru_cache
433- def mqa_logits_module ():
434- mqa_logits_module_path = None
435- if importlib .util .find_spec ("aiter.ops.triton.fp8_mqa_logits" ) is not None :
436- mqa_logits_module_path = "aiter.ops.triton.fp8_mqa_logits"
437- elif (
438- importlib .util .find_spec ("aiter.ops.triton.attention.fp8_mqa_logits" )
439- is not None
440- ):
441- mqa_logits_module_path = "aiter.ops.triton.attention.fp8_mqa_logits"
442-
443- if mqa_logits_module_path is not None :
444- try :
445- module = importlib .import_module (mqa_logits_module_path )
446- return module
447- except ImportError :
448- return None
449- return None
450-
451452 aiter_mqa_logits_module = None
452453 if rocm_aiter_ops .is_enabled ():
453454 aiter_mqa_logits_module = mqa_logits_module ()
0 commit comments