|
13 | 13 |
|
14 | 14 | from vllm.attention import AttentionType, get_attn_backend
|
15 | 15 | from vllm.attention.backends.abstract import (AttentionBackend,
|
16 |
| - AttentionMetadataBuilder, AttentionMetadata) |
| 16 | + AttentionMetadataBuilder) |
17 | 17 | from vllm.attention.layer import Attention
|
18 | 18 | from vllm.attention.utils.fa_utils import get_flash_attn_version
|
19 | 19 | from vllm.config import (CompilationLevel, VllmConfig,
|
@@ -191,8 +191,8 @@ def __init__(
|
191 | 191 | # The convention is different.
|
192 | 192 | # self.cudagraph_batch_sizes sorts in ascending order.
|
193 | 193 | # The batch sizes in the config are in descending order.
|
194 |
| - self.cudagraph_batch_sizes = list(reversed( |
195 |
| - self.compilation_config.cudagraph_capture_sizes)) |
| 194 | + self.cudagraph_batch_sizes = list( |
| 195 | + reversed(self.compilation_config.cudagraph_capture_sizes)) |
196 | 196 |
|
197 | 197 | self.full_cuda_graph = self.compilation_config.full_cuda_graph
|
198 | 198 |
|
@@ -1726,19 +1726,12 @@ def _dummy_run(
|
1726 | 1726 | attn_metadata = {}
|
1727 | 1727 | for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
1728 | 1728 | self.kv_cache_config.kv_cache_groups):
|
1729 |
| - # hack for flashMLA state |
1730 |
| - self.attn_metadata_builders[kv_cache_group_id]._num_decodes = num_tokens |
1731 |
| - self.attn_metadata_builders[kv_cache_group_id]._num_decode_tokens = num_tokens |
1732 |
| - self.attn_metadata_builders[kv_cache_group_id]._num_prefills = 0 |
1733 |
| - |
1734 |
| - attn_metadata_i = ( |
1735 |
| - self.attn_metadata_builders[kv_cache_group_id].build( |
1736 |
| - num_reqs=num_tokens, |
1737 |
| - num_actual_tokens=num_tokens, |
1738 |
| - max_query_len=num_tokens, |
1739 |
| - common_prefix_len=0, |
| 1729 | + |
| 1730 | + attn_metadata_i = self.attn_metadata_builders[ |
| 1731 | + kv_cache_group_id].build_for_cudagraph_capture( |
| 1732 | + num_tokens=num_tokens, |
1740 | 1733 | common_attn_metadata=common_attn_metadata,
|
1741 |
| - )) |
| 1734 | + ) |
1742 | 1735 | for layer_name in kv_cache_group_spec.layer_names:
|
1743 | 1736 | attn_metadata[layer_name] = attn_metadata_i
|
1744 | 1737 |
|
@@ -2095,10 +2088,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
2095 | 2088 | # group
|
2096 | 2089 | self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
2097 | 2090 |
|
2098 |
| - bind_kv_cache( |
2099 |
| - kv_caches, |
2100 |
| - self.compilation_config.static_forward_context, |
2101 |
| - self.kv_caches) |
| 2091 | + bind_kv_cache(kv_caches, |
| 2092 | + self.compilation_config.static_forward_context, |
| 2093 | + self.kv_caches) |
2102 | 2094 |
|
2103 | 2095 | if has_kv_transfer_group():
|
2104 | 2096 | get_kv_transfer_group().register_kv_caches(kv_caches)
|
|
0 commit comments