Skip to content

Commit 367adf9

Browse files
committed
BatchScheduler to prevent microbatch imbalance and GPU bubbles
Reformatting Reformatting Reformatting Reformatting Reformatting Reformatting
1 parent 4719460 commit 367adf9

File tree

4 files changed

+526
-2
lines changed

4 files changed

+526
-2
lines changed

vllm/config.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,6 +2119,20 @@ class SchedulerConfig:
21192119
default scheduler. Can be a class directly or the path to a class of form
21202120
"mod.custom_class"."""
21212121

2122+
use_batch_scheduler: bool = False
2123+
"""Whether to use the BatchScheduler instead of the default scheduler.
2124+
2125+
If set to True, the engine will use
2126+
"vllm.v1.core.sched.scheduler.BatchScheduler" as the scheduler class unless
2127+
a custom `scheduler_cls` is explicitly provided.
2128+
2129+
If both `use_batch_scheduler=True` and a non-default `scheduler_cls` are
2130+
specified, the `scheduler_cls` will take precedence and
2131+
`use_batch_scheduler` will be ignored.
2132+
2133+
Default is False.
2134+
"""
2135+
21222136
disable_hybrid_kv_cache_manager: bool = False
21232137
"""If set to True, KV cache manager will allocate the same size of KV cache
21242138
for all attention layers even if there are multiple type of attention layers

vllm/engine/arg_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ class EngineArgs:
411411
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
412412
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
413413
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
414+
use_batch_scheduler: bool = SchedulerConfig.use_batch_scheduler
414415

415416
override_neuron_config: dict[str, Any] = \
416417
get_field(ModelConfig, "override_neuron_config")
@@ -855,6 +856,8 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
855856
**scheduler_kwargs["disable_chunked_mm_input"])
856857
scheduler_group.add_argument("--scheduler-cls",
857858
**scheduler_kwargs["scheduler_cls"])
859+
scheduler_group.add_argument("--use-batch-scheduler",
860+
**scheduler_kwargs["use_batch_scheduler"])
858861
scheduler_group.add_argument(
859862
"--disable-hybrid-kv-cache-manager",
860863
**scheduler_kwargs["disable_hybrid_kv_cache_manager"])
@@ -1182,6 +1185,7 @@ def create_engine_config(
11821185
and parallel_config.use_ray),
11831186
policy=self.scheduling_policy,
11841187
scheduler_cls=self.scheduler_cls,
1188+
use_batch_scheduler=self.use_batch_scheduler,
11851189
max_num_partial_prefills=self.max_num_partial_prefills,
11861190
max_long_partial_prefills=self.max_long_partial_prefills,
11871191
long_prefill_token_threshold=self.long_prefill_token_threshold,
@@ -1550,6 +1554,18 @@ def _set_default_args_v1(self, usage_context: UsageContext,
15501554
if not self.enable_chunked_prefill:
15511555
self.max_num_batched_tokens = model_config.max_model_len
15521556

1557+
if self.use_batch_scheduler:
1558+
if self.scheduler_cls == EngineArgs.scheduler_cls:
1559+
self.scheduler_cls = \
1560+
"vllm.v1.core.sched.scheduler.BatchScheduler"
1561+
else:
1562+
logger.warning(
1563+
"use_batch_scheduler is set to True, "
1564+
"but a custom scheduler_cls is also provided. "
1565+
"The specified scheduler_cls (%s) will take precedence, "
1566+
"and use_batch_scheduler will be ignored.",
1567+
self.scheduler_cls)
1568+
15531569
# V1 should use the new scheduler by default.
15541570
# Swap it only if this arg is set to the original V0 default
15551571
if self.scheduler_cls == EngineArgs.scheduler_cls:

vllm/v1/core/sched/output.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from __future__ import annotations
55

66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Optional
7+
from typing import TYPE_CHECKING, Optional, Union
88

99
if TYPE_CHECKING:
1010
import numpy as np
@@ -155,3 +155,13 @@ class SchedulerOutput:
155155

156156
# KV Cache Connector metadata.
157157
kv_connector_metadata: Optional[KVConnectorMetadata] = None
158+
159+
160+
@dataclass
161+
class ScheduledRequest:
162+
request_id: str
163+
num_new_tokens: int
164+
encoder_inputs_to_schedule: list[int] | None
165+
num_scheduled_spec_tokens: int
166+
spec_token_ids: list[int] | None
167+
request_data: Union[NewRequestData, CachedRequestData]

0 commit comments

Comments
 (0)