|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +from collections.abc import Callable |
| 4 | + |
| 5 | +import torch |
| 6 | + |
| 7 | +from vllm.triton_utils import tl, triton |
| 8 | +from vllm.v1.outputs import LogprobsTensors |
| 9 | + |
| 10 | + |
| 11 | +@triton.jit |
| 12 | +def _topk_log_softmax_kernel( |
| 13 | + output_ptr, |
| 14 | + logits_ptr, |
| 15 | + logits_stride, |
| 16 | + topk_ids_ptr, |
| 17 | + topk, |
| 18 | + vocab_size, |
| 19 | + BLOCK_SIZE: tl.constexpr, |
| 20 | + PADDED_TOPK: tl.constexpr, |
| 21 | +): |
| 22 | + req_idx = tl.program_id(0) |
| 23 | + row_ptr = logits_ptr + req_idx * logits_stride |
| 24 | + |
| 25 | + max_val = float("-inf") |
| 26 | + for i in range(0, vocab_size, BLOCK_SIZE): |
| 27 | + block = i + tl.arange(0, BLOCK_SIZE) |
| 28 | + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) |
| 29 | + max_val = tl.max(tl.maximum(logits, max_val)) |
| 30 | + max_val = max_val.to(tl.float32) # type: ignore |
| 31 | + |
| 32 | + se = 0.0 |
| 33 | + for i in range(0, vocab_size, BLOCK_SIZE): |
| 34 | + block = i + tl.arange(0, BLOCK_SIZE) |
| 35 | + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=0.0) |
| 36 | + # NOTE(woosuk): Make sure that logits and all following operations use FP32. |
| 37 | + logits = logits.to(tl.float32) |
| 38 | + e = tl.exp(logits - max_val) |
| 39 | + e = tl.where(block < vocab_size, e, 0.0) |
| 40 | + se += tl.sum(e) |
| 41 | + lse = tl.log(se) |
| 42 | + |
| 43 | + k_offset = tl.arange(0, PADDED_TOPK) |
| 44 | + k_mask = k_offset < topk |
| 45 | + topk_ids = tl.load(topk_ids_ptr + req_idx * topk + k_offset, mask=k_mask, other=0) |
| 46 | + |
| 47 | + logits = tl.load(row_ptr + topk_ids, mask=k_mask) |
| 48 | + logits = logits.to(tl.float32) |
| 49 | + o = logits - max_val - lse |
| 50 | + tl.store(output_ptr + req_idx * topk + k_offset, o, mask=k_mask) |
| 51 | + |
| 52 | + |
| 53 | +@triton.jit |
| 54 | +def _ranks_kernel( |
| 55 | + output_ptr, |
| 56 | + logits_ptr, |
| 57 | + logits_stride, |
| 58 | + token_ids_ptr, |
| 59 | + vocab_size, |
| 60 | + BLOCK_SIZE: tl.constexpr, |
| 61 | +): |
| 62 | + req_idx = tl.program_id(0) |
| 63 | + row_ptr = logits_ptr + req_idx * logits_stride |
| 64 | + |
| 65 | + token_id = tl.load(token_ids_ptr + req_idx) |
| 66 | + x = tl.load(row_ptr + token_id) |
| 67 | + |
| 68 | + n = 0 |
| 69 | + for i in range(0, vocab_size, BLOCK_SIZE): |
| 70 | + block = i + tl.arange(0, BLOCK_SIZE) |
| 71 | + logits = tl.load(row_ptr + block, mask=block < vocab_size, other=float("-inf")) |
| 72 | + n += tl.sum((logits > x).to(tl.int32)) |
| 73 | + tl.store(output_ptr + req_idx, n) |
| 74 | + |
| 75 | + |
| 76 | +def compute_token_logprobs( |
| 77 | + logits: torch.Tensor, |
| 78 | + token_ids: torch.Tensor, |
| 79 | +) -> torch.Tensor: |
| 80 | + batch_size = logits.shape[0] |
| 81 | + vocab_size = logits.shape[1] |
| 82 | + token_ids = token_ids.to(torch.int64) |
| 83 | + num_logprobs = token_ids.shape[1] |
| 84 | + logprobs = torch.empty( |
| 85 | + batch_size, |
| 86 | + num_logprobs, |
| 87 | + dtype=torch.float32, |
| 88 | + device=logits.device, |
| 89 | + ) |
| 90 | + _topk_log_softmax_kernel[(batch_size,)]( |
| 91 | + logprobs, |
| 92 | + logits, |
| 93 | + logits.stride(0), |
| 94 | + token_ids, |
| 95 | + num_logprobs, |
| 96 | + vocab_size, |
| 97 | + BLOCK_SIZE=1024, # type: ignore |
| 98 | + PADDED_TOPK=triton.next_power_of_2(num_logprobs), |
| 99 | + ) |
| 100 | + return logprobs |
| 101 | + |
| 102 | + |
| 103 | +def compute_topk_logprobs( |
| 104 | + logits: torch.Tensor, |
| 105 | + num_logprobs: int, |
| 106 | + sampled_token_ids: torch.Tensor, |
| 107 | +) -> LogprobsTensors: |
| 108 | + assert num_logprobs >= 0 |
| 109 | + batch_size, vocab_size = logits.shape |
| 110 | + if num_logprobs == 0: |
| 111 | + logprob_token_ids = sampled_token_ids.unsqueeze(-1) |
| 112 | + else: |
| 113 | + topk_indices = torch.topk(logits, num_logprobs, dim=-1).indices |
| 114 | + logprob_token_ids = torch.cat( |
| 115 | + (sampled_token_ids.unsqueeze(-1), topk_indices), dim=1 |
| 116 | + ) |
| 117 | + |
| 118 | + # NOTE(woosuk): Here, to save GPU memory, we do not materialize the full |
| 119 | + # logprobs tensor. Instead, we only compute and return the logprobs of |
| 120 | + # the topk + 1 tokens. |
| 121 | + logprobs = compute_token_logprobs(logits, logprob_token_ids) |
| 122 | + token_ranks = torch.empty( |
| 123 | + batch_size, |
| 124 | + dtype=torch.int64, |
| 125 | + device=logits.device, |
| 126 | + ) |
| 127 | + _ranks_kernel[(batch_size,)]( |
| 128 | + token_ranks, |
| 129 | + logits, |
| 130 | + logits.stride(0), |
| 131 | + sampled_token_ids, |
| 132 | + vocab_size, |
| 133 | + BLOCK_SIZE=8192, # type: ignore |
| 134 | + ) |
| 135 | + return LogprobsTensors( |
| 136 | + logprob_token_ids=logprob_token_ids, |
| 137 | + logprobs=logprobs, |
| 138 | + selected_token_ranks=token_ranks, |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +def compute_prompt_logprobs( |
| 143 | + prompt_token_ids: torch.Tensor, |
| 144 | + prompt_hidden_states: torch.Tensor, |
| 145 | + logits_fn: Callable[[torch.Tensor], torch.Tensor], |
| 146 | +) -> tuple[torch.Tensor, torch.Tensor]: |
| 147 | + # Since materializing the full prompt logits can take too much memory, |
| 148 | + # we compute it in chunks. |
| 149 | + CHUNK_SIZE = 1024 |
| 150 | + logprobs = [] |
| 151 | + ranks = [] |
| 152 | + prompt_token_ids = prompt_token_ids.to(torch.int64) |
| 153 | + for start_idx in range(0, prompt_token_ids.shape[0], CHUNK_SIZE): |
| 154 | + end_idx = start_idx + CHUNK_SIZE |
| 155 | + # NOTE(woosuk): logits_fn can be slow because it involves all-gather. |
| 156 | + prompt_logits = logits_fn(prompt_hidden_states[start_idx:end_idx]) |
| 157 | + prompt_logprobs = compute_topk_logprobs( |
| 158 | + prompt_logits, |
| 159 | + 0, # num_logprobs |
| 160 | + prompt_token_ids[start_idx:end_idx], |
| 161 | + ) |
| 162 | + logprobs.append(prompt_logprobs.logprobs) |
| 163 | + ranks.append(prompt_logprobs.selected_token_ranks) |
| 164 | + |
| 165 | + logprobs = torch.cat(logprobs, dim=0) if len(logprobs) > 1 else logprobs[0] |
| 166 | + ranks = torch.cat(ranks, dim=0) if len(ranks) > 1 else ranks[0] |
| 167 | + return logprobs, ranks |
0 commit comments