Skip to content

Commit 800a544

Browse files
committed
Add profile execute duration observation
1 parent a93bed4 commit 800a544

File tree

3 files changed

+192
-124
lines changed

3 files changed

+192
-124
lines changed

vllm_ascend/envs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))),
3737
"VLLM_ENABLE_MC2":
3838
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
39+
"VLLM_MODEL_EXECUTE_TIME_OBSERVE":
40+
lambda: bool(int(os.getenv("VLLM_MODEL_EXECUTE_TIME_OBSERVE", '0'))),
3941
"USING_LCCL_COM":
4042
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
4143
"SOC_VERSION":

vllm_ascend/utils.py

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,18 @@
1717
# Adapted from vllm-project/vllm/vllm/worker/worker.py
1818
#
1919

20+
import atexit
2021
import math
21-
from typing import TYPE_CHECKING
22+
from contextlib import contextmanager
23+
from threading import Lock
24+
from typing import TYPE_CHECKING, List, Tuple
2225

2326
import torch
2427
from packaging.version import InvalidVersion, Version
25-
from vllm.logger import logger
28+
from torch_npu.npu.streams import Event
2629

2730
import vllm_ascend.envs as envs
31+
from vllm.logger import logger
2832

2933
if TYPE_CHECKING:
3034
from vllm.config import VllmConfig
@@ -173,3 +177,51 @@ def update_aclgraph_sizes(vllm_config: VllmConfig) -> None:
173177

174178
def dispose_tensor(x: torch.Tensor):
175179
x.set_(torch.empty((0, ), device=x.device, dtype=x.dtype))
180+
181+
182+
class ProfileExecuteDuration:
183+
_instance = None
184+
_observations: List[Tuple[str, Event, Event]] = []
185+
_lock = Lock()
186+
187+
def __new__(cls):
188+
with cls._lock:
189+
if cls._instance is None:
190+
cls._instance = super().__new__(cls)
191+
atexit.register(cls._instance.destroy)
192+
return cls._instance
193+
194+
def destroy(self):
195+
with self._lock:
196+
self._observations.clear()
197+
198+
@contextmanager
199+
def capture_async(self, duration_tag: str):
200+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
201+
yield
202+
return
203+
204+
observe_start = Event(enable_timing=True)
205+
observe_start.record()
206+
try:
207+
yield
208+
finally:
209+
observe_end = Event(enable_timing=True)
210+
observe_end.record()
211+
with self._lock:
212+
self._observations.append(
213+
(duration_tag, observe_start, observe_end))
214+
215+
def pop_captured_sync(self, captured_name: str):
216+
"""Pop and synchronize all events in the observation list, print all duration"""
217+
if not envs.VLLM_MODEL_EXECUTE_TIME_OBSERVE:
218+
return
219+
220+
log = f"Profile execute duration [{captured_name}]:"
221+
while self._observations:
222+
with self._lock:
223+
tag, observe_start, observe_end = self._observations.pop()
224+
observe_end.synchronize()
225+
duration = observe_start.elapsed_time(observe_end)
226+
log += f" [{tag}]:{duration:.2f}ms"
227+
print(log)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 136 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import numpy.typing as npt
3030
import torch
3131
import torch.nn as nn
32+
3233
from vllm.attention import AttentionType, get_attn_backend
3334
from vllm.attention.layer import Attention
3435
from vllm.config import CompilationLevel, VllmConfig
@@ -56,14 +57,15 @@
5657
from vllm.v1.utils import bind_kv_cache
5758
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
5859
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
59-
6060
from vllm_ascend.attention.attention import AttentionMaskBuilder
6161
from vllm_ascend.attention.attention_v1 import AscendAttentionState
6262
from vllm_ascend.platform import NPUPlatform
6363
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
64+
from vllm_ascend.utils import ProfileExecuteDuration
6465

6566
if TYPE_CHECKING:
6667
import xgrammar as xgr # type: ignore[import-untyped]
68+
6769
from vllm.v1.core.sched.output import SchedulerOutput
6870
else:
6971
xgr = LazyLoader("xgr", globals(), "xgrammar")
@@ -628,36 +630,38 @@ def _process_reqs(
628630
with set_forward_context(attn_metadata,
629631
self.vllm_config,
630632
num_tokens=num_input_tokens):
631-
model_kwargs = {}
632-
if self.enable_torchair_graph_mode:
633-
model_kwargs["kv_caches"] = self.kv_caches
634-
model_kwargs["attn_metadata"] = attn_metadata
635-
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
636-
torch._dynamo.mark_static(input_ids)
637-
torch._dynamo.mark_static(positions)
638-
torch._dynamo.mark_static(attn_metadata.decode.block_table)
639-
torch._dynamo.mark_static(attn_metadata.decode.input_positions)
640-
torch._dynamo.mark_static(attn_metadata.slot_mapping)
641-
for kv in self.kv_caches:
642-
if isinstance(kv, tuple):
643-
torch._dynamo.mark_static(kv[0])
644-
torch._dynamo.mark_static(kv[1])
645-
hidden_states = self.compile_model(
646-
input_ids=input_ids,
647-
positions=positions,
648-
intermediate_tensors=intermediate_tensors,
649-
inputs_embeds=None,
650-
**model_kwargs,
651-
)
652-
else:
653-
assert self.model is not None
654-
hidden_states = self.model(
655-
input_ids=input_ids,
656-
positions=positions,
657-
intermediate_tensors=intermediate_tensors,
658-
inputs_embeds=None,
659-
**model_kwargs,
660-
)
633+
with ProfileExecuteDuration().capture_async("forward"):
634+
model_kwargs = {}
635+
if self.enable_torchair_graph_mode:
636+
model_kwargs["kv_caches"] = self.kv_caches
637+
model_kwargs["attn_metadata"] = attn_metadata
638+
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
639+
torch._dynamo.mark_static(input_ids)
640+
torch._dynamo.mark_static(positions)
641+
torch._dynamo.mark_static(attn_metadata.decode.block_table)
642+
torch._dynamo.mark_static(
643+
attn_metadata.decode.input_positions)
644+
torch._dynamo.mark_static(attn_metadata.slot_mapping)
645+
for kv in self.kv_caches:
646+
if isinstance(kv, tuple):
647+
torch._dynamo.mark_static(kv[0])
648+
torch._dynamo.mark_static(kv[1])
649+
hidden_states = self.compile_model(
650+
input_ids=input_ids,
651+
positions=positions,
652+
intermediate_tensors=intermediate_tensors,
653+
inputs_embeds=None,
654+
**model_kwargs,
655+
)
656+
else:
657+
assert self.model is not None
658+
hidden_states = self.model(
659+
input_ids=input_ids,
660+
positions=positions,
661+
intermediate_tensors=intermediate_tensors,
662+
inputs_embeds=None,
663+
**model_kwargs,
664+
)
661665

662666
use_spec_decode = len(
663667
scheduler_output.scheduled_spec_decode_tokens) > 0
@@ -844,103 +848,113 @@ def execute_model(
844848
scheduler_output: "SchedulerOutput",
845849
intermediate_tensors: Optional[IntermediateTensors] = None,
846850
) -> Union[ModelRunnerOutput, torch.Tensor]:
847-
self._update_states(scheduler_output)
848-
if not scheduler_output.total_num_scheduled_tokens:
849-
# Return empty ModelRunnerOuptut if there's no work to do.
850-
return EMPTY_MODEL_RUNNER_OUTPUT
851-
(attn_metadata, hidden_states, spec_decode_metadata, positions,
852-
num_scheduled_tokens,
853-
sample_indices) = (self._process_reqs(scheduler_output,
854-
intermediate_tensors))
855-
logits = self.model.compute_logits(hidden_states[sample_indices], None)
856-
857-
# Apply structured output bitmasks if present
858-
if scheduler_output.grammar_bitmask is not None:
859-
logits = self.apply_grammar_bitmask(scheduler_output, logits)
860-
861-
# Sample the next token and get logprobs if needed.
862-
sampling_metadata = self.input_batch.sampling_metadata
863-
if spec_decode_metadata is None:
864-
sampler_output = self.sampler(
865-
logits=logits,
866-
sampling_metadata=sampling_metadata,
867-
)
868-
else:
869-
# When indexing with a tensor (bonus_logits_indices), PyTorch
870-
# creates a new tensor with separate storage from the original
871-
# logits tensor. This means any in-place operations on bonus_logits
872-
# won't affect the original logits tensor.
873-
bonus_logits = logits[spec_decode_metadata.bonus_logits_indices]
874-
sampler_output = self.sampler(
875-
logits=bonus_logits,
876-
sampling_metadata=sampling_metadata,
877-
)
878-
bonus_token_ids = sampler_output.sampled_token_ids
851+
with ProfileExecuteDuration().capture_async(
852+
"prepare input and forward"):
853+
self._update_states(scheduler_output)
854+
if not scheduler_output.total_num_scheduled_tokens:
855+
# Return empty ModelRunnerOuptut if there's no work to do.
856+
return EMPTY_MODEL_RUNNER_OUTPUT
857+
(attn_metadata, hidden_states, spec_decode_metadata, positions,
858+
num_scheduled_tokens,
859+
sample_indices) = (self._process_reqs(scheduler_output,
860+
intermediate_tensors))
861+
862+
with ProfileExecuteDuration().capture_async("post process"):
863+
logits = self.model.compute_logits(hidden_states[sample_indices],
864+
None)
865+
866+
# Apply structured output bitmasks if present
867+
if scheduler_output.grammar_bitmask is not None:
868+
logits = self.apply_grammar_bitmask(scheduler_output, logits)
869+
870+
# Sample the next token and get logprobs if needed.
871+
sampling_metadata = self.input_batch.sampling_metadata
872+
if spec_decode_metadata is None:
873+
sampler_output = self.sampler(
874+
logits=logits,
875+
sampling_metadata=sampling_metadata,
876+
)
877+
else:
878+
# When indexing with a tensor (bonus_logits_indices), PyTorch
879+
# creates a new tensor with separate storage from the original
880+
# logits tensor. This means any in-place operations on bonus_logits
881+
# won't affect the original logits tensor.
882+
bonus_logits = logits[
883+
spec_decode_metadata.bonus_logits_indices]
884+
sampler_output = self.sampler(
885+
logits=bonus_logits,
886+
sampling_metadata=sampling_metadata,
887+
)
888+
bonus_token_ids = sampler_output.sampled_token_ids
889+
890+
# Just like `bonus_logits`, `target_logits` is a new tensor with
891+
# separate storage from the original `logits` tensor. Therefore,
892+
# it is safe to update `target_logits` in place.
893+
target_logits = logits[
894+
spec_decode_metadata.target_logits_indices]
895+
output_token_ids = self.rejection_sampler(
896+
spec_decode_metadata,
897+
None, # draft_probs
898+
target_logits,
899+
bonus_token_ids,
900+
sampling_metadata,
901+
)
902+
sampler_output.sampled_token_ids = output_token_ids
903+
904+
# TODO(woosuk): The following loop can be slow since it iterates over
905+
# the requests one by one. Optimize.
906+
for i, req_id in enumerate(self.input_batch.req_ids):
907+
req_state = self.requests[req_id]
908+
seq_len = (req_state.num_computed_tokens +
909+
scheduler_output.num_scheduled_tokens[req_id])
910+
if seq_len < req_state.num_tokens:
911+
# Ignore the sampled token.
912+
# Rewind the generator state as if the token was not sampled.
913+
generator = self.input_batch.generators.get(i)
914+
if generator is not None:
915+
generator.set_offset(generator.get_offset() - 4)
916+
917+
# NOTE: NPU -> CPU Sync happens here.
918+
# Move as many CPU operations as possible before this sync point.
919+
logprobs_tensors = sampler_output.logprobs_tensors
920+
logprobs_lists = logprobs_tensors.tolists() \
921+
if logprobs_tensors is not None else None
922+
923+
# Get the valid generated tokens.
924+
sampled_token_ids = sampler_output.sampled_token_ids
925+
max_gen_len = sampled_token_ids.shape[-1]
926+
if max_gen_len == 1:
927+
# No spec decode tokens.
928+
valid_sampled_token_ids = sampled_token_ids.tolist()
929+
else:
930+
# Includes spec decode tokens.
931+
valid_sampled_token_ids = self.rejection_sampler.parse_output(
932+
sampled_token_ids,
933+
self.input_batch.vocab_size,
934+
)
879935

880-
# Just like `bonus_logits`, `target_logits` is a new tensor with
881-
# separate storage from the original `logits` tensor. Therefore,
882-
# it is safe to update `target_logits` in place.
883-
target_logits = logits[spec_decode_metadata.target_logits_indices]
884-
output_token_ids = self.rejection_sampler(
885-
spec_decode_metadata,
886-
None, # draft_probs
887-
target_logits,
888-
bonus_token_ids,
936+
spec_token_ids = self._get_spec_token_ids(
937+
valid_sampled_token_ids,
889938
sampling_metadata,
939+
scheduler_output,
940+
spec_decode_metadata,
941+
positions,
942+
num_scheduled_tokens,
943+
hidden_states,
944+
attn_metadata,
890945
)
891-
sampler_output.sampled_token_ids = output_token_ids
892946

893-
# TODO(woosuk): The following loop can be slow since it iterates over
894-
# the requests one by one. Optimize.
895-
for i, req_id in enumerate(self.input_batch.req_ids):
896-
req_state = self.requests[req_id]
897-
seq_len = (req_state.num_computed_tokens +
898-
scheduler_output.num_scheduled_tokens[req_id])
899-
if seq_len < req_state.num_tokens:
900-
# Ignore the sampled token.
901-
# Rewind the generator state as if the token was not sampled.
902-
generator = self.input_batch.generators.get(i)
903-
if generator is not None:
904-
generator.set_offset(generator.get_offset() - 4)
905-
906-
# NOTE: NPU -> CPU Sync happens here.
907-
# Move as many CPU operations as possible before this sync point.
908-
logprobs_tensors = sampler_output.logprobs_tensors
909-
logprobs_lists = logprobs_tensors.tolists() \
910-
if logprobs_tensors is not None else None
911-
912-
# Get the valid generated tokens.
913-
sampled_token_ids = sampler_output.sampled_token_ids
914-
max_gen_len = sampled_token_ids.shape[-1]
915-
if max_gen_len == 1:
916-
# No spec decode tokens.
917-
valid_sampled_token_ids = sampled_token_ids.tolist()
918-
else:
919-
# Includes spec decode tokens.
920-
valid_sampled_token_ids = self.rejection_sampler.parse_output(
921-
sampled_token_ids,
922-
self.input_batch.vocab_size,
947+
model_runner_output = ModelRunnerOutput(
948+
req_ids=self.input_batch.req_ids,
949+
req_id_to_index=self.input_batch.req_id_to_index,
950+
sampled_token_ids=valid_sampled_token_ids,
951+
spec_token_ids=spec_token_ids,
952+
logprobs=logprobs_lists,
953+
prompt_logprobs_dict={},
923954
)
924955

925-
spec_token_ids = self._get_spec_token_ids(
926-
valid_sampled_token_ids,
927-
sampling_metadata,
928-
scheduler_output,
929-
spec_decode_metadata,
930-
positions,
931-
num_scheduled_tokens,
932-
hidden_states,
933-
attn_metadata,
934-
)
935-
936-
model_runner_output = ModelRunnerOutput(
937-
req_ids=self.input_batch.req_ids,
938-
req_id_to_index=self.input_batch.req_id_to_index,
939-
sampled_token_ids=valid_sampled_token_ids,
940-
spec_token_ids=spec_token_ids,
941-
logprobs=logprobs_lists,
942-
prompt_logprobs_dict={},
943-
)
956+
capture_name = "Decode" if self.attn_state == AscendAttentionState.DecodeOnly else "Prefill"
957+
ProfileExecuteDuration().pop_captured_sync(capture_name)
944958
return model_runner_output
945959

946960
def _profile_multimodal(self) -> None:

0 commit comments

Comments
 (0)