Skip to content

middlware to track server load #18613

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
from vllm.entrypoints.utils import (LoadTrackingMiddleware, cli_env_setup,
with_cancellation)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParserManager
Expand Down Expand Up @@ -548,7 +548,6 @@ async def show_version():
}
})
@with_cancellation
@load_aware_call
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
handler = chat(raw_request)
Expand Down Expand Up @@ -587,7 +586,6 @@ async def create_chat_completion(request: ChatCompletionRequest,
},
})
@with_cancellation
@load_aware_call
async def create_completion(request: CompletionRequest, raw_request: Request):
handler = completion(raw_request)
if handler is None:
Expand Down Expand Up @@ -623,7 +621,6 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
},
})
@with_cancellation
@load_aware_call
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
handler = embedding(raw_request)
if handler is None:
Expand Down Expand Up @@ -679,7 +676,6 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request):
},
})
@with_cancellation
@load_aware_call
async def create_pooling(request: PoolingRequest, raw_request: Request):
handler = pooling(raw_request)
if handler is None:
Expand All @@ -698,7 +694,6 @@ async def create_pooling(request: PoolingRequest, raw_request: Request):

@router.post("/classify", dependencies=[Depends(validate_json_request)])
@with_cancellation
@load_aware_call
async def create_classify(request: ClassificationRequest,
raw_request: Request):
handler = classify(raw_request)
Expand Down Expand Up @@ -728,7 +723,6 @@ async def create_classify(request: ClassificationRequest,
},
})
@with_cancellation
@load_aware_call
async def create_score(request: ScoreRequest, raw_request: Request):
handler = score(raw_request)
if handler is None:
Expand Down Expand Up @@ -756,7 +750,6 @@ async def create_score(request: ScoreRequest, raw_request: Request):
},
})
@with_cancellation
@load_aware_call
async def create_score_v1(request: ScoreRequest, raw_request: Request):
logger.warning(
"To indicate that Score API is not part of standard OpenAI API, we "
Expand All @@ -783,7 +776,6 @@ async def create_score_v1(request: ScoreRequest, raw_request: Request):
},
})
@with_cancellation
@load_aware_call
async def create_transcriptions(raw_request: Request,
request: Annotated[TranscriptionRequest,
Form()]):
Expand Down Expand Up @@ -817,7 +809,6 @@ async def create_transcriptions(raw_request: Request,
},
})
@with_cancellation
@load_aware_call
async def do_rerank(request: RerankRequest, raw_request: Request):
handler = rerank(raw_request)
if handler is None:
Expand Down Expand Up @@ -1135,6 +1126,8 @@ async def log_response(request: Request, call_next):
else:
raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.")
if args.enable_server_load_tracking:
app.add_middleware(LoadTrackingMiddleware)

return app

Expand Down Expand Up @@ -1268,7 +1261,6 @@ async def init_app_state(
) if model_config.runner_type == "transcription" else None
state.task = model_config.task

state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0


Expand Down
73 changes: 33 additions & 40 deletions vllm/entrypoints/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from typing import Any, Optional

from fastapi import Request
from fastapi.responses import JSONResponse, StreamingResponse
from starlette.background import BackgroundTask, BackgroundTasks
from starlette.types import ASGIApp, Receive, Scope, Send

from vllm.logger import init_logger

Expand Down Expand Up @@ -67,54 +66,48 @@ async def wrapper(*args, **kwargs):
return wrapper


def decrement_server_load(request: Request):
request.app.state.server_load_metrics -= 1
class LoadTrackingMiddleware:

def __init__(self, app: ASGIApp):
self.app = app

def load_aware_call(func):
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] != "http":
return await self.app(scope, receive, send)

@functools.wraps(func)
async def wrapper(*args, **kwargs):
raw_request = kwargs.get("raw_request",
args[1] if len(args) > 1 else None)
path = scope.get("path", "")
if not path.startswith("/v1"):
return await self.app(scope, receive, send)

if raw_request is None:
raise ValueError(
"raw_request required when server load tracking is enabled")
state = scope["app"].state

state.server_load_metrics += 1
done = False

async def send_wrapper(message):
nonlocal done
if (message["type"] == "http.response.body"
and not message.get("more_body", False) and not done):
state.server_load_metrics -= 1
done = True
await send(message)

if not raw_request.app.state.enable_server_load_tracking:
return await func(*args, **kwargs)
async def receive_wrapper():
nonlocal done
msg = await receive()
if msg["type"] == "http.disconnect" and not done:
state.server_load_metrics -= 1
done = True
return msg

raw_request.app.state.server_load_metrics += 1
try:
response = await func(*args, **kwargs)
await self.app(scope, receive_wrapper, send_wrapper)
except Exception:
raw_request.app.state.server_load_metrics -= 1
if not done:
state.server_load_metrics -= 1
done = True
raise

if isinstance(response, (JSONResponse, StreamingResponse)):
if response.background is None:
response.background = BackgroundTask(decrement_server_load,
raw_request)
elif isinstance(response.background, BackgroundTasks):
response.background.add_task(decrement_server_load,
raw_request)
elif isinstance(response.background, BackgroundTask):
# Convert the single BackgroundTask to BackgroundTasks
# and chain the decrement_server_load task to it
tasks = BackgroundTasks()
tasks.add_task(response.background.func,
*response.background.args,
**response.background.kwargs)
tasks.add_task(decrement_server_load, raw_request)
response.background = tasks
else:
raw_request.app.state.server_load_metrics -= 1

return response

return wrapper


def cli_env_setup():
# The safest multiprocessing method is `spawn`, as the default `fork` method
Expand Down
Loading