Skip to content

Commit c3d9a4d

Browse files
Add profiling multimodal model step and fix the OOM bug when profiling the multimodal model.
Signed-off-by: root <[email protected]>
1 parent 20767a0 commit c3d9a4d

File tree

2 files changed

+71
-7
lines changed

2 files changed

+71
-7
lines changed

vllm_ascend/models/qwen2_5_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def get_window_index(self, grid_thw):
383383
window_index = torch.cat(window_index, dim=0)
384384
return window_index, cu_window_seqlens
385385

386+
@torch.inference_mode()
386387
def forward(
387388
self,
388389
x: torch.Tensor,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 70 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
from vllm.sampling_params import SamplingType
5151
from vllm.sequence import IntermediateTensors
5252
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
53-
LayerBlockType, LazyLoader, cdiv)
53+
LayerBlockType, LazyLoader, cdiv, is_pin_memory_available)
5454
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
5555
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
5656
KVCacheSpec)
@@ -143,7 +143,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
143143
else:
144144
self.chunked_prefill_enabled = True
145145
self.device = device
146-
146+
self.pin_memory = is_pin_memory_available()
147147
self.is_multimodal_model = self.model_config.is_multimodal_model
148148
self.block_size = vllm_config.cache_config.block_size
149149

@@ -1787,11 +1787,74 @@ def _dummy_run(
17871787
return hidden_states
17881788

17891789
def profile_run(self) -> None:
1790-
# FIXME Profile with multimodal encoder & encoder cache.
1791-
# current _profile_multimodal() using PyTorch SDPA backend method not
1792-
# support for window/full attn to reduce Memcpy operations, so will cause
1793-
# Out Of Memory problem, so we currently don't use self._profile_multimodal()
1794-
# self._profile_multimodal()
1790+
# Profile with multimodal encoder & encoder cache.
1791+
# TODO: handle encoder-decoder models once we support them.
1792+
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
1793+
and self.encoder_cache_size > 0):
1794+
1795+
# NOTE: Currently model is profiled with a single non-text
1796+
# modality with the max possible input tokens even when
1797+
# it supports multiple.
1798+
max_tokens_by_modality_dict = self.mm_registry \
1799+
.get_max_tokens_per_item_by_nonzero_modality(self.model_config)
1800+
dummy_data_modality, max_tokens_per_mm_item = max(
1801+
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
1802+
1803+
# Check how many items of this modality can be supported by
1804+
# the encoder budget.
1805+
encoder_budget = min(self.max_num_encoder_input_tokens,
1806+
self.encoder_cache_size)
1807+
1808+
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
1809+
max_tokens_per_mm_item)
1810+
1811+
# Check how many items of this modality can be supported by
1812+
# the decoder budget.
1813+
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
1814+
self.model_config)[dummy_data_modality]
1815+
1816+
# NOTE: We do not consider max_num_batched_tokens on purpose
1817+
# because the multimodal embeddings can be generated in advance
1818+
# and chunked prefilled.
1819+
max_num_mm_items_decoder_budget = self.max_num_reqs * \
1820+
max_mm_items_per_req
1821+
1822+
max_num_mm_items = min(max_num_mm_items_encoder_budget,
1823+
max_num_mm_items_decoder_budget)
1824+
1825+
logger.info(
1826+
"Encoder cache will be initialized with a budget of %s tokens,"
1827+
" and profiled with %s %s items of the maximum feature size.",
1828+
encoder_budget, max_num_mm_items, dummy_data_modality)
1829+
1830+
# Create dummy batch of multimodal inputs.
1831+
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
1832+
model_config=self.model_config,
1833+
seq_len=self.max_num_tokens,
1834+
mm_counts={
1835+
dummy_data_modality: 1
1836+
},
1837+
).multi_modal_data
1838+
1839+
batched_dummy_mm_inputs = MultiModalKwargs.batch(
1840+
[dummy_mm_kwargs] * max_num_mm_items,
1841+
pin_memory=self.pin_memory)
1842+
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
1843+
batched_dummy_mm_inputs,
1844+
device=self.device,
1845+
)
1846+
1847+
# Run multimodal encoder.
1848+
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
1849+
**batched_dummy_mm_inputs)
1850+
1851+
sanity_check_mm_encoder_outputs(
1852+
dummy_encoder_outputs,
1853+
expected_num_items=max_num_mm_items,
1854+
)
1855+
1856+
# Cache the dummy encoder outputs.
1857+
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
17951858

17961859
# For profile, have maximum num_reqs and that collectively have
17971860
# maximum num_tokens.

0 commit comments

Comments
 (0)