[AMD][Kernel][BugFix] Use correct scale in concat_and_cache_ds_mla_kernel when on gfx942 #32976
+9
−7
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Purpose
This PR updates
concat_and_cache_ds_mla_kernelto use the correct scale divisor when running ongfx942architectures on ROCm. Thetile_sizedivisor was448.0which does not work on AMD platforms with arch ofgfx942, e.g. MI300, MI325. This PR updates the divisor to be224.0ongfx942.Additionally, I consolidated the scale divisor into a constexpr float.
Test Plan
Use lm_eval to check accuracy on
DeepSeek-R1model.Command to run:
lm_eval --model vllm --model_args pretrained=/models/DeepSeek-R1,max_length=8192,tensor_parallel_size=8 --batch_size auto --tasks gsm8k --num_fewshot 5Test Result
I did a comparison using
lm_evalfor with and without this PR.I ran:
lm_eval --model vllm --model_args pretrained=/models/DeepSeek-R1,max_length=8192,tensor_parallel_size=8 --batch_size auto --tasks gsm8k --num_fewshot 5on MI300 and got the following results:
with this PR:
without this PR:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.