Skip to content

Commit 6c7451c

Browse files
Merge pull request #11 from character-tech/tanuj/accum
add support for accumulate in vllm
2 parents 0642536 + 4f19a06 commit 6c7451c

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

vllm/entrypoints/openai/protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,12 @@ class CompletionRequest(OpenAIBaseModel):
855855
" as strings of the form 'token_id:{token_id}' so that tokens "
856856
"that are not JSON-encodable can be identified."))
857857

858+
accumulate: Optional[bool] = Field(
859+
default=None,
860+
description=(
861+
"Special kind of echo where in the response instead of delta we return the accumulated text"
862+
)
863+
)
858864
# doc: end-completion-extra-params
859865

860866
# Default sampling parameters for completion requests

vllm/entrypoints/openai/serving_completion.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,9 @@ async def completion_stream_generator(
262262
previous_num_tokens = [0] * num_choices * num_prompts
263263
has_echoed = [False] * num_choices * num_prompts
264264
num_prompt_tokens = [0] * num_prompts
265+
accumulated_text = [""] * num_choices * num_prompts
266+
accumulated_tokens = [[] * num_choices * num_prompts]
267+
accumulated_logprobs = [[] * num_choices * num_prompts]
265268

266269
stream_options = request.stream_options
267270
if stream_options:
@@ -309,6 +312,16 @@ async def completion_stream_generator(
309312
*(output.logprobs or []),
310313
]
311314
has_echoed[i] = True
315+
elif request.accumulate:
316+
i = output.index + prompt_idx * num_choices
317+
# return the accumulated response
318+
accumulated_text[i] += output.text
319+
accumulated_tokens[i].extend(output.token_ids)
320+
accumulated_logprobs[i].extend(output.logprobs or [])
321+
322+
delta_text = accumulated_text[i]
323+
delta_token_ids = accumulated_tokens[i]
324+
out_logprobs = accumulated_logprobs[i]
312325
else:
313326
# return just the delta
314327
delta_text = output.text

0 commit comments

Comments
 (0)