Skip to content

Commit 6b094a2

Browse files
authored
[ModelRunner]Add profile execute duration observation (#1013)
### What this PR does / why we need it? We need to **observe the time consumed in each stage of inference (including pre-processing, model forward, etc.), without any performance loss**. Therefore, we use the event timestamp mechanism of the NPU to mark any stage during the execution of the NPU device (this marking operation is executed asynchronously, with no performance loss). Additionally, we provide a blocking synchronization API `pop_captured_sync` to be called at an appropriate time, to print the time consumed in all observed stages. **model_runner_v1.py file only changed 5 lines, all of which were `ProfileExecuteDuration()` calls, and nothing else was changed, while more changes were showed due to the alignment issue.** ### Does this PR introduce _any_ user-facing change? Use env `VLLM_MODEL_EXECUTE_TIME_OBSERVE `to enable this feature ### How was this patch tested? Tested in deepseek model,Print like this: ``` 5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms 5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms 5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms 5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms 5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms 5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms 5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms 5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms 5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms 5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms 5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms 5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms 5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms 5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms 5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms 5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms 5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms 5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms ``` --------- Signed-off-by: depeng1994 <[email protected]>
1 parent 78431b3 commit 6b094a2

File tree

6 files changed

+284
-113
lines changed

6 files changed

+284
-113
lines changed

docs/source/developer_guide/evaluation/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,5 @@ using_evalscope
1212
:caption: Performance
1313
:maxdepth: 1
1414
performance_benchmark
15+
profile_execute_duration
1516
:::
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Profile Execute Duration
2+
3+
The execution duration of each stage (including pre/post-processing, model forward, etc.) usually needs to be captured during a complete inference process. Typically, this is done by using `torch.npu.synchronize()` and obtaining CPU timestamps, which increases the performance overhead of host/device synchronization.
4+
5+
**To reduce the performance overhead, we add this feature, using the NPU event timestamp mechanism to observe the device execution time asynchronously.**
6+
7+
## Usage
8+
* Use the environment variable `VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE` to enable this feature.
9+
* Use the non-blocking API `ProfileExecuteDuration().capture_async` to set observation points asynchronously when you need to observe the execution duration.
10+
* Use the blocking API `ProfileExecuteDuration().pop_captured_sync` at an appropriate time to get and print the execution durations of all observed stages.
11+
12+
## Example Output
13+
14+
```
15+
5691:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.17ms [prepare input and forward]:9.57ms [forward]:4.14ms
16+
5695:(IntegratedWorker pid=1502285) Profile execute duration [Decode]: [post process]:14.29ms [prepare input and forward]:10.19ms [forward]:4.14ms
17+
5697:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.81ms [prepare input and forward]:10.29ms [forward]:3.99ms
18+
5701:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.10ms [prepare input and forward]:10.62ms [forward]:4.33ms
19+
5705:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.65ms [prepare input and forward]:9.58ms [forward]:4.20ms
20+
5709:(IntegratedWorker pid=1502343) Profile execute duration [Decode]: [post process]:14.43ms [prepare input and forward]:9.88ms [forward]:4.20ms
21+
5711:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.89ms [prepare input and forward]:10.49ms [forward]:4.19ms
22+
5715:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.14ms [prepare input and forward]:11.21ms [forward]:4.18ms
23+
5719:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.71ms [prepare input and forward]:10.15ms [forward]:4.42ms
24+
5723:(IntegratedWorker pid=1502401) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.31ms [forward]:4.25ms
25+
5725:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.12ms [prepare input and forward]:10.33ms [forward]:4.24ms
26+
5729:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.58ms [prepare input and forward]:10.85ms [forward]:4.32ms
27+
5733:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:14.32ms [prepare input and forward]:9.79ms [forward]:4.28ms
28+
5737:(IntegratedWorker pid=1502462) Profile execute duration [Decode]: [post process]:15.06ms [prepare input and forward]:9.89ms [forward]:4.32ms
29+
5739:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.62ms [prepare input and forward]:10.48ms [forward]:4.27ms
30+
5743:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.60ms [prepare input and forward]:10.71ms [forward]:4.61ms
31+
5747:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:14.21ms [prepare input and forward]:10.10ms [forward]:4.52ms
32+
5751:(IntegratedWorker pid=1502524) Profile execute duration [Decode]: [post process]:15.03ms [prepare input and forward]:10.00ms [forward]:4.42ms
33+
34+
```
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
5+
# Copyright 2023 The vLLM team.
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
#
19+
import os
20+
import time
21+
from unittest.mock import patch
22+
23+
import torch
24+
import vllm # noqa: F401
25+
26+
from vllm_ascend.utils import ProfileExecuteDuration
27+
28+
29+
@patch.dict(os.environ, {"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": "1"})
30+
def test_execue_duration_enabled_discrepancy():
31+
a = torch.randn(10000, 10000).npu()
32+
b = torch.randn(10000, 10000).npu()
33+
34+
# warmup
35+
torch.matmul(a, b)
36+
torch.npu.synchronize()
37+
38+
cpu_start = time.perf_counter()
39+
with ProfileExecuteDuration().capture_async("forward"):
40+
torch.matmul(a, b)
41+
torch.npu.synchronize()
42+
cpu_duration = (time.perf_counter() - cpu_start) * 1000
43+
npu_durations = ProfileExecuteDuration().pop_captured_sync()
44+
assert npu_durations and 'forward' in npu_durations
45+
assert not ProfileExecuteDuration._observations
46+
47+
# Assert discrepancy between CPU and NPU duration is within 50% roughly
48+
diff = abs(cpu_duration - npu_durations['forward']) / max(
49+
cpu_duration, npu_durations['forward'])
50+
assert diff <= 0.5, (
51+
f"CPU={cpu_duration:.2f}ms, NPU={npu_durations['forward']:.2f}ms")
52+
53+
54+
def test_execue_duration_disabled():
55+
a = torch.randn(100, 100).npu()
56+
b = torch.randn(100, 100).npu()
57+
58+
with ProfileExecuteDuration().capture_async("forward"):
59+
torch.matmul(a, b)
60+
torch.npu.synchronize()
61+
npu_durations = ProfileExecuteDuration().pop_captured_sync()
62+
assert not npu_durations

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@
7070
lambda: os.getenv("VLLM_VERSION", None),
7171
"VLLM_ASCEND_TRACE_RECOMPILES":
7272
lambda: bool(int(os.getenv("VLLM_ASCEND_TRACE_RECOMPILES", '0'))),
73+
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
74+
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
75+
),
7376
}
7477

7578
# end-env-vars-definition

vllm_ascend/utils.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/worker.py
1818
#
1919

20+
import atexit
2021
import math
21-
from typing import TYPE_CHECKING
22+
from contextlib import contextmanager
23+
from threading import Lock
24+
from typing import TYPE_CHECKING, List, Tuple
2225

2326
import torch
2427
from packaging.version import InvalidVersion, Version
28+
from torch_npu.npu.streams import Event
2529
from vllm.logger import logger
2630

2731
import vllm_ascend.envs as envs
@@ -175,3 +179,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
175179

176180
def dispose_tensor(x: torch.Tensor):
177181
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
182+
183+
184+
class ProfileExecuteDuration:
185+
_instance = None
186+
_observations: List[Tuple[str, Event, Event]] = []
187+
_lock = Lock()
188+
189+
def __new__(cls):
190+
with cls._lock:
191+
if cls._instance is None:
192+
cls._instance = super().__new__(cls)
193+
atexit.register(cls._instance.destroy)
194+
return cls._instance
195+
196+
def destroy(self):
197+
with self._lock:
198+
self._observations.clear()
199+
200+
@contextmanager
201+
def capture_async(self, duration_tag: str):
202+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
203+
yield
204+
return
205+
206+
observe_start = Event(enable_timing=True)
207+
observe_start.record()
208+
try:
209+
yield
210+
finally:
211+
observe_end = Event(enable_timing=True)
212+
observe_end.record()
213+
with self._lock:
214+
self._observations.append(
215+
(duration_tag, observe_start, observe_end))
216+
217+
def pop_captured_sync(self) -> dict:
218+
"""Pop and synchronize all events in the observation list"""
219+
durations: dict[str, float] = {}
220+
if not envs.VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE:
221+
return durations
222+
223+
while self._observations:
224+
with self._lock:
225+
tag, observe_start, observe_end = self._observations.pop()
226+
observe_end.synchronize()
227+
durations[tag] = observe_start.elapsed_time(observe_end)
228+
229+
return durations

0 commit comments

Comments
 (0)