Skip to content

Commit 40e7248

Browse files
committed
Use build_for_cudagraph_capture for metadata
Signed-off-by: luka <[email protected]>
1 parent 9ebc984 commit 40e7248

File tree

3 files changed

+29
-31
lines changed

3 files changed

+29
-31
lines changed

vllm/v1/attention/backends/flash_attn.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,16 @@ def reorder_batch(self, input_batch: "InputBatch",
322322
scheduler_output: "SchedulerOutput") -> bool:
323323
return False
324324

325-
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
326-
common_prefix_len: int,
327-
common_attn_metadata: CommonAttentionMetadata):
325+
def build_for_cudagraph_capture(
326+
self, num_tokens: int, common_attn_metadata: CommonAttentionMetadata
327+
) -> FlashAttentionMetadata:
328+
return self.build(num_tokens, num_tokens, num_tokens, 0,
329+
common_attn_metadata)
330+
331+
def build(
332+
self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
333+
common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata
334+
) -> FlashAttentionMetadata:
328335
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
329336
query_start_loc = common_attn_metadata.query_start_loc
330337
seq_lens = common_attn_metadata.seq_lens

vllm/v1/attention/backends/mla/common.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -449,19 +449,18 @@ def _build_decode(self, block_table_tensor: torch.Tensor,
449449
seq_lens=seq_lens,
450450
)
451451

452-
# TODO maybe use this?
453452
def build_for_cudagraph_capture(
454-
self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
455-
common_prefix_len: int,
453+
self, num_tokens: int,
456454
common_attn_metadata: CommonAttentionMetadata) -> M:
457-
# decode-only cudagraph capture
458-
assert num_reqs == num_actual_tokens
459-
self._num_decodes = num_reqs
460-
self._num_decode_tokens = num_reqs
455+
"""
456+
This method builds the metadata for full cudagraph capture.
457+
Currently, only decode is supported for full cudagraphs with MLA.
458+
"""
459+
self._num_decodes = num_tokens
460+
self._num_decode_tokens = num_tokens
461461
self._num_prefills = 0
462462
self._num_prefill_tokens = 0
463-
return self.build(num_reqs, num_actual_tokens, max_query_len,
464-
common_prefix_len, common_attn_metadata)
463+
return self.build(num_tokens, num_tokens, 1, 0, common_attn_metadata)
465464

466465
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
467466
common_prefix_len: int,

vllm/v1/worker/gpu_model_runner.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from vllm.attention import AttentionType, get_attn_backend
1515
from vllm.attention.backends.abstract import (AttentionBackend,
16-
AttentionMetadataBuilder, AttentionMetadata)
16+
AttentionMetadataBuilder)
1717
from vllm.attention.layer import Attention
1818
from vllm.attention.utils.fa_utils import get_flash_attn_version
1919
from vllm.config import (CompilationLevel, VllmConfig,
@@ -191,8 +191,8 @@ def __init__(
191191
# The convention is different.
192192
# self.cudagraph_batch_sizes sorts in ascending order.
193193
# The batch sizes in the config are in descending order.
194-
self.cudagraph_batch_sizes = list(reversed(
195-
self.compilation_config.cudagraph_capture_sizes))
194+
self.cudagraph_batch_sizes = list(
195+
reversed(self.compilation_config.cudagraph_capture_sizes))
196196

197197
self.full_cuda_graph = self.compilation_config.full_cuda_graph
198198

@@ -1726,19 +1726,12 @@ def _dummy_run(
17261726
attn_metadata = {}
17271727
for kv_cache_group_id, kv_cache_group_spec in enumerate(
17281728
self.kv_cache_config.kv_cache_groups):
1729-
# hack for flashMLA state
1730-
self.attn_metadata_builders[kv_cache_group_id]._num_decodes = num_tokens
1731-
self.attn_metadata_builders[kv_cache_group_id]._num_decode_tokens = num_tokens
1732-
self.attn_metadata_builders[kv_cache_group_id]._num_prefills = 0
1733-
1734-
attn_metadata_i = (
1735-
self.attn_metadata_builders[kv_cache_group_id].build(
1736-
num_reqs=num_tokens,
1737-
num_actual_tokens=num_tokens,
1738-
max_query_len=num_tokens,
1739-
common_prefix_len=0,
1729+
1730+
attn_metadata_i = self.attn_metadata_builders[
1731+
kv_cache_group_id].build_for_cudagraph_capture(
1732+
num_tokens=num_tokens,
17401733
common_attn_metadata=common_attn_metadata,
1741-
))
1734+
)
17421735
for layer_name in kv_cache_group_spec.layer_names:
17431736
attn_metadata[layer_name] = attn_metadata_i
17441737

@@ -2095,10 +2088,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
20952088
# group
20962089
self.drafter.validate_same_kv_cache_group(kv_cache_config)
20972090

2098-
bind_kv_cache(
2099-
kv_caches,
2100-
self.compilation_config.static_forward_context,
2101-
self.kv_caches)
2091+
bind_kv_cache(kv_caches,
2092+
self.compilation_config.static_forward_context,
2093+
self.kv_caches)
21022094

21032095
if has_kv_transfer_group():
21042096
get_kv_transfer_group().register_kv_caches(kv_caches)

0 commit comments

Comments
 (0)