Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def add_parser_api_server():
# model args
ArgumentHelper.revision(parser)
ArgumentHelper.download_dir(parser)
ArgumentHelper.generation_config(parser)
ArgumentHelper.override_generation_config(parser)

# pytorch engine args
pt_group = parser.add_argument_group('PyTorch engine arguments')
Expand Down Expand Up @@ -318,6 +320,8 @@ def api_server(args):
reasoning_parser=args.reasoning_parser,
tool_call_parser=args.tool_call_parser,
speculative_config=speculative_config,
generation_config=args.generation_config,
override_generation_config=args.override_generation_config,
)
else:
from lmdeploy.serve.openai.launch_server import launch_server
Expand Down Expand Up @@ -350,6 +354,8 @@ def api_server(args):
reasoning_parser=args.reasoning_parser,
tool_call_parser=args.tool_call_parser,
speculative_config=speculative_config,
generation_config=args.generation_config,
override_generation_config=args.override_generation_config,
)

@staticmethod
Expand Down
26 changes: 26 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ def hf_overrides(parser):
default=None,
help='Extra arguments to be forwarded to the HuggingFace config.')

@staticmethod
def generation_config(parser):
"""Add argument generation_config to parser."""
return parser.add_argument(
'--generation-config',
type=str,
default='auto',
help='The folder path to the generation config. Defaults to "auto", the '
'generation config will be loaded from model path. If set to "lmdeploy", no '
'generation config is loaded, lmdeploy defaults will be used. If set to a folder '
'path, the generation config will be loaded from the specified folder path. '
'If max_new_tokens is specified in generation config, then it sets a '
'server-wide limit on the number of output tokens for all requests.')

@staticmethod
def override_generation_config(parser):
"""Add argument override_generation_config to parser."""
return parser.add_argument(
'--override-generation-config',
type=json.loads,
default=None,
help='Overrides or sets generation config. e.g. \'{"temperature": 0.5}\'. If '
'used with --generation-config auto, the override parameters will be merged '
'with the default config from the model. If used with --generation-config '
'lmdeploy, only the override parameters are used.')

@staticmethod
def use_logn_attn(parser):
"""Add argument use_logn_attn to parser."""
Expand Down
20 changes: 12 additions & 8 deletions lmdeploy/serve/anthropic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import shortuuid

from lmdeploy.messages import GenerationConfig
from lmdeploy.serve.core.generation_config import build_generation_config, extract_request_sampling_values
from lmdeploy.serve.openai.protocol import Tool, ToolChoice, ToolChoiceFuncName

from .protocol import (
Expand Down Expand Up @@ -341,15 +342,18 @@ def to_lmdeploy_messages(request: MessagesRequest | CountTokensRequest) -> list[
return lm_messages


def to_generation_config(request: MessagesRequest) -> GenerationConfig:
def to_generation_config(
request: MessagesRequest,
server_defaults: dict | None = None,
override_max_new_tokens: int | None = None,
) -> GenerationConfig:
"""Map Anthropic messages request to LMDeploy generation config."""

return GenerationConfig(
max_new_tokens=request.max_tokens,
do_sample=True,
top_k=40 if request.top_k is None else request.top_k,
top_p=1.0 if request.top_p is None else request.top_p,
temperature=1.0 if request.temperature is None else request.temperature,
request_values = extract_request_sampling_values(request)
return build_generation_config(
request_values,
server_defaults or {},
max_tokens=request.max_tokens,
override_max_new_tokens=override_max_new_tokens,
stop_words=request.stop_sequences,
include_stop_str_in_output=request.include_stop_str_in_output or False,
skip_special_tokens=True,
Expand Down
6 changes: 5 additions & 1 deletion lmdeploy/serve/anthropic/endpoints/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ async def create_message(request: MessagesRequest, raw_request: Request):
result_generator = server_context.async_engine.generate(
engine_messages,
session,
gen_config=to_generation_config(request),
gen_config=to_generation_config(
request,
server_defaults=server_context.server_sampling_defaults,
override_max_new_tokens=server_context.override_max_new_tokens,
),
tools=parsed_request.tools,
stream_response=True,
sequence_start=True,
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/serve/anthropic/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class MessagesRequest(BaseModel):
system: str | list[ContentBlockParam] | None = None
stop_sequences: list[str] | None = None
stream: bool = False
temperature: float | None = 1.0
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
metadata: dict[str, Any] | None = None
Expand Down
173 changes: 173 additions & 0 deletions lmdeploy/serve/core/generation_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Server-side generation config resolution and sampling parameter merge
helpers."""

from __future__ import annotations

from typing import Any

from lmdeploy.messages import GenerationConfig
from lmdeploy.utils import get_logger

logger = get_logger('lmdeploy')

PROTOCOL_FALLBACKS: dict[str, Any] = {
'temperature': 0.7,
'top_p': 1.0,
'top_k': 40,
'repetition_penalty': 1.0,
'min_p': 0.0,
'do_sample': True,
}

SAMPLING_PARAM_KEYS = (
'temperature',
'top_p',
'top_k',
'min_p',
'repetition_penalty',
'max_new_tokens',
'do_sample',
)

REQUEST_SAMPLING_FIELDS = (
'temperature',
'top_p',
'top_k',
'min_p',
'repetition_penalty',
)


def _load_hf_generation_config(path: str, trust_remote_code: bool) -> dict[str, Any]:
from transformers import GenerationConfig

try:
cfg = GenerationConfig.from_pretrained(path, trust_remote_code=trust_remote_code)
return cfg.to_diff_dict()
except OSError:
return {}


def extract_sampling_params(config: dict[str, Any]) -> dict[str, Any]:
"""Extract supported sampling parameters from a generation config dict."""
return {key: config[key] for key in SAMPLING_PARAM_KEYS if key in config and config[key] is not None}


def resolve_server_sampling_defaults(
generation_config: str,
override: dict[str, Any] | None,
model_path: str,
trust_remote_code: bool,
) -> tuple[dict[str, Any], int | None]:
"""Resolve server-side default sampling params from CLI flags.

Returns:
A tuple of (sampling_defaults, override_max_new_tokens).
``override_max_new_tokens`` is a server-wide cap/default when set.
"""
override = override or {}
src = generation_config
Comment on lines +69 to +70

if src == 'lmdeploy':
config: dict[str, Any] = {}
elif src == 'auto':
config = _load_hf_generation_config(model_path, trust_remote_code)
else:
config = _load_hf_generation_config(src, trust_remote_code)

config.update(override)
sampling = extract_sampling_params(config)

override_max_new_tokens = sampling.pop('max_new_tokens', None)
if override_max_new_tokens is not None:
override_max_new_tokens = int(override_max_new_tokens)

Comment on lines +82 to +85
if sampling and src != 'lmdeploy':
source = "the model's `generation_config.json`" if src == 'auto' else src
logger.info(
'Using default sampling params from %s: %s. '
'Use `--generation-config lmdeploy` to disable.',
source,
sampling,
)
elif sampling and override:
logger.info('Using override generation config sampling params: %s.', sampling)

return sampling, override_max_new_tokens


def merge_sampling_params(
request_values: dict[str, Any],
server_defaults: dict[str, Any],
fallbacks: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Merge sampling params with request > server > protocol fallback
priority."""
fallbacks = fallbacks or PROTOCOL_FALLBACKS
merged: dict[str, Any] = {}
all_keys = set(fallbacks) | set(server_defaults) | set(request_values)
for key in all_keys:
if key in request_values:
merged[key] = request_values[key]
elif key in server_defaults:
merged[key] = server_defaults[key]
elif key in fallbacks:
merged[key] = fallbacks[key]
return merged


def extract_request_sampling_values(request: Any) -> dict[str, Any]:
"""Extract explicitly provided sampling fields from a request object."""
values: dict[str, Any] = {}
for field in REQUEST_SAMPLING_FIELDS:
if not hasattr(request, field):
continue
value = getattr(request, field)
if value is not None:
values[field] = value
return values


def resolve_max_new_tokens(
max_completion_tokens: int | None,
max_tokens: int | None,
server_cap: int | None,
) -> int | None:
"""Resolve output token limit with optional server-wide cap/default."""
request_value = max_completion_tokens if max_completion_tokens is not None else max_tokens
if request_value is None:
return server_cap
if server_cap is not None:
return min(request_value, server_cap)
Comment on lines +138 to +142
return request_value


def build_generation_config(
request_values: dict[str, Any],
server_defaults: dict[str, Any],
*,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
override_max_new_tokens: int | None = None,
fallbacks: dict[str, Any] | None = None,
**extra_kwargs: Any,
) -> GenerationConfig:
"""Build ``GenerationConfig`` from merged sampling defaults and request
values."""
merged = merge_sampling_params(request_values, server_defaults, fallbacks)
max_new_tokens = resolve_max_new_tokens(
max_completion_tokens,
max_tokens,
override_max_new_tokens,
)
return GenerationConfig(
max_new_tokens=max_new_tokens,
do_sample=merged.get('do_sample', PROTOCOL_FALLBACKS['do_sample']),
top_k=merged['top_k'],
top_p=merged['top_p'],
temperature=merged['temperature'],
repetition_penalty=merged['repetition_penalty'],
min_p=merged['min_p'],
**extra_kwargs,
)
Loading
Loading