From 67b7621312e1c22181b12f008136d0d9079c225f Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Thu, 25 Jun 2026 08:47:38 +0000 Subject: [PATCH] feat(serve): add --generation-config CLI for server sampling defaults Align api_server with vLLM by loading HuggingFace generation_config.json as default sampling params, with optional override and lmdeploy fallback. Co-authored-by: Cursor --- lmdeploy/cli/serve.py | 6 + lmdeploy/cli/utils.py | 26 +++ lmdeploy/serve/anthropic/adapter.py | 20 +- .../serve/anthropic/endpoints/messages.py | 6 +- lmdeploy/serve/anthropic/protocol.py | 2 +- lmdeploy/serve/core/generation_config.py | 173 ++++++++++++++++++ lmdeploy/serve/openai/api_server.py | 66 ++++--- lmdeploy/serve/openai/protocol.py | 26 +-- lmdeploy/serve/openai/responses/protocol.py | 4 +- lmdeploy/serve/openai/responses/request.py | 21 ++- lmdeploy/serve/openai/responses/serving.py | 6 +- .../serve/openai/serving_chat_completion.py | 28 ++- lmdeploy/serve/openai/serving_completion.py | 28 ++- .../serve/test_generation_config.py | 103 +++++++++++ 14 files changed, 443 insertions(+), 72 deletions(-) create mode 100644 lmdeploy/serve/core/generation_config.py create mode 100644 tests/test_lmdeploy/serve/test_generation_config.py diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index 01ac1d44f1..5d8e0174b6 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -92,6 +92,8 @@ def add_parser_api_server(): # model args ArgumentHelper.revision(parser) ArgumentHelper.download_dir(parser) + ArgumentHelper.generation_config(parser) + ArgumentHelper.override_generation_config(parser) # pytorch engine args pt_group = parser.add_argument_group('PyTorch engine arguments') @@ -318,6 +320,8 @@ def api_server(args): reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, + generation_config=args.generation_config, + override_generation_config=args.override_generation_config, ) else: from lmdeploy.serve.openai.launch_server import launch_server @@ -350,6 +354,8 @@ def api_server(args): reasoning_parser=args.reasoning_parser, tool_call_parser=args.tool_call_parser, speculative_config=speculative_config, + generation_config=args.generation_config, + override_generation_config=args.override_generation_config, ) @staticmethod diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 1dda62a7e8..810f54524e 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -301,6 +301,32 @@ def hf_overrides(parser): default=None, help='Extra arguments to be forwarded to the HuggingFace config.') + @staticmethod + def generation_config(parser): + """Add argument generation_config to parser.""" + return parser.add_argument( + '--generation-config', + type=str, + default='auto', + help='The folder path to the generation config. Defaults to "auto", the ' + 'generation config will be loaded from model path. If set to "lmdeploy", no ' + 'generation config is loaded, lmdeploy defaults will be used. If set to a folder ' + 'path, the generation config will be loaded from the specified folder path. ' + 'If max_new_tokens is specified in generation config, then it sets a ' + 'server-wide limit on the number of output tokens for all requests.') + + @staticmethod + def override_generation_config(parser): + """Add argument override_generation_config to parser.""" + return parser.add_argument( + '--override-generation-config', + type=json.loads, + default=None, + help='Overrides or sets generation config. e.g. \'{"temperature": 0.5}\'. If ' + 'used with --generation-config auto, the override parameters will be merged ' + 'with the default config from the model. If used with --generation-config ' + 'lmdeploy, only the override parameters are used.') + @staticmethod def use_logn_attn(parser): """Add argument use_logn_attn to parser.""" diff --git a/lmdeploy/serve/anthropic/adapter.py b/lmdeploy/serve/anthropic/adapter.py index 6975c7c8ad..0474ca6580 100644 --- a/lmdeploy/serve/anthropic/adapter.py +++ b/lmdeploy/serve/anthropic/adapter.py @@ -9,6 +9,7 @@ import shortuuid from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from .protocol import ( @@ -341,15 +342,18 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[ return lm_messages -def to_generation_config(request: MessagesRequest) -> GenerationConfig: +def to_generation_config( + request: MessagesRequest, + server_defaults: dict | None = None, + override_max_new_tokens: int | None = None, +) -> GenerationConfig: """Map Anthropic messages request to LMDeploy generation config.""" - - return GenerationConfig( - max_new_tokens=request.max_tokens, - do_sample=True, - top_k=40 if request.top_k is None else request.top_k, - top_p=1.0 if request.top_p is None else request.top_p, - temperature=1.0 if request.temperature is None else request.temperature, + request_values = extract_request_sampling_values(request) + return build_generation_config( + request_values, + server_defaults or {}, + max_tokens=request.max_tokens, + override_max_new_tokens=override_max_new_tokens, stop_words=request.stop_sequences, include_stop_str_in_output=request.include_stop_str_in_output or False, skip_special_tokens=True, diff --git a/lmdeploy/serve/anthropic/endpoints/messages.py b/lmdeploy/serve/anthropic/endpoints/messages.py index 1f5b17a54b..d390b6a868 100644 --- a/lmdeploy/serve/anthropic/endpoints/messages.py +++ b/lmdeploy/serve/anthropic/endpoints/messages.py @@ -174,7 +174,11 @@ async def create_message(request: MessagesRequest, raw_request: Request): result_generator = server_context.async_engine.generate( engine_messages, session, - gen_config=to_generation_config(request), + gen_config=to_generation_config( + request, + server_defaults=server_context.server_sampling_defaults, + override_max_new_tokens=server_context.override_max_new_tokens, + ), tools=parsed_request.tools, stream_response=True, sequence_start=True, diff --git a/lmdeploy/serve/anthropic/protocol.py b/lmdeploy/serve/anthropic/protocol.py index 03f0b1c37f..37fa0f7bd7 100644 --- a/lmdeploy/serve/anthropic/protocol.py +++ b/lmdeploy/serve/anthropic/protocol.py @@ -104,7 +104,7 @@ class MessagesRequest(BaseModel): system: str | list[ContentBlockParam] | None = None stop_sequences: list[str] | None = None stream: bool = False - temperature: float | None = 1.0 + temperature: float | None = None top_p: float | None = None top_k: int | None = None metadata: dict[str, Any] | None = None diff --git a/lmdeploy/serve/core/generation_config.py b/lmdeploy/serve/core/generation_config.py new file mode 100644 index 0000000000..b11cf26dbc --- /dev/null +++ b/lmdeploy/serve/core/generation_config.py @@ -0,0 +1,173 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Server-side generation config resolution and sampling parameter merge +helpers.""" + +from __future__ import annotations + +from typing import Any + +from lmdeploy.messages import GenerationConfig +from lmdeploy.utils import get_logger + +logger = get_logger('lmdeploy') + +PROTOCOL_FALLBACKS: dict[str, Any] = { + 'temperature': 0.7, + 'top_p': 1.0, + 'top_k': 40, + 'repetition_penalty': 1.0, + 'min_p': 0.0, + 'do_sample': True, +} + +SAMPLING_PARAM_KEYS = ( + 'temperature', + 'top_p', + 'top_k', + 'min_p', + 'repetition_penalty', + 'max_new_tokens', + 'do_sample', +) + +REQUEST_SAMPLING_FIELDS = ( + 'temperature', + 'top_p', + 'top_k', + 'min_p', + 'repetition_penalty', +) + + +def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]: + from transformers import GenerationConfig + + try: + cfg = GenerationConfig.from_pretrained(path, trust_remote_code=trust_remote_code) + return cfg.to_diff_dict() + except OSError: + return {} + + +def extract_sampling_params(config: dict[str, Any]) -> dict[str, Any]: + """Extract supported sampling parameters from a generation config dict.""" + return {key: config[key] for key in SAMPLING_PARAM_KEYS if key in config and config[key] is not None} + + +def resolve_server_sampling_defaults( + generation_config: str, + override: dict[str, Any] | None, + model_path: str, + trust_remote_code: bool, +) -> tuple[dict[str, Any], int | None]: + """Resolve server-side default sampling params from CLI flags. + + Returns: + A tuple of (sampling_defaults, override_max_new_tokens). + ``override_max_new_tokens`` is a server-wide cap/default when set. + """ + override = override or {} + src = generation_config + + if src == 'lmdeploy': + config: dict[str, Any] = {} + elif src == 'auto': + config = _load_hf_generation_config(model_path, trust_remote_code) + else: + config = _load_hf_generation_config(src, trust_remote_code) + + config.update(override) + sampling = extract_sampling_params(config) + + override_max_new_tokens = sampling.pop('max_new_tokens', None) + if override_max_new_tokens is not None: + override_max_new_tokens = int(override_max_new_tokens) + + if sampling and src != 'lmdeploy': + source = "the model's `generation_config.json`" if src == 'auto' else src + logger.info( + 'Using default sampling params from %s: %s. ' + 'Use `--generation-config lmdeploy` to disable.', + source, + sampling, + ) + elif sampling and override: + logger.info('Using override generation config sampling params: %s.', sampling) + + return sampling, override_max_new_tokens + + +def merge_sampling_params( + request_values: dict[str, Any], + server_defaults: dict[str, Any], + fallbacks: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Merge sampling params with request > server > protocol fallback + priority.""" + fallbacks = fallbacks or PROTOCOL_FALLBACKS + merged: dict[str, Any] = {} + all_keys = set(fallbacks) | set(server_defaults) | set(request_values) + for key in all_keys: + if key in request_values: + merged[key] = request_values[key] + elif key in server_defaults: + merged[key] = server_defaults[key] + elif key in fallbacks: + merged[key] = fallbacks[key] + return merged + + +def extract_request_sampling_values(request: Any) -> dict[str, Any]: + """Extract explicitly provided sampling fields from a request object.""" + values: dict[str, Any] = {} + for field in REQUEST_SAMPLING_FIELDS: + if not hasattr(request, field): + continue + value = getattr(request, field) + if value is not None: + values[field] = value + return values + + +def resolve_max_new_tokens( + max_completion_tokens: int | None, + max_tokens: int | None, + server_cap: int | None, +) -> int | None: + """Resolve output token limit with optional server-wide cap/default.""" + request_value = max_completion_tokens if max_completion_tokens is not None else max_tokens + if request_value is None: + return server_cap + if server_cap is not None: + return min(request_value, server_cap) + return request_value + + +def build_generation_config( + request_values: dict[str, Any], + server_defaults: dict[str, Any], + *, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + override_max_new_tokens: int | None = None, + fallbacks: dict[str, Any] | None = None, + **extra_kwargs: Any, +) -> GenerationConfig: + """Build ``GenerationConfig`` from merged sampling defaults and request + values.""" + merged = merge_sampling_params(request_values, server_defaults, fallbacks) + max_new_tokens = resolve_max_new_tokens( + max_completion_tokens, + max_tokens, + override_max_new_tokens, + ) + return GenerationConfig( + max_new_tokens=max_new_tokens, + do_sample=merged.get('do_sample', PROTOCOL_FALLBACKS['do_sample']), + top_k=merged['top_k'], + top_p=merged['top_p'], + temperature=merged['temperature'], + repetition_penalty=merged['repetition_penalty'], + min_p=merged['min_p'], + **extra_kwargs, + ) diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py index e8a3ca6ecc..d39150487c 100644 --- a/lmdeploy/serve/openai/api_server.py +++ b/lmdeploy/serve/openai/api_server.py @@ -47,6 +47,11 @@ ) from lmdeploy.serve.anthropic import create_anthropic_router from lmdeploy.serve.core import AsyncEngine, EngineHealthMonitor +from lmdeploy.serve.core.generation_config import ( + build_generation_config, + extract_request_sampling_values, + resolve_server_sampling_defaults, +) from lmdeploy.serve.openai.protocol import ( AbortRequest, ChatCompletionRequest, @@ -110,6 +115,8 @@ class VariableInterface: allow_terminate_by_client: bool = False enable_abort_handling: bool = False response_parser_cls: type[ResponseParser] | None = None + server_sampling_defaults: dict = {} + override_max_new_tokens: int | None = None @classmethod def create_session(cls, user_session_id: int | None = None) -> Session: @@ -147,6 +154,23 @@ def get_engine_config(cls): return cls.async_engine.backend_config +def _build_serving_generation_config(request, **extra_kwargs) -> GenerationConfig: + """Build ``GenerationConfig`` with server and request sampling merge.""" + request_values = extract_request_sampling_values(request) + max_completion_tokens = getattr(request, 'max_completion_tokens', None) + max_tokens = getattr(request, 'max_tokens', None) + if max_completion_tokens is None and hasattr(request, 'max_output_tokens'): + max_completion_tokens = getattr(request, 'max_output_tokens', None) + return build_generation_config( + request_values, + VariableInterface.server_sampling_defaults, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + override_max_new_tokens=VariableInterface.override_max_new_tokens, + **extra_kwargs, + ) + + async def _with_request_cleanup(generator, result_generators, sessions): """Yield from an API generator and cleanup when the HTTP task exits.""" session_mgr = VariableInterface.get_session_manager() @@ -486,14 +510,9 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque # (e.g. GPT-OSS clears response_format and injects the schema into messages) request = response_parser.request - gen_config = GenerationConfig( - max_new_tokens=request.max_completion_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=gen_logprobs, - top_k=request.top_k, - top_p=request.top_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, include_stop_str_in_output=request.include_stop_str_in_output, @@ -501,7 +520,6 @@ async def chat_completions_v1(request: ChatCompletionRequest, raw_request: Reque response_format=request.response_format, logits_processors=logits_processors, min_new_tokens=request.min_new_tokens, - min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, @@ -828,20 +846,13 @@ async def completions_v1(request: CompletionRequest, raw_request: Request = None if isinstance(request.stop, str): request.stop = [request.stop] random_seed = request.seed if request.seed is not None else None - max_new_tokens = (request.max_completion_tokens if request.max_completion_tokens else request.max_tokens) - gen_config = GenerationConfig( - max_new_tokens=max_new_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=request.logprobs, - top_k=request.top_k, - top_p=request.top_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, skip_special_tokens=request.skip_special_tokens, - min_p=request.min_p, random_seed=random_seed, spaces_between_special_tokens=request.spaces_between_special_tokens, migration_request=migration_request, @@ -1012,15 +1023,9 @@ async def generate(request: GenerateReqInput, raw_request: Request = None): prompt = [dict(role='user', content=[text_input] + image_input)] input_ids = None - gen_config = GenerationConfig( - max_new_tokens=request.max_tokens, - do_sample=True, + gen_config = _build_serving_generation_config( + request, logprobs=1 if request.return_logprob else None, - top_k=request.top_k, - top_p=request.top_p, - min_p=request.min_p, - temperature=request.temperature, - repetition_penalty=request.repetition_penalty, ignore_eos=request.ignore_eos, stop_words=request.stop, stop_token_ids=request.stop_token_ids, @@ -1546,6 +1551,8 @@ def serve(model_path: str, allow_terminate_by_client: bool = False, enable_abort_handling: bool = False, speculative_config: SpeculativeConfig | None = None, + generation_config: str = 'auto', + override_generation_config: dict | None = None, **kwargs): """An example to perform model inference through the command line interface. @@ -1606,6 +1613,15 @@ def serve(model_path: str, VariableInterface.allow_terminate_by_client = allow_terminate_by_client VariableInterface.enable_abort_handling = enable_abort_handling + server_defaults, override_max_new_tokens = resolve_server_sampling_defaults( + generation_config, + override_generation_config, + model_path, + trust_remote_code, + ) + VariableInterface.server_sampling_defaults = server_defaults + VariableInterface.override_max_new_tokens = override_max_new_tokens + ssl_keyfile, ssl_certfile, http_or_https = None, None, 'http' if ssl: ssl_keyfile = os.environ['SSL_KEYFILE'] diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py index 675cbf5103..40162faf15 100644 --- a/lmdeploy/serve/openai/protocol.py +++ b/lmdeploy/serve/openai/protocol.py @@ -144,8 +144,8 @@ class ChatCompletionRequest(BaseModel): model: str messages: str | list[dict[str, Any]] = Field(examples=[[{'role': 'user', 'content': 'hi'}]]) - temperature: float | None = 0.7 - top_p: float | None = 1.0 + temperature: float | None = None + top_p: float | None = None tools: list[Tool] | None = Field(default=None, examples=[None]) tool_choice: ToolChoice | AllowedToolChoice | Literal[ 'auto', 'required', 'none'] = Field(default='auto', examples=['none']) @@ -176,17 +176,17 @@ class ChatCompletionRequest(BaseModel): response_format: ResponseFormat | None = Field(default=None, examples=[None]) # additional argument of lmdeploy do_preprocess: bool | None = True - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None repetition_ngram_size: int = Field(default=0, ge=0) repetition_ngram_threshold: int = Field(default=0, ge=0) session_id: int | None = -1 ignore_eos: bool | None = False skip_special_tokens: bool | None = True spaces_between_special_tokens: bool | None = True - top_k: int | None = 40 + top_k: int | None = None seed: int | None = None min_new_tokens: int | None = Field(default=None, examples=[None]) - min_p: float = 0.0 + min_p: float | None = None enable_thinking: bool | None = None # will be deprecated in the future return_token_ids: bool | None = False return_logprob: bool | None = False @@ -352,7 +352,7 @@ class CompletionRequest(BaseModel): model: str prompt: str | list[Any] suffix: str | None = None - temperature: float | None = 0.7 + temperature: float | None = None n: int | None = 1 logprobs: int | None = None max_completion_tokens: int | None = Field( @@ -362,29 +362,29 @@ class CompletionRequest(BaseModel): 'including visible output tokens and reasoning tokens'), ) max_tokens: int | None = Field( - default=16, - examples=[16], + default=None, + examples=[None], deprecated='max_tokens is deprecated in favor of the max_completion_tokens field', ) stop: str | list[str] | None = Field(default=None, examples=[None]) stream: bool | None = False stream_options: StreamOptions | None = Field(default=None, examples=[None]) - top_p: float | None = 1.0 + top_p: float | None = None echo: bool | None = False presence_penalty: float | None = 0.0 frequency_penalty: float | None = 0.0 user: str | None = None # additional argument of lmdeploy - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None repetition_ngram_size: int = Field(default=0, ge=0) repetition_ngram_threshold: int = Field(default=0, ge=0) session_id: int | None = -1 ignore_eos: bool | None = False skip_special_tokens: bool | None = True spaces_between_special_tokens: bool | None = True - top_k: int | None = 40 # for opencompass + top_k: int | None = None # for opencompass seed: int | None = None - min_p: float = 0.0 + min_p: float | None = None class CompletionResponseChoice(BaseModel): @@ -550,7 +550,7 @@ class GenerateReqInput(BaseModel): stop_token_ids: list[int] | None = None stream: bool | None = False temperature: float = 1.0 - repetition_penalty: float | None = 1.0 + repetition_penalty: float | None = None ignore_eos: bool | None = False top_p: float = 1.0 top_k: int = 0 diff --git a/lmdeploy/serve/openai/responses/protocol.py b/lmdeploy/serve/openai/responses/protocol.py index ed08f16202..d02d220df7 100644 --- a/lmdeploy/serve/openai/responses/protocol.py +++ b/lmdeploy/serve/openai/responses/protocol.py @@ -82,10 +82,10 @@ class ResponsesRequest(BaseModel): presence_penalty: float | None = None frequency_penalty: float | None = None repetition_penalty: float | None = None - top_k: int | None = 40 + top_k: int | None = None stop: str | list[str] | None = None seed: int | None = None - min_p: float = 0.0 + min_p: float | None = None ignore_eos: bool | None = False skip_special_tokens: bool | None = True include_stop_str_in_output: bool | None = False diff --git a/lmdeploy/serve/openai/responses/request.py b/lmdeploy/serve/openai/responses/request.py index 2db1000532..6a7dc9f8df 100644 --- a/lmdeploy/serve/openai/responses/request.py +++ b/lmdeploy/serve/openai/responses/request.py @@ -9,6 +9,7 @@ from fastapi.responses import JSONResponse from lmdeploy.messages import GenerationConfig +from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName from lmdeploy.serve.openai.responses.protocol import ResponsesRequest from lmdeploy.utils import get_logger @@ -264,20 +265,22 @@ def _response_format_from_text(text: Any) -> dict[str, Any] | None: raise ValueError(f'Unsupported text.format type: {format_type!r}.') -def to_generation_config(request: ResponsesRequest) -> GenerationConfig: +def to_generation_config( + request: ResponsesRequest, + server_defaults: dict | None = None, + override_max_new_tokens: int | None = None, +) -> GenerationConfig: stop_words = [request.stop] if isinstance(request.stop, str) else request.stop - return GenerationConfig( - max_new_tokens=request.max_output_tokens, - do_sample=True, - top_k=40 if request.top_k is None else request.top_k, - top_p=1.0 if request.top_p is None else request.top_p, - temperature=1.0 if request.temperature is None else request.temperature, + request_values = extract_request_sampling_values(request) + return build_generation_config( + request_values, + server_defaults or {}, + max_completion_tokens=request.max_output_tokens, + override_max_new_tokens=override_max_new_tokens, stop_words=stop_words, ignore_eos=request.ignore_eos, skip_special_tokens=request.skip_special_tokens, include_stop_str_in_output=request.include_stop_str_in_output, response_format=_response_format_from_text(request.text), - min_p=request.min_p, random_seed=request.seed, - repetition_penalty=1.0 if request.repetition_penalty is None else request.repetition_penalty, ) diff --git a/lmdeploy/serve/openai/responses/serving.py b/lmdeploy/serve/openai/responses/serving.py index 2837e483c5..64864ece80 100644 --- a/lmdeploy/serve/openai/responses/serving.py +++ b/lmdeploy/serve/openai/responses/serving.py @@ -97,7 +97,11 @@ async def create_response(self, request: ResponsesRequest, raw_request: Request) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='input') try: - gen_config = to_generation_config(request) + gen_config = to_generation_config( + request, + server_defaults=self.server_context.server_sampling_defaults, + override_max_new_tokens=self.server_context.override_max_new_tokens, + ) except ValueError as err: return error_response(HTTPStatus.BAD_REQUEST, str(err), param='text') try: diff --git a/lmdeploy/serve/openai/serving_chat_completion.py b/lmdeploy/serve/openai/serving_chat_completion.py index 362f0bf9e5..7ac750a617 100644 --- a/lmdeploy/serve/openai/serving_chat_completion.py +++ b/lmdeploy/serve/openai/serving_chat_completion.py @@ -1,12 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING +from lmdeploy.serve.core.generation_config import ( + PROTOCOL_FALLBACKS, + extract_request_sampling_values, + merge_sampling_params, +) + from .protocol import ChatCompletionRequest if TYPE_CHECKING: from .api_server import VariableInterface +def _effective_sampling(request: ChatCompletionRequest, server_context: 'VariableInterface') -> dict: + return merge_sampling_params( + extract_request_sampling_values(request), + server_context.server_sampling_defaults, + PROTOCOL_FALLBACKS, + ) + + def check_request(request: ChatCompletionRequest, server_context: 'VariableInterface') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() @@ -32,15 +46,17 @@ def check_request(request: ChatCompletionRequest, server_context: 'VariableInter if session_manager.has(request.session_id): return f'The session_id {request.session_id!r} is occupied.' + sampling = _effective_sampling(request, server_context) + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' - if not (0 < request.top_p <= 1): - return f'The top_p {request.top_p!r} must be in (0, 1].' - if request.top_k < 0: - return f'The top_k {request.top_k!r} cannot be a negative integer.' - if not (0 <= request.temperature <= 2): - return f'The temperature {request.temperature!r} must be in [0, 2]' + if not (0 < sampling['top_p'] <= 1): + return f'The top_p {sampling["top_p"]!r} must be in (0, 1].' + if sampling['top_k'] < 0: + return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.' + if not (0 <= sampling['temperature'] <= 2): + return f'The temperature {sampling["temperature"]!r} must be in [0, 2]' # Validate input_ids and image_data constraints. # messages has higher priority. input_ids and image_data are only used when diff --git a/lmdeploy/serve/openai/serving_completion.py b/lmdeploy/serve/openai/serving_completion.py index 759972db36..3047717c1f 100644 --- a/lmdeploy/serve/openai/serving_completion.py +++ b/lmdeploy/serve/openai/serving_completion.py @@ -1,12 +1,26 @@ # Copyright (c) OpenMMLab. All rights reserved. from typing import TYPE_CHECKING +from lmdeploy.serve.core.generation_config import ( + PROTOCOL_FALLBACKS, + extract_request_sampling_values, + merge_sampling_params, +) + from .protocol import CompletionRequest if TYPE_CHECKING: from .api_server import VariableInterface +def _effective_sampling(request: CompletionRequest, server_context: 'VariableInterface') -> dict: + return merge_sampling_params( + extract_request_sampling_values(request), + server_context.server_sampling_defaults, + PROTOCOL_FALLBACKS, + ) + + def check_request(request: CompletionRequest, server_context: 'VariableInterface') -> str: engine_config = server_context.get_engine_config() session_manager = server_context.get_session_manager() @@ -24,14 +38,16 @@ def check_request(request: CompletionRequest, server_context: 'VariableInterface if session_manager.has(request.session_id): return f'The session_id {request.session_id!r} is occupied.' + sampling = _effective_sampling(request, server_context) + # check sampling settings if request.n <= 0: return f'The n {request.n!r} must be a positive int.' - if not (0 < request.top_p <= 1): - return f'The top_p {request.top_p!r} must be in (0, 1].' - if request.top_k < 0: - return f'The top_k {request.top_k!r} cannot be a negative integer.' - if not (0 <= request.temperature <= 2): - return f'The temperature {request.temperature!r} must be in [0, 2]' + if not (0 < sampling['top_p'] <= 1): + return f'The top_p {sampling["top_p"]!r} must be in (0, 1].' + if sampling['top_k'] < 0: + return f'The top_k {sampling["top_k"]!r} cannot be a negative integer.' + if not (0 <= sampling['temperature'] <= 2): + return f'The temperature {sampling["temperature"]!r} must be in [0, 2]' return '' diff --git a/tests/test_lmdeploy/serve/test_generation_config.py b/tests/test_lmdeploy/serve/test_generation_config.py new file mode 100644 index 0000000000..a5a9199a2f --- /dev/null +++ b/tests/test_lmdeploy/serve/test_generation_config.py @@ -0,0 +1,103 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest.mock import patch + +from lmdeploy.serve.core.generation_config import ( + PROTOCOL_FALLBACKS, + build_generation_config, + extract_request_sampling_values, + merge_sampling_params, + resolve_max_new_tokens, + resolve_server_sampling_defaults, +) +from lmdeploy.serve.openai.protocol import ChatCompletionRequest, CompletionRequest + + +def test_merge_sampling_params_priority(): + merged = merge_sampling_params( + {'temperature': 0.2}, + {'temperature': 0.5, 'top_k': 10}, + PROTOCOL_FALLBACKS, + ) + assert merged['temperature'] == 0.2 + assert merged['top_k'] == 10 + assert merged['top_p'] == PROTOCOL_FALLBACKS['top_p'] + + +def test_merge_sampling_params_uses_server_then_fallback(): + merged = merge_sampling_params({}, {'temperature': 0.5}, PROTOCOL_FALLBACKS) + assert merged['temperature'] == 0.5 + assert merged['top_k'] == PROTOCOL_FALLBACKS['top_k'] + + +def test_extract_request_sampling_values_only_non_null(): + request = ChatCompletionRequest(model='test', messages='hi', temperature=0.3) + values = extract_request_sampling_values(request) + assert values == {'temperature': 0.3} + + +def test_resolve_max_new_tokens_uses_server_default(): + assert resolve_max_new_tokens(None, None, 128) == 128 + + +def test_resolve_max_new_tokens_caps_request_value(): + assert resolve_max_new_tokens(256, None, 128) == 128 + assert resolve_max_new_tokens(None, 256, 128) == 128 + + +def test_resolve_max_new_tokens_prefers_max_completion_tokens(): + assert resolve_max_new_tokens(64, 256, None) == 64 + + +def test_build_generation_config_from_merged_values(): + gen_config = build_generation_config( + {'temperature': 0.2}, + {'top_k': 5}, + max_completion_tokens=32, + override_max_new_tokens=64, + ) + assert gen_config.temperature == 0.2 + assert gen_config.top_k == 5 + assert gen_config.max_new_tokens == 32 + assert gen_config.do_sample is True + + +@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') +def test_resolve_server_sampling_defaults_auto(mock_load): + mock_load.return_value = { + 'temperature': 0.6, + 'top_p': 0.8, + 'max_new_tokens': 2048, + } + defaults, cap = resolve_server_sampling_defaults('auto', None, '/fake/model', False) + assert defaults == {'temperature': 0.6, 'top_p': 0.8} + assert cap == 2048 + mock_load.assert_called_once_with('/fake/model', False) + + +def test_resolve_server_sampling_defaults_lmdeploy(): + defaults, cap = resolve_server_sampling_defaults('lmdeploy', None, '/fake/model', False) + assert defaults == {} + assert cap is None + + +@patch('lmdeploy.serve.core.generation_config._load_hf_generation_config') +def test_resolve_server_sampling_defaults_with_override(mock_load): + mock_load.return_value = {'temperature': 0.6, 'top_k': 20} + defaults, cap = resolve_server_sampling_defaults( + 'auto', + {'temperature': 0.5, 'max_new_tokens': 100}, + '/fake/model', + False, + ) + assert defaults == {'temperature': 0.5, 'top_k': 20} + assert cap == 100 + + +def test_completion_request_sampling_merge(): + request = CompletionRequest(model='test', prompt='hello') + gen_config = build_generation_config( + extract_request_sampling_values(request), + {'temperature': 0.9}, + ) + assert gen_config.temperature == 0.9 + assert gen_config.top_k == PROTOCOL_FALLBACKS['top_k']