Skip to content

Commit 6afc0ff

Browse files
authored
[Model Runner V2] Add sample/ directory and reorganize files (#29719)
Signed-off-by: Woosuk Kwon <[email protected]>
1 parent 39e63de commit 6afc0ff

File tree

10 files changed

+587
-570
lines changed

10 files changed

+587
-570
lines changed

vllm/v1/worker/gpu/model_runner.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,18 @@
4747
prepare_pos_seq_lens,
4848
prepare_prefill_inputs,
4949
)
50-
from vllm.v1.worker.gpu.sampler import Sampler, compute_prompt_logprobs
50+
from vllm.v1.worker.gpu.sample.logprob import compute_prompt_logprobs
51+
from vllm.v1.worker.gpu.sample.metadata import (
52+
SamplingMetadata,
53+
expand_sampling_metadata,
54+
)
55+
from vllm.v1.worker.gpu.sample.sampler import Sampler
5156
from vllm.v1.worker.gpu.spec_decode import init_speculator
5257
from vllm.v1.worker.gpu.spec_decode.rejection_sample import (
5358
get_num_rejected,
5459
rejection_sample,
5560
)
56-
from vllm.v1.worker.gpu.states import RequestState, SamplingMetadata
61+
from vllm.v1.worker.gpu.states import RequestState
5762
from vllm.v1.worker.gpu.structured_outputs import apply_grammar_bitmask
5863
from vllm.v1.worker.kv_connector_model_runner_mixin import KVConnectorModelRunnerMixin
5964
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
@@ -890,8 +895,10 @@ def execute_model(
890895
input_batch.idx_mapping, input_batch.idx_mapping_np, pos
891896
)
892897
if input_batch.num_draft_tokens > 0:
893-
sampling_metadata = self.req_states.expand_sampling_metadata(
894-
sampling_metadata, input_batch.cu_num_logits
898+
sampling_metadata = expand_sampling_metadata(
899+
sampling_metadata,
900+
input_batch.cu_num_logits,
901+
max_expand_len=self.num_speculative_steps + 1,
895902
)
896903

897904
if self.lora_config:

vllm/v1/worker/gpu/sample/__init__.py

Whitespace-only changes.
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
import torch
4+
5+
from vllm.triton_utils import tl, triton
6+
7+
8+
@triton.jit
9+
def _gumbel_sample_kernel(
10+
local_argmax_ptr,
11+
local_argmax_stride,
12+
local_max_ptr,
13+
local_max_stride,
14+
logits_ptr,
15+
logits_stride,
16+
seeds_ptr,
17+
pos_ptr,
18+
temp_ptr,
19+
vocab_size,
20+
BLOCK_SIZE: tl.constexpr,
21+
APPLY_TEMPERATURE: tl.constexpr,
22+
):
23+
req_idx = tl.program_id(0)
24+
block_idx = tl.program_id(1)
25+
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
26+
mask = block < vocab_size
27+
logits = tl.load(
28+
logits_ptr + req_idx * logits_stride + block,
29+
mask=mask,
30+
other=float("-inf"),
31+
)
32+
logits = logits.to(tl.float32)
33+
34+
temp = tl.load(temp_ptr + req_idx).to(tl.float32)
35+
if temp != 0.0:
36+
# Calculate the seed for gumbel noise.
37+
seed = tl.load(seeds_ptr + req_idx)
38+
pos = tl.load(pos_ptr + req_idx)
39+
gumbel_seed = tl.randint(seed, pos)
40+
41+
# Generate gumbel noise.
42+
r = tl.rand(gumbel_seed, block).to(tl.float64)
43+
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
44+
gumbel_noise = gumbel_noise.to(tl.float32)
45+
46+
# Apply temperature.
47+
if APPLY_TEMPERATURE:
48+
# NOTE(woosuk): Use div_rn to match the behavior of torch.
49+
logits = tl.div_rn(logits, temp)
50+
51+
# Apply gumbel noise.
52+
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
53+
54+
idx = tl.argmax(logits, axis=0)
55+
token_id = block_idx * BLOCK_SIZE + idx
56+
value = tl.max(logits, axis=0)
57+
tl.store(local_argmax_ptr + req_idx * local_argmax_stride + block_idx, token_id)
58+
tl.store(local_max_ptr + req_idx * local_max_stride + block_idx, value)
59+
60+
61+
def gumbel_sample(
62+
logits: torch.Tensor, # [num_reqs, vocab_size]
63+
temperature: torch.Tensor, # [num_reqs]
64+
seed: torch.Tensor, # [num_reqs]
65+
pos: torch.Tensor, # [num_reqs]
66+
apply_temperature: bool,
67+
) -> torch.Tensor:
68+
num_reqs, vocab_size = logits.shape
69+
BLOCK_SIZE = 1024
70+
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
71+
local_argmax = torch.empty(
72+
num_reqs,
73+
num_blocks,
74+
dtype=torch.int64,
75+
device=logits.device,
76+
)
77+
local_max = torch.empty(
78+
num_reqs,
79+
num_blocks,
80+
dtype=torch.float32,
81+
device=logits.device,
82+
)
83+
_gumbel_sample_kernel[(num_reqs, num_blocks)](
84+
local_argmax,
85+
local_argmax.stride(0),
86+
local_max,
87+
local_max.stride(0),
88+
logits,
89+
logits.stride(0),
90+
seed,
91+
pos,
92+
temperature,
93+
vocab_size,
94+
BLOCK_SIZE=BLOCK_SIZE,
95+
APPLY_TEMPERATURE=apply_temperature,
96+
)
97+
# NOTE(woosuk): Use int64 for later indexing.
98+
max_block_idx = local_max.argmax(dim=-1, keepdim=True)
99+
sampled = local_argmax.gather(dim=-1, index=max_block_idx).view(-1)
100+
return sampled
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
from collections.abc import Callable
4+
5+
import torch
6+
7+
from vllm.triton_utils import tl, triton
8+
from vllm.v1.outputs import LogprobsTensors
9+
10+
11+
@triton.jit
12+
def _topk_log_softmax_kernel(
13+
output_ptr,
14+
logits_ptr,
15+
logits_stride,
16+
topk_ids_ptr,
17+
topk,
18+
vocab_size,
19+
BLOCK_SIZE: tl.constexpr,
20+
PADDED_TOPK: tl.constexpr,
21+
):
22+
req_idx = tl.program_id(0)
23+
row_ptr = logits_ptr + req_idx * logits_stride
24+
25+
max_val = float("-inf")
26+
for i in range(0, vocab_size, BLOCK_SIZE):
27+
block = i + tl.arange(0, BLOCK_SIZE)
28+
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
29+
max_val = tl.max(tl.maximum(logits, max_val))
30+
max_val = max_val.to(tl.float32) # type: ignore
31+
32+
se = 0.0
33+
for i in range(0, vocab_size, BLOCK_SIZE):
34+
block = i + tl.arange(0, BLOCK_SIZE)
35+
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0)
36+
# NOTE(woosuk): Make sure that logits and all following operations use FP32.
37+
logits = logits.to(tl.float32)
38+
e = tl.exp(logits - max_val)
39+
e = tl.where(block < vocab_size, e, 0.0)
40+
se += tl.sum(e)
41+
lse = tl.log(se)
42+
43+
k_offset = tl.arange(0, PADDED_TOPK)
44+
k_mask = k_offset < topk
45+
topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0)
46+
47+
logits = tl.load(row_ptr + topk_ids, mask=k_mask)
48+
logits = logits.to(tl.float32)
49+
o = logits - max_val - lse
50+
tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask)
51+
52+
53+
@triton.jit
54+
def _ranks_kernel(
55+
output_ptr,
56+
logits_ptr,
57+
logits_stride,
58+
token_ids_ptr,
59+
vocab_size,
60+
BLOCK_SIZE: tl.constexpr,
61+
):
62+
req_idx = tl.program_id(0)
63+
row_ptr = logits_ptr + req_idx * logits_stride
64+
65+
token_id = tl.load(token_ids_ptr + req_idx)
66+
x = tl.load(row_ptr + token_id)
67+
68+
n = 0
69+
for i in range(0, vocab_size, BLOCK_SIZE):
70+
block = i + tl.arange(0, BLOCK_SIZE)
71+
logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf"))
72+
n += tl.sum((logits > x).to(tl.int32))
73+
tl.store(output_ptr + req_idx, n)
74+
75+
76+
def compute_token_logprobs(
77+
logits: torch.Tensor,
78+
token_ids: torch.Tensor,
79+
) -> torch.Tensor:
80+
batch_size = logits.shape[0]
81+
vocab_size = logits.shape[1]
82+
token_ids = token_ids.to(torch.int64)
83+
num_logprobs = token_ids.shape[1]
84+
logprobs = torch.empty(
85+
batch_size,
86+
num_logprobs,
87+
dtype=torch.float32,
88+
device=logits.device,
89+
)
90+
_topk_log_softmax_kernel[(batch_size,)](
91+
logprobs,
92+
logits,
93+
logits.stride(0),
94+
token_ids,
95+
num_logprobs,
96+
vocab_size,
97+
BLOCK_SIZE=1024, # type: ignore
98+
PADDED_TOPK=triton.next_power_of_2(num_logprobs),
99+
)
100+
return logprobs
101+
102+
103+
def compute_topk_logprobs(
104+
logits: torch.Tensor,
105+
num_logprobs: int,
106+
sampled_token_ids: torch.Tensor,
107+
) -> LogprobsTensors:
108+
assert num_logprobs >= 0
109+
batch_size, vocab_size = logits.shape
110+
if num_logprobs == 0:
111+
logprob_token_ids = sampled_token_ids.unsqueeze(-1)
112+
else:
113+
topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices
114+
logprob_token_ids = torch.cat(
115+
(sampled_token_ids.unsqueeze(-1), topk_indices), dim=1
116+
)
117+
118+
# NOTE(woosuk): Here, to save GPU memory, we do not materialize the full
119+
# logprobs tensor. Instead, we only compute and return the logprobs of
120+
# the topk + 1 tokens.
121+
logprobs = compute_token_logprobs(logits, logprob_token_ids)
122+
token_ranks = torch.empty(
123+
batch_size,
124+
dtype=torch.int64,
125+
device=logits.device,
126+
)
127+
_ranks_kernel[(batch_size,)](
128+
token_ranks,
129+
logits,
130+
logits.stride(0),
131+
sampled_token_ids,
132+
vocab_size,
133+
BLOCK_SIZE=8192, # type: ignore
134+
)
135+
return LogprobsTensors(
136+
logprob_token_ids=logprob_token_ids,
137+
logprobs=logprobs,
138+
selected_token_ranks=token_ranks,
139+
)
140+
141+
142+
def compute_prompt_logprobs(
143+
prompt_token_ids: torch.Tensor,
144+
prompt_hidden_states: torch.Tensor,
145+
logits_fn: Callable[[torch.Tensor], torch.Tensor],
146+
) -> tuple[torch.Tensor, torch.Tensor]:
147+
# Since materializing the full prompt logits can take too much memory,
148+
# we compute it in chunks.
149+
CHUNK_SIZE = 1024
150+
logprobs = []
151+
ranks = []
152+
prompt_token_ids = prompt_token_ids.to(torch.int64)
153+
for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE):
154+
end_idx = start_idx + CHUNK_SIZE
155+
# NOTE(woosuk): logits_fn can be slow because it involves all-gather.
156+
prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx])
157+
prompt_logprobs = compute_topk_logprobs(
158+
prompt_logits,
159+
0, # num_logprobs
160+
prompt_token_ids[start_idx:end_idx],
161+
)
162+
logprobs.append(prompt_logprobs.logprobs)
163+
ranks.append(prompt_logprobs.selected_token_ranks)
164+
165+
logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0]
166+
ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0]
167+
return logprobs, ranks

0 commit comments

Comments
 (0)