-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[CUDA] Enable full cudagraph for FlashMLA #18581
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[CUDA] Enable full cudagraph for FlashMLA #18581
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Could you give some details on speedup associated with this modification? |
This pull request has merge conflicts that must be resolved before it can be |
vllm/v1/worker/gpu_model_runner.py
Outdated
direct_call = has_prefill(attn_metadata) and self.full_cuda_graph | ||
if direct_call: | ||
# Skip the outer model layer as inner model is compiled | ||
model_output = self.model.model.forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm thinking about setting a context to bypass Inductor backend directly here, but I guess we could also separately compile (and capture) prefill stage - any thoughts?
cc @youkaichao
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is hacky and depends on the fact that self.model.model
is the underlying nn.Module
we compile (which might not be the case if we are running multi-modality models).
I think we can have some fields in the forward context, like not_use_cudagraph
, and set it to true
only when we capture full cudagraph but get prefill data in the current batch. then, when we decide to replay the cudagraph, we can check this field.
it should just be changing this line
entry.cudagraph.replay() |
return entry.runnable(*args)
I haven't necessarily profiled this but it's meant to enable the double-batch-overlap optimization (prototype in #18415) |
d5c7a35
to
c794889
Compare
c794889
to
80f20ce
Compare
976e852
to
40e7248
Compare
This pull request has merge conflicts that must be resolved before it can be |
Hi, any further progress on this pr? |
Signed-off-by: luka <[email protected]>
Signed-off-by: luka <[email protected]>
Almost ready for review! |
5e3f7ab
to
30562a2
Compare
Signed-off-by: luka <[email protected]>
30562a2
to
9ea599d
Compare
Currently experiencing some issues when batching (in unit test), need to investigate further. |
This pull request has merge conflicts that must be resolved before it can be |
Enable fullgraph CUDAGraph capture for the FlashMLA decode case.
Hacks:
Tested with: