Skip to content

[V1 Scheduler] BatchScheduler to balance token-based microbatches and reduce GPU pipeline bubbles #19873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

juncheoll
Copy link
Contributor

@juncheoll juncheoll commented Jun 19, 2025

Purpose

This PR introduces a new BatchScheduler class to improve scheduling efficiency under pipeline parallelism by reducing microbatch imbalance and pipeline bubbles.

In the existing Scheduler.schedule() implementation, requests are scheduled in large chunks up to the token budget in a single step. This causes front-loaded scheduling where most requests are packed into the first few microbatches, leading to:

  • Imbalanced workload across pipeline stages
  • Underutilized microbatch slots in later stages
  • Idle GPU time (bubbles) between batches

What this PR adds:

  • BatchScheduler, a new scheduler implementation with two-phase scheduling:
    • Phase 1: All schedulable requests are added to a scheduled_queue
    • Phase 2: The queue is drained in balanced microbatches according to an estimated token budget
  • A new method _estimate_token_budget() that dynamically computes per-step token limits based on current running token load and pipeline parallel configuration
  • Seamless integration with step_with_batch_queue() and existing batch queue logic
  • use_batch_scheduler option that activates this scheduling mode

Test Plan

Run on a 4-node LAN environment with the following configuration:

Test Environment:

  • 4 nodes over LAN, each with an RTX 4070 GPU
  • Distributed backend: ray
  • Model: meta-llama/Llama-2-13b-chat-hf
  • Scheduler variants tested: default scheduler vs BatchScheduler

Requests configuration:

  • num-requests = 256
  • Each request: input_tokens = 128, output_tokens = 128

Used options:

  --model meta-llama/Llama-2-13b-chat-hf \
  --pipeline-parallel-size 4 \
  --distributed-executor-backend ray \
  --use-batch-scheduler

Test Result

Scheduler Type Total Time to Complete 256 Requests
Default Scheduler 146.24 seconds
BatchScheduler 70.19 seconds

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @juncheoll, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new BatchScheduler designed to significantly enhance scheduling efficiency, particularly in environments utilizing pipeline parallelism. By implementing a two-phase scheduling approach and dynamically estimating token budgets, this new scheduler aims to mitigate microbatch imbalance and reduce idle GPU time, commonly known as pipeline bubbles. The result is a substantial improvement in overall request completion time, as demonstrated by the provided test results.

Highlights

  • New BatchScheduler Implementation: A dedicated BatchScheduler class has been introduced, inheriting from the base Scheduler, specifically designed to optimize request scheduling for pipeline parallelism and reduce GPU pipeline bubbles.
  • Two-Phase Scheduling Mechanism: The scheduling process is refined into two distinct phases: first, all schedulable requests are collected into an internal scheduled_queue, and then they are drained in balanced microbatches based on a dynamically estimated token budget.
  • Dynamic Token Budget Estimation: A new private method, _estimate_token_budget(), has been added to dynamically calculate the per-step token limit. This ensures a more balanced workload distribution across pipeline stages, minimizing idle GPU time.
  • Significant Performance Improvement: Benchmarks demonstrate that the BatchScheduler dramatically reduces the total time to complete 256 requests from 146.24 seconds (with the default scheduler) to 70.19 seconds, representing a speedup of over 2x.
  • Configurable Activation: A new use_batch_scheduler boolean option has been added to SchedulerConfig and exposed via EngineArgs and the command-line interface, allowing users to easily enable or disable this new scheduling mode.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@mergify mergify bot added the v1 label Jun 19, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new BatchScheduler designed to improve scheduling efficiency in pipeline parallelism by balancing microbatches. The approach involves a two-phase scheduling process: first adding all schedulable requests to an intermediate queue, and then draining this queue based on an estimated token budget. The test results demonstrate a significant performance improvement. The changes involve adding a new configuration option, updating argument parsing, defining a new data structure for scheduled requests, and implementing the new scheduler logic. While the core concept and performance gains are promising, the implementation of the schedule method in the BatchScheduler involves significant code duplication from the base class, which impacts maintainability. Additionally, there's a potential minor issue in how the token budget is checked during the queue draining phase.

Comment on lines +1089 to +1420
num_new_tokens = (
self.scheduler_config.long_prefill_token_threshold)

assert num_new_tokens > 0

# Schedule encoder inputs.
if request.has_encoder_inputs:
(encoder_inputs_to_schedule, num_new_tokens,
new_encoder_budget
) = self._try_schedule_encoder_inputs(
request, num_computed_tokens, num_new_tokens,
encoder_budget)
if num_new_tokens == 0:
# The request cannot be scheduled.
break

new_blocks = self.kv_cache_manager.allocate_slots(
request,
num_new_tokens + num_external_computed_tokens,
num_new_local_computed_tokens,
new_computed_blocks,
num_lookahead_tokens=self.num_lookahead_tokens,
delay_cache_blocks=load_kv_async,
)
if new_blocks is None:
# The request cannot be scheduled.
break

# KVTransfer: the connector uses this info to determine
# if a load is needed. Note that
# This information is used to determine if a load is
# needed for this request.
if self.connector is not None:
self.connector.update_state_after_alloc(
request,
new_computed_blocks + new_blocks,
num_external_computed_tokens,
)

self.waiting.popleft()
if load_kv_async:
# If loading async, allocate memory and put request
# into the WAITING_FOR_REMOTE_KV state.
skipped_waiting_requests.appendleft(request)
request.status = RequestStatus.WAITING_FOR_REMOTE_KVS
continue

req_index += 1
self.running.append(request)
if self.log_stats:
request.record_event(EngineCoreEventType.SCHEDULED,
scheduled_timestamp)

if self.lora_config and request.lora_request:
scheduled_loras.add(request.lora_request.lora_int_id)
request.num_computed_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
# Encoder-related.
if encoder_inputs_to_schedule:
# Allocate the encoder cache.
for i in encoder_inputs_to_schedule:
self.encoder_cache_manager.allocate(request, i)
encoder_budget = new_encoder_budget

new_block_ids = self.kv_cache_manager.get_block_ids(
request.request_id)

request_data: Union[NewRequestData, CachedRequestData]
if request.status == RequestStatus.WAITING:
request_data = NewRequestData.from_request(
request, new_block_ids)
elif request.status == RequestStatus.PREEMPTED:
request_data = self._make_cached_request_data(
request,
num_new_tokens,
len(request.spec_token_ids),
new_block_ids,
resumed_from_preemption=True,
)
else:
raise ValueError(
f"Invalid request status: {request.status}")

request.status = RequestStatus.RUNNING
self.scheduled_queue.append(
ScheduledRequest(
request.request_id,
num_new_tokens,
encoder_inputs_to_schedule,
0,
request.spec_token_ids.copy(),
request_data,
))

self.requests[
request.request_id].num_computed_tokens += num_new_tokens
self.running_tokens += num_new_tokens

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The schedule method in BatchScheduler duplicates a significant amount of logic from the base Scheduler class, particularly the loops for scheduling RUNNING and WAITING requests. This leads to code duplication and makes the code harder to maintain. Consider refactoring the common logic into a shared helper method or restructuring the classes to avoid this duplication.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have integration with the existing Scheduler, but I'm not yet sure that BatchScheduler is compatible with all vllm features and performs well, so I included an optional option.

spec_token_ids = scheduled_request.spec_token_ids
request_data = scheduled_request.request_data

if num_scheduled_tokens and token_budget < num_new_tokens:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The condition if num_scheduled_tokens and token_budget < num_new_tokens: for breaking the draining loop is potentially incorrect. The check token_budget < num_new_tokens should be performed even if num_scheduled_tokens is empty (i.e., for the very first request being considered from the scheduled_queue). If the first request's token count exceeds the budget, no requests should be scheduled in this batch.

Suggested change
if num_scheduled_tokens and token_budget < num_new_tokens:
if token_budget < num_new_tokens:

@juncheoll juncheoll changed the title BatchScheduler to balance token-based microbatches and reduce GPU pipeline bubbles [V1 Scheduler] BatchScheduler to balance token-based microbatches and reduce GPU pipeline bubbles Jun 19, 2025
juncheoll and others added 3 commits June 20, 2025 09:22
Reformatting

Reformatting

Reformatting

Reformatting

Reformatting

Reformatting

Signed-off-by: juncheoll <[email protected]>
@juncheoll juncheoll requested a review from hmellor as a code owner June 20, 2025 00:22
@mergify mergify bot added the documentation Improvements or additions to documentation label Jun 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants