Skip to content

Commit 292d018

Browse files
Add profiling multimodal model step and fix the OOM bug when profiling the multimodal model.
Signed-off-by: root <[email protected]>
1 parent 5f5800b commit 292d018

File tree

3 files changed

+74
-96
lines changed

3 files changed

+74
-96
lines changed

.github/workflows/accuracy_test.yaml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,6 @@ jobs:
115115
contains(github.event.pull_request.labels.*.name, 'vl-accuracy-test') &&
116116
'["Qwen/Qwen2.5-VL-7B-Instruct"]'
117117
) }}
118-
# Remove exclude after https://github.com/vllm-project/vllm-ascend/issues/1044 resolved
119-
exclude:
120-
- model_name: Qwen/Qwen2.5-VL-7B-Instruct
121-
vllm_use_version: 1
122118

123119
fail-fast: false
124120
name: ${{ matrix.model_name }} accuracy V${{ matrix.vllm_use_version }}
@@ -225,6 +221,7 @@ jobs:
225221
env:
226222
PYTORCH_NPU_ALLOC_CONF: max_split_size_mb:256
227223
VLLM_USE_V1: ${{ matrix.vllm_use_version }}
224+
VLLM_LOGGING_LEVEL: DEBUG
228225
run: |
229226
model_base_name=$(basename ${{ matrix.model_name }})
230227
markdown_name="${model_base_name}-V${{ matrix.vllm_use_version }}"

benchmarks/scripts/run_accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def run_accuracy_unimodal(queue, model, dataset):
7777

7878
def run_accuracy_multimodal(queue, model, dataset):
7979
try:
80-
model_args = f"pretrained={model},max_model_len=8192,dtype=auto,tensor_parallel_size=4,max_images=2"
80+
model_args = f"pretrained={model},max_model_len=8192,dtype=auto,tensor_parallel_size=4,max_images=2,gpu_memory_utilization=0.6"
8181
results = lm_eval.simple_evaluate(
8282
model="vllm-vlm",
8383
model_args=model_args,

vllm_ascend/worker/model_runner_v1.py

Lines changed: 72 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
from vllm.multimodal.utils import group_mm_inputs_by_modality
5050
from vllm.sampling_params import SamplingType
5151
from vllm.sequence import IntermediateTensors
52-
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
53-
LayerBlockType, LazyLoader, cdiv)
52+
from vllm.utils import (DeviceMemoryProfiler, LazyLoader, cdiv,
53+
is_pin_memory_available)
5454
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
5555
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
5656
KVCacheSpec)
@@ -143,7 +143,7 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
143143
else:
144144
self.chunked_prefill_enabled = True
145145
self.device = device
146-
146+
self.pin_memory = is_pin_memory_available()
147147
self.is_multimodal_model = self.model_config.is_multimodal_model
148148
self.block_size = vllm_config.cache_config.block_size
149149

@@ -1607,89 +1607,6 @@ def execute_model(
16071607

16081608
return model_runner_output
16091609

1610-
def _profile_multimodal(self) -> None:
1611-
# TODO: handle encoder-decoder models once we support them.
1612-
# NOTE: Currently model is profiled with a single non-text
1613-
# modality with the max possible input tokens even when
1614-
# it supports multiple.
1615-
1616-
if (not self.is_multimodal_model
1617-
or self.max_num_encoder_input_tokens <= 0
1618-
or self.encoder_cache_size <= 0):
1619-
return
1620-
1621-
max_tokens_by_modality_dict = (
1622-
MULTIMODAL_REGISTRY.get_max_tokens_per_item_by_nonzero_modality(
1623-
self.model_config))
1624-
dummy_data_modality, max_tokens_per_mm_item = max(
1625-
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
1626-
1627-
# Check how many items of this modality can be supported by
1628-
# the encoder budget.
1629-
encoder_budget = min(self.max_num_encoder_input_tokens,
1630-
self.encoder_cache_size)
1631-
1632-
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
1633-
max_tokens_per_mm_item)
1634-
1635-
# Check how many items of this modality can be supported by
1636-
# the decoder budget.
1637-
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
1638-
self.model_config)[dummy_data_modality]
1639-
1640-
# NOTE: We do not consider max_num_batched_tokens on purpose
1641-
# because the multimodal embeddings can be generated in advance
1642-
# and chunked prefilled.
1643-
max_num_mm_items_decoder_budget = self.max_num_reqs * \
1644-
max_mm_items_per_req
1645-
1646-
max_num_mm_items = min(max_num_mm_items_encoder_budget,
1647-
max_num_mm_items_decoder_budget)
1648-
1649-
logger.info(
1650-
"Encoder cache will be initialized with a budget of %s tokens,"
1651-
" and profiled with %s %s items of the maximum feature size.",
1652-
encoder_budget, max_num_mm_items, dummy_data_modality)
1653-
1654-
# Create dummy batch of multimodal inputs.
1655-
dummy_request_data = self.input_registry.dummy_data_for_profiling(
1656-
model_config=self.model_config,
1657-
seq_len=self.max_num_tokens,
1658-
mm_registry=self.mm_registry,
1659-
)
1660-
dummy_mm_data = dummy_request_data.multi_modal_data
1661-
1662-
if not isinstance(dummy_mm_data, MultiModalKwargs):
1663-
# TODO: Delete this check once input mapper is fully removed.
1664-
raise RuntimeError("Legacy input mapper is not supported in V1")
1665-
1666-
# Dummy data definition in V0 may contain multiple multimodal items
1667-
# (e.g, multiple images) for a single request, therefore here we
1668-
# always replicate first item by max_num_mm_items times since in V1
1669-
# they are scheduled to be processed separately.
1670-
1671-
dummy_mm_item = dummy_mm_data.get_item(modality=dummy_data_modality,
1672-
item_index=0)
1673-
dummy_mm_kwargs = MultiModalKwargs.from_items([dummy_mm_item])
1674-
1675-
batched_dummy_mm_inputs = MultiModalKwargs.batch([dummy_mm_kwargs] *
1676-
max_num_mm_items)
1677-
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
1678-
batched_dummy_mm_inputs, device=self.device)
1679-
1680-
# Run multimodal encoder.
1681-
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
1682-
**batched_dummy_mm_inputs)
1683-
assert len(dummy_encoder_outputs) == max_num_mm_items, (
1684-
"Expected dimension 0 of encoder outputs to match the number "
1685-
f"of multimodal data items: {max_num_mm_items}, got "
1686-
f"{len(dummy_encoder_outputs)=} instead. This is most likely "
1687-
"due to the 'get_multimodal_embeddings' method of the model "
1688-
"not implemented correctly.")
1689-
1690-
# Cache the dummy encoder outputs.
1691-
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
1692-
16931610
@torch.inference_mode()
16941611
def _dummy_run(
16951612
self,
@@ -1796,12 +1713,76 @@ def _dummy_run(
17961713
self.drafter.dummy_run(num_tokens)
17971714
return hidden_states
17981715

1716+
@torch.inference_mode()
17991717
def profile_run(self) -> None:
1800-
# FIXME Profile with multimodal encoder & encoder cache.
1801-
# current _profile_multimodal() using PyTorch SDPA backend method not
1802-
# support for window/full attn to reduce Memcpy operations, so will cause
1803-
# Out Of Memory problem, so we currently don't use self._profile_multimodal()
1804-
# self._profile_multimodal()
1718+
# Profile with multimodal encoder & encoder cache.
1719+
# TODO: handle encoder-decoder models once we support them.
1720+
if (self.is_multimodal_model and self.max_num_encoder_input_tokens > 0
1721+
and self.encoder_cache_size > 0):
1722+
1723+
# NOTE: Currently model is profiled with a single non-text
1724+
# modality with the max possible input tokens even when
1725+
# it supports multiple.
1726+
max_tokens_by_modality_dict = self.mm_registry \
1727+
.get_max_tokens_per_item_by_nonzero_modality(self.model_config)
1728+
dummy_data_modality, max_tokens_per_mm_item = max(
1729+
max_tokens_by_modality_dict.items(), key=lambda item: item[1])
1730+
1731+
# Check how many items of this modality can be supported by
1732+
# the encoder budget.
1733+
encoder_budget = min(self.max_num_encoder_input_tokens,
1734+
self.encoder_cache_size)
1735+
1736+
max_num_mm_items_encoder_budget = cdiv(encoder_budget,
1737+
max_tokens_per_mm_item)
1738+
1739+
# Check how many items of this modality can be supported by
1740+
# the decoder budget.
1741+
max_mm_items_per_req = self.mm_registry.get_mm_limits_per_prompt(
1742+
self.model_config)[dummy_data_modality]
1743+
1744+
# NOTE: We do not consider max_num_batched_tokens on purpose
1745+
# because the multimodal embeddings can be generated in advance
1746+
# and chunked prefilled.
1747+
max_num_mm_items_decoder_budget = self.max_num_reqs * \
1748+
max_mm_items_per_req
1749+
1750+
max_num_mm_items = min(max_num_mm_items_encoder_budget,
1751+
max_num_mm_items_decoder_budget)
1752+
1753+
logger.info(
1754+
"Encoder cache will be initialized with a budget of %s tokens,"
1755+
" and profiled with %s %s items of the maximum feature size.",
1756+
encoder_budget, max_num_mm_items, dummy_data_modality)
1757+
1758+
# Create dummy batch of multimodal inputs.
1759+
dummy_mm_kwargs = self.mm_registry.get_decoder_dummy_data(
1760+
model_config=self.model_config,
1761+
seq_len=self.max_num_tokens,
1762+
mm_counts={
1763+
dummy_data_modality: 1
1764+
},
1765+
).multi_modal_data
1766+
1767+
batched_dummy_mm_inputs = MultiModalKwargs.batch(
1768+
[dummy_mm_kwargs] * max_num_mm_items,
1769+
pin_memory=self.pin_memory)
1770+
batched_dummy_mm_inputs = MultiModalKwargs.as_kwargs(
1771+
batched_dummy_mm_inputs,
1772+
device=self.device,
1773+
)
1774+
1775+
# Run multimodal encoder.
1776+
dummy_encoder_outputs = self.model.get_multimodal_embeddings(
1777+
**batched_dummy_mm_inputs)
1778+
1779+
sanity_check_mm_encoder_outputs(
1780+
dummy_encoder_outputs,
1781+
expected_num_items=max_num_mm_items,
1782+
)
1783+
1784+
# Cache the dummy encoder outputs.
1785+
self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs))
18051786

18061787
# For profile, have maximum num_reqs and that collectively have
18071788
# maximum num_tokens.

0 commit comments

Comments
 (0)