From dec9def919703c54380323b1eefc103839d417ab Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Fri, 5 Jun 2026 03:07:23 -0700 Subject: [PATCH 1/7] [None][feat] Coordinator/worker disagg serving; fold routing into disagg service Replace the WEB_CONCURRENCY multi-worker no-op (and the router_http_server / disagg_app scaffolding) with a coordinator/worker model: - disaggregated (WEB_CONCURRENCY>1) becomes a pure coordinator on port-1 that owns the ctx/gen routers, readiness, cluster/worker events, and the centralized ZMQ ingest bind, and serves only the internal /select, /finish, /cluster_info, /health API (coordinator_server.py). - The worker fleet runs via one uvicorn process group (workers=N) over a shared listening socket on the public port, rebuilt from a stateless import-string factory (create_worker_app); uvicorn owns supervision + graceful shutdown. - Placement is split on Router: extract_routing_key (client/worker side) + select_by_key / finish_by_handle (coordinator side). Round-robin -> empty key, conversation -> conversation_id (handle-based load release), centralized -> block hashes. The worker holds a CoordinatorHttpRouter that posts the key to the coordinator; single-process calls the router directly. - New DisaggCoordinator abstraction (disagg_coordinator.py): DisaggCoordinatorService (in-process, owns routers) and CoordinatorClient (worker, delegates over HTTP). OpenAIDisaggregatedService reads ctx_router/gen_router off the coordinator and drives get_next_server / finish_request uniformly, so serving is identical in both modes. - Remove router_http_server.py, disagg_app.py, RemoteHttpRouter and the remote_http router type, plus their tests; add test_coordinator_worker.py. Verified in-container (gb200): test_coordinator_worker 2/2 across 5 repeats, test_per_rank_routing + test_centralized_kv_cache_router + test_openai_disagg_service 76 passed. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- docs/source/features/disagg-serving.md | 89 +++ tensorrt_llm/commands/serve.py | 219 +++++- tensorrt_llm/llmapi/disagg_utils.py | 14 + tensorrt_llm/llmapi/llm.py | 16 +- tensorrt_llm/llmapi/llm_args.py | 23 + tensorrt_llm/serve/coordinator_server.py | 122 +++ tensorrt_llm/serve/disagg_coordinator.py | 425 +++++++++++ tensorrt_llm/serve/openai_disagg_server.py | 62 +- tensorrt_llm/serve/openai_disagg_service.py | 135 +--- tensorrt_llm/serve/openai_server.py | 21 + tensorrt_llm/serve/router.py | 709 +++++++++--------- tensorrt_llm/serve/router_utils.py | 376 ++++++++++ .../disaggregated/test_coordinator_e2e.py | 410 ++++++++++ .../disaggregated/test_coordinator_worker.py | 247 ++++++ .../test_openai_disagg_service.py | 24 +- 15 files changed, 2395 insertions(+), 497 deletions(-) create mode 100644 tensorrt_llm/serve/coordinator_server.py create mode 100644 tensorrt_llm/serve/disagg_coordinator.py create mode 100644 tensorrt_llm/serve/router_utils.py create mode 100644 tests/unittest/disaggregated/test_coordinator_e2e.py create mode 100644 tests/unittest/disaggregated/test_coordinator_worker.py diff --git a/docs/source/features/disagg-serving.md b/docs/source/features/disagg-serving.md index 7083b009a8f6..b70fd2f7ee9e 100644 --- a/docs/source/features/disagg-serving.md +++ b/docs/source/features/disagg-serving.md @@ -233,6 +233,95 @@ Example (two-node deployment): - **Client entrypoint** - Send requests or use a load balancer forwarding to `node-a:8000` and `node-b:8000` +### Coordinator and Worker Fleet + +A single disaggregated server process is itself a single-threaded orchestrator and can become a throughput bottleneck (it terminates every client connection, runs routing, and proxies the ctx→gen hop). To scale the orchestrator on one node without standing up multiple independent instances, `trtllm-serve disaggregated` can run a **fleet** of stateless disaggregated-server worker processes behind a shared **coordinator**. + +The two roles split as follows: + +- **Coordinator** — a single process that owns all cluster state: the ctx/gen routers, worker readiness, and (for the KV-cache-aware router) the single ZMQ event-ingest endpoint. It exposes an internal coordination API (`/select`, `/finish`, `/cluster_info`, `/health`). +- **Fleet workers** — `num_workers` stateless disaggregated servers sharing the public port (via uvicorn `--workers`). Each holds a lightweight delegating client: it computes the routing key locally (e.g. block hashes) and delegates the placement decision to the coordinator over HTTP. Workers own no routing state, so routing stays globally consistent no matter which worker terminates a connection. + +This is controlled by two fields in the disaggregated config: + +- `num_workers` (int, default `1`) — number of disaggregated-server worker processes to run on the public port. +- `disagg_coordinator_url` (str, optional) — URL of an already-running coordinator. When set, this process starts **no** coordinator and its fleet delegates to that external one. + +The three resulting topologies: + +| `num_workers` | `disagg_coordinator_url` | Behavior | +|---------------|--------------------------|----------| +| `1` | unset | Single self-contained server with an in-process coordinator (the default; unchanged from earlier examples). | +| `> 1` | unset | An **implicit** coordinator starts in this process (on `port - 1`) and a fleet of `num_workers` delegating servers runs on the public port. | +| any | set | **No** coordinator starts here; a fleet of `num_workers` delegating servers points at the external `disagg_coordinator_url`. | + +```{note} +The fleet is most useful with a *stateful* router (`kv_cache_aware`, `conversation`) where placement must be globally consistent — that decision is delegated to the coordinator. With a *stateless* router (`round_robin`, `load_balancing`) each worker simply places locally and no coordinator round-trip occurs. +``` + +#### Example: implicit coordinator + 4-worker fleet + +Extend the `disagg_config.yaml` from the [trtllm-serve](#trtllm-serve) example with `num_workers` and a router type: + +```yaml +hostname: localhost +port: 8000 +backend: pytorch +# Run 4 stateless disaggregated-server workers on port 8000, with an implicit +# coordinator started in-process on port 7999 (port - 1). +num_workers: 4 +context_servers: + num_instances: 2 + urls: + - "localhost:8001" + - "localhost:8002" + router: + type: kv_cache_aware +generation_servers: + num_instances: 1 + urls: + - "localhost:8003" + router: + type: kv_cache_aware +``` + +Launch it exactly as before — the coordinator and fleet are started for you: + +``` +trtllm-serve disaggregated -c disagg_config.yaml +``` + +Clients still send requests to the public endpoint (`localhost:8000`); the fleet transparently delegates routing to the coordinator. + +#### Example: external coordinator + +To point a fleet at a coordinator already running elsewhere (for example, one shared across nodes), set `disagg_coordinator_url` and omit the coordinator from this process: + +```yaml +hostname: localhost +port: 8000 +backend: pytorch +num_workers: 4 +disagg_coordinator_url: "http://coordinator-host:7999" +context_servers: + num_instances: 2 + urls: + - "localhost:8001" + - "localhost:8002" + router: + type: kv_cache_aware +generation_servers: + num_instances: 1 + urls: + - "localhost:8003" + router: + type: kv_cache_aware +``` + +```{note} +A fleet worker fails fast if its coordinator is unreachable: on startup it probes the coordinator's `/cluster_info` with bounded retry (up to `--server_start_timeout` seconds) and exits with an error rather than coming up and returning `Cluster is not ready` for every request. +``` + ## Environment Variables TRT-LLM uses some environment variables to control the behavior of disaggregated service. diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 30c728d0b78e..a5e51b7092a3 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -1,4 +1,5 @@ import asyncio +import atexit import gc import importlib import inspect @@ -1238,6 +1239,43 @@ def disaggregated( # Inherited by child processes via env var; used for deduplication at query time. os.environ[DisaggLauncherEnvs.TLLM_DISAGG_DEPLOYMENT_ID] = uuid.uuid4().hex + metadata_server_cfg = parse_metadata_server_config_file( + metadata_server_config_file) + + # Disable GC by default (see note below). + if os.getenv("TRTLLM_DISAGG_SERVER_DISABLE_GC", "1") == "1": + gc.disable() + + # Startup topology is driven by explicit config (num_workers + + # disagg_coordinator_url), NOT the WEB_CONCURRENCY env var. Three cases: + # (a) disagg_coordinator_url set -> don't start a coordinator here; the + # fleet delegates to that external one (num_workers sizes the fleet). + # (b) url absent, num_workers>1 -> start an implicit in-process + # coordinator (port-1) + a delegating uvicorn fleet on the public port. + # (c) url absent, num_workers==1 -> one self-contained server with a local + # (in-process) coordinator. + num_workers = disagg_cfg.num_workers + coordinator_url = disagg_cfg.disagg_coordinator_url + + if coordinator_url: + # (a) External coordinator: fork a fleet of delegating servers (or a + # single one) pointed at it; never start a coordinator in this process. + _serve_disagg_fleet(disagg_cfg, config_file, + metadata_server_config_file, request_timeout, + server_start_timeout, num_workers, coordinator_url) + return + + if num_workers > 1: + # (b) Implicit coordinator in this process (on port-1) + a delegating + # uvicorn fleet (workers=N) on the public port. See below. + _serve_coordinator_and_fleet(disagg_cfg, config_file, + metadata_server_config_file, + metadata_server_cfg, request_timeout, + server_start_timeout, num_workers) + return + + # (c) num_workers==1, no external coordinator: a single disagg server with an + # in-process (local) coordinator. Pre-bind the socket (validates port), serve. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: try: s.bind((disagg_cfg.hostname, disagg_cfg.port)) @@ -1246,9 +1284,6 @@ def disaggregated( f"Failed to bind socket to {disagg_cfg.hostname}:{disagg_cfg.port}: {e}" ) - metadata_server_cfg = parse_metadata_server_config_file( - metadata_server_config_file) - server = OpenAIDisaggServer( config=disagg_cfg, req_timeout_secs=request_timeout, @@ -1256,22 +1291,164 @@ def disaggregated( metadata_server_cfg=metadata_server_cfg, metrics_interval_secs=metrics_log_interval) - # Disable GC by default - # When concurrency is high, the number of Python objects increases, so - # GC runs frequently and takes a long time to process. In this case, - # requests are not immediately forwarded to CTX workers and GEN workers, - # causing them to run with small batch sizes. Disabling GC can mitigate - # this problem. - # By testing this feature, we didn't observe significant RSS or VMS - # increment, and observed that `count0` (obtained by `gc.get_count()`) - # increases by fewer than 1,000 after every 200,000 requests, while the - # maximum value of `count0` exceeded 3,000,000 during the test. - if os.getenv("TRTLLM_DISAGG_SERVER_DISABLE_GC", "1") == "1": - gc.disable() - asyncio.run(server(disagg_cfg.hostname, disagg_cfg.port, sockets=[s])) +def _launch_disagg_fleet(disagg_cfg, config_file, metadata_server_config_file, + request_timeout, server_start_timeout, num_workers, + coordinator_url): + """Fork a uvicorn fleet of delegating disagg servers pointed at ``coordinator_url``. + + Each worker is an ordinary ``OpenAIDisaggServer`` built with ``coordinator_url`` + so it holds a remote ``CoordinatorClient``. Invoked as ``python -m uvicorn`` so + there is no bespoke worker command; uvicorn owns the shared socket, worker + supervision, and graceful shutdown. MPI/PMIX/SLURM env is stripped so a worker + (a plain HTTP process) never joins an MPI namespace. Returns the Popen handle. + """ + public_host, public_port = disagg_cfg.hostname, disagg_cfg.port + child_env = { + k: v for k, v in os.environ.items() + if not k.startswith(("SLURM_", "PMIX_", "PMI_", "OMPI_", "UCX_", + "I_MPI_", "HYDRA_", "MPI_")) + } + # num_workers is explicit config now; ensure no stale WEB_CONCURRENCY leaks in + # and re-forks each plain-HTTP worker into a nested fleet. + child_env.pop("WEB_CONCURRENCY", None) + child_env[DisaggWorkerEnvs.TLLM_DISAGG_COORDINATOR_URL] = coordinator_url + child_env[DisaggWorkerEnvs.TLLM_DISAGG_CONFIG_FILE] = os.path.abspath( + config_file) + if metadata_server_config_file: + child_env[DisaggWorkerEnvs.TLLM_DISAGG_METADATA_CONFIG_FILE] = \ + os.path.abspath(metadata_server_config_file) + child_env[DisaggWorkerEnvs.TLLM_DISAGG_REQUEST_TIMEOUT] = str(request_timeout) + child_env[DisaggWorkerEnvs.TLLM_DISAGG_SERVER_START_TIMEOUT] = str( + server_start_timeout) + cmd = [sys.executable, "-m", "uvicorn", "--factory", + "--host", str(public_host), "--port", str(public_port), + "--workers", str(num_workers), "--timeout-keep-alive", "10", + "tensorrt_llm.commands.serve:create_disagg_server_app"] + logger.info(f"Launching disagg fleet: {num_workers} uvicorn workers on " + f"{public_host}:{public_port}, coordinator={coordinator_url}") + logger.info(f"Disagg fleet command: {' '.join(cmd)}") + fleet = subprocess.Popen(cmd, env=child_env, stdout=sys.stdout, + stderr=sys.stderr, start_new_session=True) + logger.info(f"Disagg fleet launched (pid={fleet.pid})") + + def _cleanup(): + if fleet.poll() is None: + fleet.terminate() + try: + fleet.wait(timeout=10) + except Exception: + fleet.kill() + + atexit.register(_cleanup) + return fleet + + +def _serve_disagg_fleet(disagg_cfg, config_file, metadata_server_config_file, + request_timeout, server_start_timeout, num_workers, + coordinator_url): + """External coordinator: fork the delegating fleet and block on it. + + No coordinator is started in this process -- the fleet delegates to the + already-running coordinator at ``coordinator_url``. + """ + fleet = _launch_disagg_fleet(disagg_cfg, config_file, + metadata_server_config_file, request_timeout, + server_start_timeout, num_workers, + coordinator_url) + rc = fleet.wait() + if rc != 0: + raise RuntimeError(f"Disagg fleet exited with code {rc}") + + +def _serve_coordinator_and_fleet(disagg_cfg, config_file, + metadata_server_config_file, + metadata_server_cfg, request_timeout, + server_start_timeout, num_workers): + """workers>1, no external URL: coordinator server here + a delegating fleet. + + This process runs the coordinator server (owns the ctx/gen routers, cluster + state, and centralized ZMQ ingest) on ``port-1``; a uvicorn fleet on the + public port delegates to it. + """ + from tensorrt_llm.serve.coordinator_server import CoordinatorServer + from tensorrt_llm.serve.disagg_coordinator import DisaggCoordinatorService + + public_host, public_port = disagg_cfg.hostname, disagg_cfg.port + coord_port = int(os.environ.get("TLLM_DISAGG_COORDINATOR_PORT", + public_port - 1)) + coord_url = f"http://{public_host}:{coord_port}" + + # 1. Launch the delegating fleet pointed at the implicit coordinator we start + # below (port-1). Workers hold CoordinatorClients (no core), so they can't + # race the coordinator's single ZMQ ingest bind. + _launch_disagg_fleet(disagg_cfg, config_file, metadata_server_config_file, + request_timeout, server_start_timeout, num_workers, + coord_url) + + # 2. Build + serve the coordinator in this process. It OWNS routing state and + # builds the owner routers itself (single shared namespace-aware core + ONE + # ZMQ ingest server, started once here). + def _client_factory(router, role, max_retries=1): + from tensorrt_llm.serve.openai_client import OpenAIHttpClient + return OpenAIHttpClient(router, role, request_timeout, max_retries) + + coordinator = DisaggCoordinatorService( + disagg_cfg, _client_factory, + metadata_config=metadata_server_cfg, + server_start_timeout_secs=server_start_timeout) + logger.info(f"Coordinator serving on {public_host}:{coord_port} " + f"(fleet on public port {public_port})") + asyncio.run(CoordinatorServer(coordinator)(public_host, coord_port)) + + +def create_disagg_server_app(): + """uvicorn import-string factory: build one disagg server's FastAPI app. + + Rebuilt inside each uvicorn worker process (workers=N). Fully stateless -- + reads config + coordinator URL from the env ``_serve_coordinator_and_fleet`` + exported (see ``DisaggWorkerEnvs``); the server holds a remote + ``CoordinatorClient`` so routing/readiness are delegated to the coordinator. + """ + # A worker is a plain HTTP process, never an MPI rank; drop WEB_CONCURRENCY so + # it is never itself re-forked into multiple uvicorn workers. + os.environ.pop("WEB_CONCURRENCY", None) + if os.getenv("TRTLLM_DISAGG_SERVER_DISABLE_GC", "1") == "1": + gc.disable() + + # All N fleet workers share one stdout; tag every trtllm log line from this + # worker with its PID so interleaved fleet output is attributable. + import logging as _logging + for _h in _logging.getLogger("TRT-LLM").handlers: + _h.setFormatter(_logging.Formatter( + fmt=f"[%(asctime)s] [fleet-worker pid={os.getpid()}] %(message)s", + datefmt="%m/%d/%Y-%H:%M:%S")) + + config_file = os.environ[DisaggWorkerEnvs.TLLM_DISAGG_CONFIG_FILE] + coordinator_url = os.environ[DisaggWorkerEnvs.TLLM_DISAGG_COORDINATOR_URL] + metadata_config_file = os.environ.get( + DisaggWorkerEnvs.TLLM_DISAGG_METADATA_CONFIG_FILE) + request_timeout = int(os.environ.get( + DisaggWorkerEnvs.TLLM_DISAGG_REQUEST_TIMEOUT, "180")) + server_start_timeout = int(os.environ.get( + DisaggWorkerEnvs.TLLM_DISAGG_SERVER_START_TIMEOUT, "180")) + + disagg_cfg = parse_disagg_config_file(config_file) + metadata_server_cfg = parse_metadata_server_config_file( + metadata_config_file) + + server = OpenAIDisaggServer( + config=disagg_cfg, + req_timeout_secs=request_timeout, + server_start_timeout_secs=server_start_timeout, + metadata_server_cfg=metadata_server_cfg, + coordinator_url=coordinator_url) + logger.info(f"Disagg server app built, coordinator={coordinator_url}") + return server.app + + def set_cuda_device(): if (os.getenv("OMPI_COMM_WORLD_RANK")): env_global_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) @@ -1367,6 +1544,16 @@ class DisaggLauncherEnvs(StrEnum): TLLM_DISAGG_ROLE = "TRTLLM_DISAGG_ROLE" +class DisaggWorkerEnvs(StrEnum): + # Passed from the `disaggregated` coordinator to the forked worker fleet + # (uvicorn workers=N) via env, then read by create_disagg_server_app in each worker. + TLLM_DISAGG_COORDINATOR_URL = "TRTLLM_DISAGG_COORDINATOR_URL" + TLLM_DISAGG_CONFIG_FILE = "TRTLLM_DISAGG_CONFIG_FILE" + TLLM_DISAGG_METADATA_CONFIG_FILE = "TRTLLM_DISAGG_METADATA_CONFIG_FILE" + TLLM_DISAGG_REQUEST_TIMEOUT = "TRTLLM_DISAGG_REQUEST_TIMEOUT" + TLLM_DISAGG_SERVER_START_TIMEOUT = "TRTLLM_DISAGG_SERVER_START_TIMEOUT" + + def _launch_disaggregated_server(disagg_config_file: str, llm_args: dict): # Launching the server instance_idx = os.environ.get(DisaggLauncherEnvs.TLLM_DISAGG_INSTANCE_IDX) diff --git a/tensorrt_llm/llmapi/disagg_utils.py b/tensorrt_llm/llmapi/disagg_utils.py index a93325ee60e7..42835e04fd4e 100644 --- a/tensorrt_llm/llmapi/disagg_utils.py +++ b/tensorrt_llm/llmapi/disagg_utils.py @@ -94,6 +94,16 @@ class DisaggServerConfig(): # the orchestrator relays a string instead of materializing the token-id list # on its event loop. Text-only, non-harmony deployments (see _get_ctx_request). gen_tokids_ctxbytes: bool = False + # Number of uvicorn disagg-server worker processes to fork on the public port. + # >1 means a fleet of delegating servers behind one coordinator. Replaces the + # WEB_CONCURRENCY env var (explicit config over implicit env). + num_workers: int = 1 + # URL of an already-running coordinator (e.g. "http://host:8332"). When set, + # this process does NOT start a coordinator -- the fleet delegates to this + # external one. When absent and num_workers>1, an implicit coordinator is + # started in-process. When absent and num_workers==1, a single self-contained + # server with a local (in-process) coordinator is run. + disagg_coordinator_url: Optional[str] = None @dataclass @@ -145,6 +155,8 @@ def extract_disagg_cfg(hostname: str = 'localhost', 'generation_first'] = 'context_first', gen_strip_message_history: bool = False, gen_tokids_ctxbytes: bool = False, + num_workers: int = 1, + disagg_coordinator_url: Optional[str] = None, **kwargs: Any) -> DisaggServerConfig: context_servers = context_servers or {} generation_servers = generation_servers or {} @@ -194,6 +206,8 @@ def extract_disagg_cfg(hostname: str = 'localhost', config.schedule_style = schedule_style config.gen_strip_message_history = gen_strip_message_history config.gen_tokids_ctxbytes = gen_tokids_ctxbytes + config.num_workers = num_workers + config.disagg_coordinator_url = disagg_coordinator_url return config diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 995298b3060b..727152fdf4fe 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -169,7 +169,10 @@ def __init__(self, self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) self._orchestrator_type = kwargs.get("orchestrator_type", None) - self._llm_id = None + hostname = socket.gethostname() + pid = os.getpid() + timestamp = int(time.time() * 1000) + self._llm_id = f"{hostname}-{pid}-{timestamp}" self._disaggregated_params: Optional[dict] = None log_level = logger.level @@ -218,6 +221,8 @@ def __init__(self, revision=revision, tokenizer_revision=tokenizer_revision, **kwargs) + if hasattr(self.args, 'llm_id'): + self.args.llm_id = self._llm_id except Exception as e: logger.error( @@ -272,6 +277,9 @@ def __init__(self, self.llm_build_stats = LlmBuildStats() self._build_model() + if self._executor is not None: + self._executor.llm_id = self._llm_id + except Exception: if self.mpi_session is not None: self.mpi_session.shutdown() @@ -311,12 +319,6 @@ def __init__(self, @property @set_api_status("beta") def llm_id(self) -> str: - if self._llm_id is None: - hostname = socket.gethostname() - pid = os.getpid() - timestamp = int(time.time() * 1000) - self._llm_id = f"{hostname}-{pid}-{timestamp}" - return self._llm_id @property diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 1cdc3756306e..124f328af867 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2793,6 +2793,23 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): description= "Whether KV cache manager v2 uses SWA scratch reuse during prefill.") + # Per-rank centralized KV cache routing + centralized_router_report_address: Optional[str] = Field( + default=None, + status="prototype", + description= + "ZMQ address to push KV cache events to for centralized routing " + "(e.g. 'tcp://orchestrator:5557'). When set with per_rank_routing=True, " + "each DP rank reports its own events independently.") + + per_rank_routing: bool = Field( + default=False, + status="prototype", + description= + "Enable per-rank KV cache event reporting. Each DP rank reports its " + "own cache blocks independently (no allgather). Requires " + "use_kv_cache_manager_v2=True and event_buffer_max_size > 0.") + def _to_pybind(self): config = _KvCacheConfig( enable_block_reuse=self.enable_block_reuse, @@ -4213,6 +4230,12 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) + llm_id: Optional[str] = Field( + default=None, + description="Stable instance identifier propagated to all ranks via MPI broadcast.", + status="prototype", + ) + # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) diff --git a/tensorrt_llm/serve/coordinator_server.py b/tensorrt_llm/serve/coordinator_server.py new file mode 100644 index 000000000000..a7c7750d4d36 --- /dev/null +++ b/tensorrt_llm/serve/coordinator_server.py @@ -0,0 +1,122 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Coordinator HTTP server for disaggregated serving. + +One coordinator process owns all cluster state (routers, readiness, worker +events, and -- for the centralized router -- the single ZMQ event-ingest bind) +and answers the internal coordination API that the forked worker processes call: + + POST /select {"role", "routing_key", "req_id", "exclude_server"} + -> {"server": "host:port", "info": {...}, "req_id": } + POST /finish {"role", "req_id", "success"} -> {} + GET /cluster_info -> {...} + GET /health -> 200 when ready + GET /version + +The routing key is produced worker-side by ``Router.routing_key`` and consumed +here by ``Router.get_next_server_by_key`` (see ``serve/router.py``), so this +endpoint is generic across the stateful router types that use it (centralized -> +block hashes, conversation -> conversation_id). Single-process by design; it owns +the ZMQ ingest bind for centralized mode. +""" + +import asyncio +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response + +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.disagg_coordinator import DisaggCoordinatorService +from tensorrt_llm.version import __version__ as VERSION + +TIMEOUT_KEEP_ALIVE = 10 # seconds + + +class CoordinatorServer: + """Serve a :class:`DisaggCoordinatorService`'s coordination API over HTTP.""" + + def __init__(self, coordinator: DisaggCoordinatorService) -> None: + self._coordinator = coordinator + + @asynccontextmanager + async def lifespan(app: FastAPI): + await self._coordinator.start() + yield + await self._coordinator.stop() + + self.app = FastAPI(lifespan=lifespan) + self.app.add_api_route("/select", self.select, methods=["POST"]) + self.app.add_api_route("/finish", self.finish, methods=["POST"]) + self.app.add_api_route("/cluster_info", self.cluster_info, + methods=["GET"]) + self.app.add_api_route("/health", self.health, methods=["GET"]) + self.app.add_api_route("/version", self.version, methods=["GET"]) + + async def select(self, raw_req: Request) -> Response: + try: + body = await raw_req.json() + except Exception as e: + return JSONResponse(status_code=400, + content={"error": f"invalid JSON body: {e}"}) + if not isinstance(body, dict) or "role" not in body: + return JSONResponse( + status_code=400, + content={"error": "body must include 'role' and 'routing_key'"}) + try: + server, info, req_id = await self._coordinator.select( + body["role"], body.get("routing_key"), body.get("req_id"), + body.get("exclude_server")) + except ValueError as e: + return JSONResponse(status_code=503, content={"error": str(e)}) + except Exception as e: # noqa: BLE001 + logger.error(f"CoordinatorServer.select failed: {e}") + return JSONResponse(status_code=500, content={"error": str(e)}) + return JSONResponse(content={"server": server, "info": info, + "req_id": req_id}) + + async def finish(self, raw_req: Request) -> Response: + try: + body = await raw_req.json() + except Exception as e: + return JSONResponse(status_code=400, + content={"error": f"invalid JSON body: {e}"}) + await self._coordinator.finish(body.get("role", "gen"), + body.get("req_id"), + body.get("success", True)) + return JSONResponse(content={}) + + async def cluster_info(self) -> Response: + return JSONResponse(content=await self._coordinator.cluster_info()) + + async def health(self) -> Response: + return Response(status_code=200 if await self._coordinator.is_ready() + else 503) + + async def version(self) -> Response: + return JSONResponse(content={"version": VERSION}) + + async def __call__(self, host: str, port: int) -> None: + # Single-process by design: owns routing state + the centralized ZMQ + # ingest bind. workers=1 forced so a leaked WEB_CONCURRENCY can't fork it. + config = uvicorn.Config(self.app, host=host, port=port, workers=1, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + await uvicorn.Server(config).serve() + + +def serve_coordinator(host: str, port: int, + coordinator: DisaggCoordinatorService) -> None: + asyncio.run(CoordinatorServer(coordinator)(host, port)) diff --git a/tensorrt_llm/serve/disagg_coordinator.py b/tensorrt_llm/serve/disagg_coordinator.py new file mode 100644 index 000000000000..7ef705fb32fc --- /dev/null +++ b/tensorrt_llm/serve/disagg_coordinator.py @@ -0,0 +1,425 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Coordination for disaggregated serving. + +A :class:`DisaggCoordinator` owns everything that is *not* a completion: the +ctx/gen routers, readiness, cluster info, and worker/auto-scaling events. The +completions service holds one and reads ``ctx_router`` / ``gen_router`` off it, +then drives ``router.get_next_server`` / ``router.finish_request`` uniformly -- +so serving a completion is decoupled from managing the cluster and is identical +whether this process owns the routers or delegates to a remote coordinator. + +Two implementations for the coordinator/worker deployment: + +* :class:`DisaggCoordinatorService` -- runs in the coordinator (and in the + collapsed single-process path). Owns the real ctx/gen ``Router`` objects, + server preparation/monitoring, the auto-scaling ``DisaggClusterManager`` + + worker events, and readiness. Its :meth:`select` / :meth:`finish` are the + coordinator's ``/select`` / ``/finish`` handlers. +* :class:`CoordinatorClient` -- runs in each forked worker. Stateful routers + (conversation, centralized) are wrapped in a :class:`CoordinatorDelegatingRouter` + that posts the routing key to ``/select`` (finish -> ``/finish``); stateless + routers (round_robin, load_balancing) place locally in the worker. Readiness / + cluster_info proxy the coordinator over HTTP. +""" + +import asyncio +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple + +import aiohttp + +from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, + MetadataServerConfig, ServerRole, + get_ctx_gen_server_addrs) +from tensorrt_llm.logger import logger +from tensorrt_llm.serve.cluster_storage import (ClusterStorage, WatchEventType, + create_cluster_storage) +from tensorrt_llm.serve.disagg_auto_scaling import (DisaggClusterManager, + WorkerInfo) +from tensorrt_llm.serve.metadata_server import create_metadata_server +from tensorrt_llm.serve.openai_client import OpenAIClient +from tensorrt_llm.serve.router import (CoordinatorDelegatingRouter, Router, + build_disagg_routers) + +__all__ = [ + "DisaggCoordinator", + "DisaggCoordinatorService", + "CoordinatorClient", +] + + +class DisaggCoordinator(ABC): + """Abstract coordinator: ctx/gen routers + readiness + cluster info + lifecycle. + + Placement and finish are driven through ``ctx_router`` / ``gen_router`` + (``Router.get_next_server`` / ``Router.finish_request``), so this surface only + exposes the routers plus readiness/info/lifecycle. + """ + + @property + @abstractmethod + def ctx_router(self) -> Router: + ... + + @property + @abstractmethod + def gen_router(self) -> Router: + ... + + @abstractmethod + async def is_ready(self) -> bool: + ... + + @abstractmethod + async def cluster_info(self) -> Dict[str, Any]: + ... + + async def start(self) -> None: + ... + + async def stop(self) -> None: + ... + + +class DisaggCoordinatorService(DisaggCoordinator): + """In-process coordinator owning the ctx/gen routers and all cluster state. + + Used in the coordinator process and in the single-process (workers==1) path. + """ + + def __init__( + self, + config: DisaggServerConfig, + client_factory, + metadata_config: Optional[MetadataServerConfig] = None, + server_preparation_func=None, + server_start_timeout_secs: int = 180, + health_check_interval_secs: int = 3, + ): + self._config = config + self._client_factory = client_factory + self._metadata_config = metadata_config + # The coordinator OWNS routing state, so it builds the owner routers here + # (is_delegating_client=False): a centralized deployment gets ONE shared + # namespace-aware core and starts its single ZMQ ingest server exactly + # once. This is the sole place owner routers are created -- the fleet + # worker holds no owner router, only a CoordinatorClient's delegating + # surfaces. server_preparation_func (e.g. steady-clock sync) is wired into + # the routers at build time. + self._metadata_server = create_metadata_server(metadata_config) + ctx_servers, gen_servers = get_ctx_gen_server_addrs( + config.server_configs) + self._ctx_router, self._gen_router = build_disagg_routers( + config.ctx_router_config, config.gen_router_config, + ctx_servers, gen_servers, metadata_config, self._metadata_server, + server_preparation_func, disagg_node_id=config.node_id, + is_delegating_client=False) + # The coordinator owns the disagg cluster storage (auto-scaling backend): + # it drives the DisaggClusterManager below and, when the storage is an + # in-process HTTP server, its routes are mounted on the coordinator app. + self._cluster_storage: Optional[ClusterStorage] = ( + create_cluster_storage(config.disagg_cluster_config.cluster_uri, + config.disagg_cluster_config.cluster_name) + if config.disagg_cluster_config else None) + self._server_start_timeout_secs = server_start_timeout_secs + self._health_check_interval_secs = health_check_interval_secs + + self._ctx_client: Optional[OpenAIClient] = None + self._gen_client: Optional[OpenAIClient] = None + self._disagg_cluster_manager: Optional[DisaggClusterManager] = None + + @property + def ctx_router(self) -> Router: + return self._ctx_router + + @property + def gen_router(self) -> Router: + return self._gen_router + + @property + def cluster_storage(self) -> Optional[ClusterStorage]: + return self._cluster_storage + + def set_clients(self, ctx_client: OpenAIClient, + gen_client: OpenAIClient) -> None: + self._ctx_client = ctx_client + self._gen_client = gen_client + + # -- coordinator-path placement (workers call these via the HTTP server) -- + + async def select(self, role: str, routing_key, req_id, + exclude_server: Optional[str]) -> Tuple[str, dict, Optional[str]]: + router = self._router_for_role(role) + return await router.get_next_server_by_key(routing_key, req_id=req_id, + exclude_server=exclude_server) + + async def finish(self, role: str, req_id, + success: bool = True) -> None: + await self._router_for_role(role).finish_request_by_id(req_id, success) + + def _router_for_role(self, role: str) -> Router: + return (self._ctx_router + if str(role).lower().startswith("c") else self._gen_router) + + async def start(self) -> None: + await self._ctx_router.prepare_servers() + await self._gen_router.prepare_servers() + if self._ctx_client is None or self._gen_client is None: + self._ctx_client = self._client_factory( + self._ctx_router, ServerRole.CONTEXT, self._config.max_retries) + self._gen_client = self._client_factory( + self._gen_router, ServerRole.GENERATION, + self._config.max_retries) + + if self._config.disagg_cluster_config and self._cluster_storage: + logger.info("Starting disagg cluster manager") + self._disagg_cluster_manager = DisaggClusterManager( + self._config.disagg_cluster_config, self._cluster_storage) + await self._disagg_cluster_manager.start() + await self._disagg_cluster_manager.watch_workers( + on_event=self._on_worker_event) + logger.info("Disagg cluster manager started") + else: + if self._metadata_server and self._metadata_config: + logger.info("Starting server monitoring via metadata service") + await self._ctx_router.start_server_monitoring( + self._metadata_config.refresh_interval) + await self._gen_router.start_server_monitoring( + self._metadata_config.refresh_interval) + await self._wait_for_all_servers_ready() + + async def stop(self) -> None: + if self._disagg_cluster_manager: + await self._disagg_cluster_manager.stop() + if self._metadata_server: + await self._ctx_router.stop_server_monitoring() + await self._gen_router.stop_server_monitoring() + + async def is_ready(self) -> bool: + if self._disagg_cluster_manager: + return await self._disagg_cluster_manager.is_ready_with_router( + self._ctx_router.num_prepared_servers, + self._gen_router.num_prepared_servers, + ) + return True + + async def cluster_info(self) -> Dict[str, Any]: + info = {"is_ready": await self.is_ready()} + # Expose the block-hash granularity so a delegating client hashes with + # the SAME tokens_per_block as the workers (the owner adopts it from + # /server_info). Clients don't monitor servers, so they can't learn it + # otherwise. + tpb = getattr(self._ctx_router, "_tokens_per_block", None) + if tpb is not None: + info["tokens_per_block"] = tpb + if self._disagg_cluster_manager: + info.update(await self._disagg_cluster_manager.cluster_info()) + return info + + async def _wait_for_all_servers_ready(self) -> None: + import os + gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1" + + async def check_servers_ready(): + elapsed_time = 0 + interval = self._health_check_interval_secs + while elapsed_time < self._server_start_timeout_secs: + if gen_only: + unready_ctx_servers = [] + else: + _, unready_ctx_servers = await self._ctx_client.check_ready() + _, unready_gen_servers = await self._gen_client.check_ready() + if len(unready_ctx_servers) == 0 and len( + unready_gen_servers) == 0: + logger.info("All servers are ready" if not gen_only else + "Generation servers are ready (context skipped)") + return + logger.info( + f"Waiting for servers, context: {unready_ctx_servers}, " + f"generation: {unready_gen_servers}") + await asyncio.sleep(interval) + elapsed_time += interval + + try: + await asyncio.wait_for(check_servers_ready(), + timeout=self._server_start_timeout_secs) + except asyncio.TimeoutError: + raise TimeoutError( + "Timeout waiting for context and generation servers to be ready") + + async def _on_worker_event(self, worker_info: WorkerInfo, + event_type: WatchEventType): + router_map = { + ServerRole.CONTEXT: self._ctx_router, + ServerRole.GENERATION: self._gen_router, + } + worker_addr = f"{worker_info.host}:{worker_info.port}" + try: + router = router_map[worker_info.role] + if event_type == WatchEventType.SET: + await router.add_server(worker_addr) + elif event_type == WatchEventType.DELETE: + await router.remove_server(worker_addr) + logger.info(f"Worker {event_type.name} event: " + f"{worker_info.worker_id}, {worker_addr}") + except KeyError: + logger.error( + f"Unknown worker role: {worker_info.role}, Worker " + f"{worker_info.worker_id} event: {event_type.name}") + + +class CoordinatorClient(DisaggCoordinator): + """Worker-side coordinator: delegate stateful routing to the coordinator. + + A *stateful* router (conversation, centralized -- it exposes + ``get_next_server_by_key``) is wrapped in a :class:`CoordinatorDelegatingRouter` + so the worker computes the small routing key locally and the coordinator makes + the placement (placement -> ``/select``, finish -> ``/finish``). A *stateless* + router (round_robin, load_balancing) is used as-is and places locally in the + worker -- no coordinator round-trip. Readiness / cluster_info always proxy the + coordinator over HTTP. + + Args: + remote_url: Coordinator base URL (e.g. ``http://host:PORT``). + config: The disagg config; the client builds its own delegating routers + of the configured type (same config as the coordinator so the keys it + extracts line up). + """ + + def __init__(self, remote_url: str, config: DisaggServerConfig, + metadata_config: Optional[MetadataServerConfig] = None, + request_timeout_s: float = 5.0, + startup_timeout_s: float = 180.0): + self._remote_url = remote_url.rstrip("/") + self._request_timeout_s = request_timeout_s + self._startup_timeout_s = startup_timeout_s + self._session: Optional[aiohttp.ClientSession] = None + # A delegating client builds coreless surfaces (is_delegating_client=True): + # they compute the routing key locally (routing_key()) but never bind an + # ingest port or own a core -- placement is delegated to the coordinator. + # This is the sole place delegating routers are created. + ctx_servers, gen_servers = get_ctx_gen_server_addrs( + config.server_configs) + ctx_router, gen_router = build_disagg_routers( + config.ctx_router_config, config.gen_router_config, + ctx_servers, gen_servers, metadata_config, + create_metadata_server(metadata_config), + disagg_node_id=config.node_id, is_delegating_client=True) + self._ctx_router = self._maybe_delegate(ctx_router, "context") + self._gen_router = self._maybe_delegate(gen_router, "generation") + + def _maybe_delegate(self, local_router: Router, role: str) -> Router: + # Stateful routers expose get_next_server_by_key -> delegate placement to + # the coordinator; stateless ones place locally (used unchanged). + if hasattr(local_router, "get_next_server_by_key"): + return CoordinatorDelegatingRouter(self._remote_url, local_router, + role, self._request_timeout_s) + return local_router + + @property + def ctx_router(self) -> Router: + return self._ctx_router + + @property + def gen_router(self) -> Router: + return self._gen_router + + @property + def session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = aiohttp.ClientSession() + return self._session + + async def start(self) -> None: + # Fail fast: a delegating server is useless without its coordinator. Probe + # /cluster_info with bounded retry; if it never becomes reachable within + # startup_timeout_s, raise so the server exits non-zero instead of coming + # up and 500-ing every request against a missing coordinator. + info = await self._await_coordinator() + # A delegating client doesn't monitor servers, so it can't learn the + # workers' tokens_per_block from /server_info. Fetch it from the + # coordinator (above) and apply it to the local routers' block-hashing so + # the keys the client computes line up with the coordinator/workers. + tpb = info.get("tokens_per_block") + if tpb is not None: + for router in (self._ctx_router, self._gen_router): + local = getattr(router, "_local", router) + if getattr(local, "_tokens_per_block", None) != tpb: + local._tokens_per_block = tpb + local._tpb_auto = False + logger.info( + f"CoordinatorClient: adopted coordinator " + f"tokens_per_block={tpb} for {getattr(local, '_namespace', '?')}") + + async def _await_coordinator(self) -> Dict[str, Any]: + """Poll the coordinator's /cluster_info until reachable, or raise. + + Returns the cluster_info dict once the coordinator answers HTTP 200. A 200 + means the coordinator process is up (not that all workers are ready -- + that's is_ready's job). Raises RuntimeError if it stays unreachable past + startup_timeout_s so the delegating server fails fast instead of serving + against a missing coordinator.""" + loop = asyncio.get_event_loop() + deadline = loop.time() + self._startup_timeout_s + attempt = 0 + while True: + try: + async with self.session.get( + f"{self._remote_url}/cluster_info", + timeout=self._request_timeout_s) as resp: + if resp.status == 200: + logger.info("CoordinatorClient: coordinator reachable at " + f"{self._remote_url}") + return await resp.json() + last_err = f"HTTP {resp.status}" + except Exception as e: # noqa: BLE001 + last_err = str(e) + attempt += 1 + if loop.time() >= deadline: + raise RuntimeError( + f"Coordinator at {self._remote_url} not reachable after " + f"{self._startup_timeout_s}s ({attempt} attempts, last " + f"error: {last_err}); aborting delegating server startup") + logger.info(f"CoordinatorClient: waiting for coordinator at " + f"{self._remote_url} (attempt {attempt}, {last_err})") + await asyncio.sleep(2.0) + + async def is_ready(self) -> bool: + try: + async with self.session.get( + f"{self._remote_url}/health", + timeout=self._request_timeout_s) as resp: + return resp.status == 200 + except Exception as e: # noqa: BLE001 + logger.warning(f"CoordinatorClient health check failed: {e}") + return False + + async def cluster_info(self) -> Dict[str, Any]: + try: + async with self.session.get( + f"{self._remote_url}/cluster_info", + timeout=self._request_timeout_s) as resp: + if resp.status == 200: + return await resp.json() + except Exception as e: # noqa: BLE001 + logger.warning(f"CoordinatorClient cluster_info failed: {e}") + return {"is_ready": False} + + async def stop(self) -> None: + if self._session is not None: + await self._session.close() + self._session = None + await self._ctx_router.close() + await self._gen_router.close() diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 733b11a7aed7..06c43efbad4f 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -33,7 +33,6 @@ from tensorrt_llm.llmapi import tracing from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, MetadataServerConfig, ServerRole, - get_ctx_gen_server_addrs, get_global_disagg_request_id) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer, @@ -41,6 +40,8 @@ from tensorrt_llm.serve.conversation_id import resolve_request_conversation_id from tensorrt_llm.serve.metadata_server import create_metadata_server from tensorrt_llm.serve.openai_client import OpenAIClient, OpenAIHttpClient +from tensorrt_llm.serve.disagg_coordinator import (CoordinatorClient, + DisaggCoordinatorService) from tensorrt_llm.serve.openai_disagg_service import ( OpenAIDisaggregatedService, ResponseHooks) from tensorrt_llm.serve.openai_protocol import (UCompletionRequest, @@ -48,7 +49,7 @@ from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware, get_steady_clock_now_in_seconds) -from tensorrt_llm.serve.router import Router, create_router +from tensorrt_llm.serve.router import Router from tensorrt_llm.version import __version__ as VERSION # yapf: enale @@ -85,29 +86,43 @@ def __init__(self, req_timeout_secs: int = 180, server_start_timeout_secs: int = 180, metadata_server_cfg: Optional[MetadataServerConfig] = None, - metrics_interval_secs: int = 0): + metrics_interval_secs: int = 0, + coordinator_url: Optional[str] = None): self._config = config self._req_timeout_secs = req_timeout_secs self._server_start_timeout_secs = server_start_timeout_secs self._metadata_server_cfg = metadata_server_cfg self._metrics_interval_secs = metrics_interval_secs + # When set, this is a forked worker: routing/readiness are delegated to + # the coordinator at coordinator_url (CoordinatorClient). Otherwise this + # process owns the routers + cluster state (DisaggCoordinatorService). + self._coordinator_url = coordinator_url - self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs) - self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock, disagg_node_id=config.node_id) - self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock, disagg_node_id=config.node_id) - self._metadata_server = create_metadata_server(metadata_server_cfg) self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests) - self._disagg_cluster_storage = create_cluster_storage(config.disagg_cluster_config.cluster_uri, config.disagg_cluster_config.cluster_name) if config.disagg_cluster_config else None + # The server does NOT build routers. Router ownership is decided (and the + # routers built) by the coordinator object: DisaggCoordinatorService is + # the owner (builds core + ingest); CoordinatorClient is the delegating + # client (builds coreless surfaces). The server just holds whichever one + # matches its deployment and reads .ctx_router / .gen_router off it. + if self._coordinator_url: + self._coordinator = CoordinatorClient( + self._coordinator_url, self._config, metadata_server_cfg, + request_timeout_s=self._req_timeout_secs, + startup_timeout_s=self._server_start_timeout_secs) + else: + self._coordinator = DisaggCoordinatorService( + self._config, self._create_client, + metadata_config=self._metadata_server_cfg, + server_preparation_func=self._sync_server_clock, + server_start_timeout_secs=self._server_start_timeout_secs) + self._ctx_router = self._coordinator.ctx_router + self._gen_router = self._coordinator.gen_router self._service = OpenAIDisaggregatedService( - self._config, self._ctx_router, self._gen_router, self._create_client, - metadata_server=self._metadata_server, - metadata_config=self._metadata_server_cfg, + self._config, self._coordinator, self._create_client, req_timeout_secs=self._req_timeout_secs, - server_start_timeout_secs=self._server_start_timeout_secs, - perf_metrics_collector=self._perf_metrics_collector, - disagg_cluster_storage=self._disagg_cluster_storage) + perf_metrics_collector=self._perf_metrics_collector) try: otlp_cfg = config.otlp_config @@ -122,9 +137,7 @@ def __init__(self, @asynccontextmanager async def lifespan(app) -> None: - # Prepare servers (sync server clock) when static ctx/gen server list is used - await self._ctx_router.prepare_servers() - await self._gen_router.prepare_servers() + # The cluster manager (via setup) owns server preparation + monitoring. await self._service.setup() yield await self._service.teardown() @@ -149,6 +162,9 @@ def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) return client def register_routes(self): + # The disagg service owns only the request-serving endpoints (/v1/*) and + # perf metrics. Readiness / cluster topology are the coordinator's state, + # so /health and /cluster_info hook straight to self._coordinator. self.app.add_api_route("/v1/completions", self._wrap_entry_point(self._service.openai_completion), methods=["POST"]) self.app.add_api_route("/v1/chat/completions", self._wrap_entry_point(self._service.openai_chat_completion), methods=["POST"]) self.app.add_api_route("/health", self.health, methods=["GET"]) @@ -158,8 +174,12 @@ def register_routes(self): # import prometheus_client lazily to break the `set_prometheus_multiproc_dir` from prometheus_client import make_asgi_app self.app.mount("/prometheus/metrics", make_asgi_app()) - if self._disagg_cluster_storage and isinstance(self._disagg_cluster_storage, HttpClusterStorageServer): - self._disagg_cluster_storage.add_routes(self.app) + # Single-process (local coordinator): mount the in-process HTTP cluster + # storage routes on this app. In worker mode the coordinator is remote and + # owns those routes (CoordinatorClient has no cluster_storage). + cluster_storage = getattr(self._coordinator, "cluster_storage", None) + if isinstance(cluster_storage, HttpClusterStorageServer): + cluster_storage.add_routes(self.app) @staticmethod def _extract_conversation_id(req: UCompletionRequest, raw_req: Request): @@ -205,12 +225,12 @@ def _handle_exception(self, exception): async def health(self) -> Response: - if not await self._service.is_ready(): + if not await self._coordinator.is_ready(): return Response(status_code=500) return Response(status_code=200) async def cluster_info(self) -> JSONResponse: - return JSONResponse(content=await self._service.cluster_info()) + return JSONResponse(content=await self._coordinator.cluster_info()) async def version(self) -> JSONResponse: return JSONResponse(content={"version": VERSION}) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index cf5b50af0fb3..69819bf3be62 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -20,17 +20,13 @@ from tensorrt_llm.llmapi.disagg_utils import ( ConditionalDisaggConfig, - DisaggClusterConfig, DisaggServerConfig, - MetadataServerConfig, ServerRole, get_global_disagg_request_id, ) from tensorrt_llm.logger import logger -from tensorrt_llm.serve.cluster_storage import ClusterStorage, WatchEventType -from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterManager, WorkerInfo -from tensorrt_llm.serve.metadata_server import JsonDictionary from tensorrt_llm.serve.openai_client import OpenAIClient +from tensorrt_llm.serve.disagg_coordinator import DisaggCoordinator from tensorrt_llm.serve.openai_protocol import ( ChatCompletionRequest, CompletionRequest, @@ -54,28 +50,23 @@ class OpenAIDisaggregatedService(OpenAIService): def __init__( self, config: DisaggServerConfig, - ctx_router: Router, - gen_router: Router, + coordinator: "DisaggCoordinator", client_factory: Callable[[Router, ServerRole], OpenAIClient], - metadata_server: Optional[JsonDictionary] = None, - metadata_config: Optional[MetadataServerConfig] = None, req_timeout_secs: int = 180, - server_start_timeout_secs: int = 180, perf_metrics_collector: Optional[DisaggPerfMetricsCollector] = None, - disagg_cluster_storage: Optional[ClusterStorage] = None, - health_check_interval_secs: int = 3, ): self._config = config - self._ctx_router = ctx_router - self._gen_router = gen_router + # The coordinator owns readiness, cluster info, and worker events. The + # service takes its ctx/gen routers and drives get_next_server / + # finish_request uniformly -- so serving is identical whether the router + # is the real one (single-process) or a CoordinatorDelegatingRouter that + # forwards placement to a remote coordinator (worker). + self._cluster = coordinator + self._ctx_router = coordinator.ctx_router + self._gen_router = coordinator.gen_router self._client_factory = client_factory - self._metadata_server = metadata_server - self._metadata_config = metadata_config self._req_timeout_secs = req_timeout_secs - self._server_start_timeout_secs = server_start_timeout_secs self._perf_metrics_collector = perf_metrics_collector - self._cluster_storage = disagg_cluster_storage - self._health_check_interval_secs = health_check_interval_secs # Opt-in body-shrink for generation_only requests; see _get_gen_request. self._strip_gen_message_history = config.gen_strip_message_history # Opt-in: ask context workers to return prompt_token_ids as base64 int32. @@ -83,7 +74,6 @@ def __init__( self._ctx_client = None self._gen_client = None - self._disagg_cluster_manager = None self._schedule_style = DisaggScheduleStyle.CONTEXT_FIRST match self._config.schedule_style: @@ -423,110 +413,36 @@ async def _check_gen_only_disagg(self, request: UCompletionRequest) -> bool: return True return False - async def cluster_info(self) -> Dict[str, Any]: - cluster_info = {"is_ready": await self.is_ready()} - if self._disagg_cluster_manager: - cluster_info.update(await self._disagg_cluster_manager.cluster_info()) - return cluster_info - async def is_ready(self) -> bool: - if self._disagg_cluster_manager: - return await self._disagg_cluster_manager.is_ready_with_router( - self._ctx_router.num_prepared_servers, - self._gen_router.num_prepared_servers, - ) - return True - - @property - def disagg_cluster_config(self) -> Optional[DisaggClusterConfig]: - return self._config.disagg_cluster_config + # Per-request readiness gate for the /v1/ handlers (the server's /health + # and /cluster_info hook the coordinator directly). Cluster topology + # (cluster_info) is the coordinator's concern, not the request service's. + return await self._cluster.is_ready() @property def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]: return self._config.conditional_disagg_config async def setup(self) -> None: + # Build the request-sending clients from the coordinator's routers + # (worker-mode: get_next_server / finish_request on those routers proxy + # the coordinator). Share them with the coordinator service so it can run + # readiness checks against the same pool (no-op on CoordinatorClient). self._ctx_client = self._client_factory( self._ctx_router, ServerRole.CONTEXT, self._config.max_retries ) self._gen_client = self._client_factory( - self._gen_router, ServerRole.GENERATION, self._config.max_retries + self._gen_router, ServerRole.GENERATION, + self._config.max_retries ) - - if self.disagg_cluster_config and self._cluster_storage: - logger.info("Starting disagg cluster manager") - self._disagg_cluster_manager = DisaggClusterManager( - self.disagg_cluster_config, self._cluster_storage - ) - await self._disagg_cluster_manager.start() - await self._disagg_cluster_manager.watch_workers(on_event=self._on_worker_event) - logger.info("Disagg cluster manager started") - else: - if self._metadata_server and self._metadata_config: - logger.info("Starting server monitoring via metadata service") - await self._ctx_router.start_server_monitoring( - self._metadata_config.refresh_interval - ) - await self._gen_router.start_server_monitoring( - self._metadata_config.refresh_interval - ) - await self._wait_for_all_servers_ready() + if hasattr(self._cluster, "set_clients"): + self._cluster.set_clients(self._ctx_client, self._gen_client) + await self._cluster.start() async def teardown(self) -> None: await self._ctx_client.shutdown() await self._gen_client.shutdown() - - if self._disagg_cluster_manager: - await self._disagg_cluster_manager.stop() - - if self._metadata_server: - await self._ctx_router.stop_server_monitoring() - await self._gen_router.stop_server_monitoring() - - async def _wait_for_all_servers_ready(self) -> None: - # Skip context servers if TRTLLM_DISAGG_BENCHMARK_GEN_ONLY is set - gen_only = os.getenv("TRTLLM_DISAGG_BENCHMARK_GEN_ONLY") == "1" - - async def check_servers_ready(): - elapsed_time = 0 - interval = self._health_check_interval_secs - while elapsed_time < self._server_start_timeout_secs: - if gen_only: - unready_ctx_servers = [] - else: - _, unready_ctx_servers = await self._ctx_client.check_ready() - _, unready_gen_servers = await self._gen_client.check_ready() - if len(unready_ctx_servers) == 0 and len(unready_gen_servers) == 0: - if gen_only: - logger.info("Generation servers are ready (context servers skipped)") - else: - logger.info("All servers are ready") - return - logger.info( - f"Waiting for servers, context: {unready_ctx_servers}, generation: {unready_gen_servers}" - ) - await asyncio.sleep(interval) - elapsed_time += interval - - try: - await asyncio.wait_for(check_servers_ready(), timeout=self._server_start_timeout_secs) - except asyncio.TimeoutError: - raise TimeoutError("Timeout waiting for context and generation servers to be ready") - - async def _on_worker_event(self, worker_info: WorkerInfo, event_type: WatchEventType): - router_map = {ServerRole.CONTEXT: self._ctx_router, ServerRole.GENERATION: self._gen_router} - worker_addr = f"{worker_info.host}:{worker_info.port}" - try: - router = router_map[worker_info.role] - if event_type == WatchEventType.SET: - await router.add_server(worker_addr) - elif event_type == WatchEventType.DELETE: - await router.remove_server(worker_addr) - logger.info(f"Worker {event_type.name} event: {worker_info.worker_id}, {worker_addr}") - except KeyError: - logger.error( - f"Unknown worker role: {worker_info.role}, Worker {worker_info.worker_id} event: {event_type.name}" - ) + await self._cluster.stop() async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None: if ctx_response: @@ -566,7 +482,8 @@ async def _send_disagg_request_gen_first( ctx_req, gen_req = None, None disagg_request_id = get_global_disagg_request_id(self._config.node_id) if need_ctx: - ctx_server, ctx_server_info = await self._ctx_router.get_next_server(request) + ctx_server, ctx_server_info = await self._ctx_router.get_next_server( + request) ctx_req = self._get_ctx_request(request, disagg_request_id) gen_req = self._get_gen_request( request, diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index cfae37de17f5..8cd63f8723a4 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -245,6 +245,9 @@ def __init__( self.disagg_cluster_worker = None self.resource_governor = None + self._active_request_count = 0 + self._active_request_count_lock = __import__('threading').Lock() + # Skip loading AutoProcessor and model_config for VISUAL_GEN models # These are LLM-specific and can cause unnecessary memory usage if self._is_visual_gen: @@ -342,6 +345,21 @@ async def validation_exception_handler(_, exc): self.app.add_middleware(ServerArrivalTimeMiddleware) + # Track active /v1/ requests (in-flight request gauge). + server_ref = self + + @self.app.middleware("http") + async def _track_active_requests(request, call_next): + if request.url.path.startswith("/v1/"): + with server_ref._active_request_count_lock: + server_ref._active_request_count += 1 + try: + return await call_next(request) + finally: + with server_ref._active_request_count_lock: + server_ref._active_request_count -= 1 + return await call_next(request) + def _get_iteration_stats_buffer_maxlen(self) -> Optional[int]: if isinstance(self.generator, VisualGen): return None @@ -1908,6 +1926,9 @@ async def update_weights(self, async def get_server_info(self) -> JSONResponse: content = {"disaggregated_params": self.generator.disaggregated_params} + llm_id = getattr(self.generator, "llm_id", None) + if llm_id is not None: + content["worker_id"] = llm_id args = getattr(self.generator, "args", None) if args is not None: if args.max_batch_size is not None: diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 4673afbd4c7f..153b4693e84e 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -20,12 +20,7 @@ from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union import aiohttp -from transformers import AutoTokenizer -from tensorrt_llm.bindings.internal.batch_manager import \ - BlockKey as _NativeBlockKey -from tensorrt_llm.bindings.internal.batch_manager import \ - BlockKeyHasher as _NativeBlockKeyHasher from tensorrt_llm.llmapi.disagg_utils import (MetadataServerConfig, RouterConfig, ServerRole) from tensorrt_llm.logger import logger @@ -40,17 +35,13 @@ from tensorrt_llm.serve.metadata_server import JsonDictionary from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, CompletionRequest) - -KV_CACHE_HASH_ALGO_DEFAULT = kv_cache_hash.KV_CACHE_HASH_ALGO_DEFAULT -KV_CACHE_HASH_ALGO_V1 = kv_cache_hash.KV_CACHE_HASH_ALGO_V1 -KV_CACHE_HASH_ALGO_V2 = kv_cache_hash.KV_CACHE_HASH_ALGO_V2 -KV_CACHE_HASH_ALGO_V2_SHA256_64 = kv_cache_hash.KV_CACHE_HASH_ALGO_V2_SHA256_64 -get_cache_salt_id = kv_cache_hash.get_cache_salt_id -hash_v1_block_key = kv_cache_hash.hash_v1_block_key -truncate_sha256_hash_to_int64 = kv_cache_hash.truncate_sha256_hash_to_int64 - -OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest] -BlockHash = Union[int, str] +# Shared tokenization / block-hashing utilities (single source of truth). +# Re-exported here for backward compat. +from tensorrt_llm.serve.router_utils import ( # noqa: F401 + KV_CACHE_HASH_ALGO_DEFAULT, KV_CACHE_HASH_ALGO_V1, KV_CACHE_HASH_ALGO_V2, + KV_CACHE_HASH_ALGO_V2_SHA256_64, BlockHash, BlockHashMixin, OpenAIRequest, + block_key_hasher, get_cache_salt_id, get_request_num_tokens, + hash_v1_block_key, truncate_sha256_hash_to_int64, v2_sha256_block_hasher) # Max number of conversations whose home-server pin is retained (LRU). ROUTE_AFFINITY_CACHE_SIZE = 50000 @@ -59,32 +50,6 @@ ROUTE_AFFINITY_TOKEN_PREFIX = 256 -def get_request_num_tokens(request: OpenAIRequest) -> int: - if request.disaggregated_params is None or request.disaggregated_params.request_type == "context_only": - if isinstance(request, ChatCompletionRequest): - raise ValueError( - "LoadBalancing router with tokens doesn't support ChatCompletionRequest yet" - ) - - if isinstance(request.prompt, str) or \ - (isinstance(request.prompt, list) and isinstance(request.prompt[0], int)): - prompts = [request.prompt] - else: - prompts = request.prompt - - num_tokens = sum(len(prompt) for prompt in prompts) - elif request.disaggregated_params.request_type == "generation_only": - raise ValueError( - "LoadBalancing router with tokens doesn't support generation_only requests" - ) - else: - raise ValueError( - f"Unsupported request type: {request.disaggregated_params.request_type}" - ) - - return num_tokens - - class ServerState: def __init__( @@ -107,13 +72,17 @@ def _session(self) -> Optional[aiohttp.ClientSession]: return self._session_provider() if self._session_provider else None async def increment_load(self, request: OpenAIRequest): - num_tokens = get_request_num_tokens(request) if self._use_tokens else 0 + # request may be None on the coordinator-delegated path (no request + # object crosses the HTTP hop); token accounting is skipped then. + num_tokens = (get_request_num_tokens(request) + if self._use_tokens and request is not None else 0) async with self._lock: self._num_active_requests += 1 self._num_active_tokens += num_tokens async def decrement_load(self, request: OpenAIRequest): - num_tokens = get_request_num_tokens(request) if self._use_tokens else 0 + num_tokens = (get_request_num_tokens(request) + if self._use_tokens and request is not None else 0) async with self._lock: self._num_active_requests -= 1 self._num_active_tokens -= num_tokens @@ -141,6 +110,7 @@ def __init__( self._kv_cache_block_tables: dict[str, set[BlockHash]] = { KV_CACHE_HASH_ALGO_V1: self._kv_cache_block_table } + self._event_only_blocks: set[BlockHash] = set() self._kv_cache_hash_algo = KV_CACHE_HASH_ALGO_DEFAULT self._tokens_per_block = tokens_per_block self._poll_task: Optional[asyncio.Task] = None @@ -198,32 +168,50 @@ def update_with_events(self, events: Iterable[dict]): if event["type"] == "created": self.set_hash_algo(hash_algo) if event["type"] == "stored": - self.add_blocks( - (block["block_hash"] for block in event["blocks"]), - hash_algo=hash_algo) + block_hashes = [block["block_hash"] for block in event["blocks"]] + self.add_blocks(block_hashes, hash_algo=hash_algo) + self._event_only_blocks.update(block_hashes) elif event["type"] == "removed": self.remove_blocks(event["block_hashes"], hash_algo=hash_algo) + self._event_only_blocks.difference_update(event["block_hashes"]) async def poll_events(self, session: aiohttp.ClientSession): async with session.post( f"{self._base_url}/kv_cache_events") as response: events_raw = await response.json() + # DIAG: confirm which servers actually get polled and how many events + # each returns (diagnoses single-server block-table population). + logger.info( + f"POLL_DIAG server={self._server} " + f"n_events={len(events_raw) if events_raw is not None else 'None'}") return events_raw + _event_match_log_counter = 0 + async def matched_tokens( self, block_hashes: list[list[BlockHash]], hash_algo: str = KV_CACHE_HASH_ALGO_DEFAULT) -> int: match_count = 0 + event_match_count = 0 async with self._lock: block_table = self._block_table(hash_algo) for hash_list in block_hashes: for block_hash in hash_list: - # TODO: 1) parent hash verification, 2) partial matching if block_hash in block_table: match_count += self._tokens_per_block + if block_hash in self._event_only_blocks: + event_match_count += self._tokens_per_block else: break + KvCacheAwareServerState._event_match_log_counter += 1 + if KvCacheAwareServerState._event_match_log_counter <= 20 or KvCacheAwareServerState._event_match_log_counter % 100 == 0: + logger.info( + f"EVENT_MATCH_DIAG server={self._server} " + f"total_match={match_count} event_match={event_match_count} " + f"event_blocks={len(self._event_only_blocks)} " + f"total_blocks={len(block_table)} " + f"query_hashes={[h for hl in block_hashes for h in hl][:3]}") return match_count async def decrement_load(self, request: OpenAIRequest): @@ -369,6 +357,34 @@ def __init__( self._health_check_timeout = metadata_server_cfg.health_check_timeout if metadata_server_cfg else None self._server_preparation_func = server_preparation_func self._prepared_ready_servers: set[str] = set() + # Routing-latency diagnostics (gated by TLLM_LOG_ROUTE_TIMING=1). Records + # wall time spent in get_next_server per request and logs percentiles + # periodically. Lets us compare the per-request routing cost across + # router types. + import os + self._log_route_timing = ( + os.environ.get("TLLM_LOG_ROUTE_TIMING", "0") == "1") + self._rt_samples: list = [] + self._rt_n = 0 + + def _record_route_timing(self, dt_s: float) -> None: + """Record one get_next_server latency sample; log percentiles every + 500 calls. No-op unless TLLM_LOG_ROUTE_TIMING=1.""" + if not self._log_route_timing: + return + self._rt_samples.append(dt_s * 1000.0) # ms + self._rt_n += 1 + if self._rt_n % 500 == 0: + import statistics + s = sorted(self._rt_samples) + n = len(s) + p = lambda q: s[min(int(q * n), n - 1)] + logger.info( + f"[route_timing] {type(self).__name__} n={self._rt_n} " + f"get_next_server_ms: mean={statistics.mean(s):.2f} " + f"p50={p(0.5):.2f} p90={p(0.9):.2f} p99={p(0.99):.2f} " + f"max={s[-1]:.2f}") + self._rt_samples = [] # reset window async def close(self): """Close the shared HTTP session.""" @@ -774,226 +790,6 @@ async def finish_request(self, await self._unregister_request(request) -def block_key_hasher(token_ids: list[int], - parent_hash: Optional[int] = None, - cache_salt_id: Optional[int] = None) -> int: - parent = 0 if parent_hash is None else parent_hash - # Fast path: the native C++ BlockKeyHasher is bit-exact with - # hash_v1_block_key and avoids the per-token Python loop. Its hash() binding - # takes no cache_salt_id, so fall back to Python only when a salt is set - # (rare opt-in; never in the unsalted agent/chat completion path). - if cache_salt_id is None: - return _NativeBlockKeyHasher.hash(_NativeBlockKey(token_ids), parent) - return hash_v1_block_key(token_ids, - parent_hash=parent, - cache_salt_id=cache_salt_id) - - -def v2_sha256_block_hasher(token_ids: list[int], - parent_hash: Optional[str] = None, - cache_salt_id: Optional[int] = None) -> str: - parent_key = (V2RootBlock.make_key(ReuseScope(salt=cache_salt_id)) - if parent_hash is None else bytes.fromhex(parent_hash)) - return V2Block.make_key(parent_key, token_ids).hex() - - -class BlockHashMixin: - """Shared tokenization and block-hash computation. - - Used by routers that need KV-cache-aware prefix matching. - """ - - def _init_block_hashing(self, - tokens_per_block: Optional[int] = None, - custom_tokenizer: Optional[str] = None): - env_tokens_per_block = os.environ.get( - "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK") - if env_tokens_per_block is not None: - tokens_per_block = int(env_tokens_per_block) - self._tpb_auto = tokens_per_block is None - self._tokens_per_block = 32 if tokens_per_block is None \ - else tokens_per_block - self._tokenizers: dict = {} - self._custom_tokenizer = custom_tokenizer - logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}" - f"{' (auto, adopts worker)' if self._tpb_auto else ''}" - f", custom_tokenizer={self._custom_tokenizer}") - - def _get_tokenizer(self, model: str): - if model not in self._tokenizers: - if self._custom_tokenizer: - from tensorrt_llm.tokenizer import load_custom_tokenizer - self._tokenizers[model] = load_custom_tokenizer( - self._custom_tokenizer, model) - else: - from tensorrt_llm.tokenizer import \ - maybe_fix_byte_level_tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model, trust_remote_code=True) - # Work around Transformers 5.x LlamaTokenizer overriding - # tokenizer.json's ByteLevel pre-tokenizer with Metaspace, - # which silently strips spaces from prompts (see tokenizer.py). - self._tokenizers[model] = maybe_fix_byte_level_tokenizer( - tokenizer, model, trust_remote_code=True) - return self._tokenizers[model] - - def _encode_with_prefix_cache(self, rendered: str, key: int, - tokenizer) -> list[int]: - cache = getattr(self, "_tok_prefix_cache", None) - if cache is None: - cache = self._tok_prefix_cache = OrderedDict() - entry = cache.get(key) - if entry is not None and len(rendered) > len(entry[0]) and \ - rendered.startswith(entry[0]): - ids = entry[1] + tokenizer.encode(rendered[len(entry[0]):], - add_special_tokens=False) - else: - ids = tokenizer.encode(rendered, add_special_tokens=False) - cache[key] = (rendered, ids) - cache.move_to_end(key) - while len(cache) > 1024: - cache.popitem(last=False) - return ids - - def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: - # Handle ChatCompletionRequest (has messages, not prompt) - if isinstance(request, ChatCompletionRequest): - if request.prompt_token_ids is not None: - return [request.prompt_token_ids] - tokenizer = self._get_tokenizer(request.model) - tool_dicts = (None if getattr(request, "tools", None) is None else [ - tool.model_dump() if hasattr(tool, "model_dump") else tool - for tool in request.tools - ]) - chat_template_kwargs = (request.chat_template_kwargs if getattr( - request, "chat_template_kwargs", None) else {}) - rendered = tokenizer.apply_chat_template( - [ - msg if isinstance(msg, dict) else dict(msg) - for msg in request.messages - ], - add_generation_prompt=request.add_generation_prompt, - tokenize=False, - return_dict=False, - tools=tool_dicts, - **chat_template_kwargs, - ) - if isinstance(rendered, str): - key = hash("".join( - str( - msg.get("content") if isinstance(msg, dict) else - getattr(msg, "content", "")) - for msg in request.messages[:2])) - result = self._encode_with_prefix_cache(rendered, key, - tokenizer) - else: - result = list(rendered) - request.prompt_token_ids = result - return [result] - - # Handle CompletionRequest (has prompt) - prompts = request.prompt - if isinstance(prompts, list) and isinstance(prompts[0], list): - return prompts - elif isinstance(prompts, list) and isinstance(prompts[0], int): - return [prompts] - elif isinstance(prompts, str): - prompts = [prompts] - else: - assert isinstance(prompts, list) and isinstance(prompts[0], str) - - tokenizer = self._get_tokenizer(request.model) - token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts] - # Replace string prompts with token IDs so the worker server - # skips re-tokenization - request.prompt = (token_lists - if len(token_lists) > 1 else token_lists[0]) - return token_lists - - def _compute_block_hashes( - self, - token_lists: list[list[int]], - hash_algo: str = KV_CACHE_HASH_ALGO_DEFAULT, - cache_salt_id: Optional[int] = None, - ) -> list[list[BlockHash]]: - if hash_algo == KV_CACHE_HASH_ALGO_V1: - block_hasher = block_key_hasher - elif hash_algo == KV_CACHE_HASH_ALGO_V2: - block_hasher = v2_sha256_block_hasher - elif hash_algo == KV_CACHE_HASH_ALGO_V2_SHA256_64: - reuse_scope = ReuseScope(salt=cache_salt_id) - block_hashes: list[list[BlockHash]] = [] - for token_list in token_lists: - hash_list = [] - parent_key = V2RootBlock.make_key(reuse_scope) - for t in range(0, len(token_list) - 1, self._tokens_per_block): - t_end = min(t + self._tokens_per_block, len(token_list) - 1) - parent_key = V2Block.make_key(parent_key, - token_list[t:t_end]) - hash_list.append(truncate_sha256_hash_to_int64(parent_key)) - block_hashes.append(hash_list) - return block_hashes - else: - raise ValueError( - f"Unsupported KV cache hash algorithm: {hash_algo}") - - block_hashes: list[list[BlockHash]] = [] - for token_list in token_lists: - hash_list = [] - # in KvCacheManager, the last token is not included in the block key - for t in range(0, len(token_list) - 1, self._tokens_per_block): - t_end = min(t + self._tokens_per_block, len(token_list) - 1) - hash_list.append( - block_hasher(token_list[t:t_end], - None if t == 0 else hash_list[-1], - cache_salt_id)) - block_hashes.append(hash_list) - return block_hashes - - def _tokenize_and_compute_block_hashes( - self, - request: OpenAIRequest) -> tuple[list[list[int]], list[list[int]]]: - """Synchronous tokenize + block-hash, combined for thread offload. - - Factored into one method so ``get_next_server`` can offload the whole - CPU-bound step via ``asyncio.to_thread`` in a single call, keeping - the orchestrator's asyncio event loop free to dispatch other - requests in parallel. - """ - token_lists = self._tokenize(request) - block_hashes = self._compute_block_hashes(token_lists) - return token_lists, block_hashes - - def _tokenize_and_compute_block_hashes_by_algo( - self, - request: OpenAIRequest, - hash_algos: Iterable[str], - cache_salt_id: Optional[int] = None, - ) -> tuple[list[list[int]], dict[str, list[list[BlockHash]]]]: - """Synchronous tokenize + per-algorithm block hashes for thread offload.""" - token_lists = self._tokenize(request) - return token_lists, { - hash_algo: - self._compute_block_hashes(token_lists, - hash_algo, - cache_salt_id=cache_salt_id) - for hash_algo in set(hash_algos) - } - - @staticmethod - def _text_to_int_sequences(texts: list[str]) -> list[list[int]]: - """Convert text strings to lists of unicode code points. - - Usable as input to ``_compute_block_hashes``. - """ - return [[ord(c) for c in text] for text in texts] - - @staticmethod - def _get_request_cache_salt_id(request: OpenAIRequest) -> Optional[int]: - cache_salt = getattr(request, "cache_salt", None) - return None if cache_salt is None else get_cache_salt_id(cache_salt) - - class KvCacheAwareRouter(BlockHashMixin, LoadBalancingMixin, Router): _server_state_class = KvCacheAwareServerState @@ -1007,19 +803,23 @@ def __init__(self, max_batch_size: int = 64, tokens_per_block: Optional[int] = None, custom_tokenizer: Optional[str] = None, + tokenizer_dir: Optional[str] = None, track_routed_blocks: bool = True, load_weight: float = 0.25, load_cap: float = float("inf"), **kwargs): super().__init__(server_role, servers, metadata_server_cfg, metadata_server, **kwargs) - self._init_block_hashing(tokens_per_block, custom_tokenizer) + self._init_block_hashing(tokens_per_block, custom_tokenizer, + tokenizer_dir) self._init_load_balancing(servers, use_tokens) # TODO: use max_num_tokens? per server? self._max_batch_size = max_batch_size self._load_weight = load_weight self._load_cap = load_cap self._track_routed_blocks = track_routed_blocks + # request key -> (flat block hashes, hash_algo). Key is id(request) on the + # standalone path, the disagg req_id on the coordinator path. self._pending_routed_blocks: dict[int, tuple[list[BlockHash], str]] = {} def _create_server_state(self, server: str) -> KvCacheAwareServerState: @@ -1032,19 +832,19 @@ async def close(self): await state.cancel_poll_task() await super().close() - def _stash_routed_blocks_on_route(self, request: OpenAIRequest, + def _stash_routed_blocks_on_route(self, key: int, block_hashes: list[list[BlockHash]], hash_algo: str) -> None: if not self._track_routed_blocks: return flat = [h for hl in block_hashes for h in hl] - self._pending_routed_blocks[id(request)] = (flat, hash_algo) + self._pending_routed_blocks[key] = (flat, hash_algo) - def _apply_routed_blocks_on_finish(self, request: OpenAIRequest, + def _apply_routed_blocks_on_finish(self, key: int, server: Optional[str], success: bool) -> None: # Pop unconditionally to avoid leaks; apply only when eligible. - entry = self._pending_routed_blocks.pop(id(request), None) + entry = self._pending_routed_blocks.pop(key, None) if not (self._track_routed_blocks and success): return if entry is None: @@ -1133,66 +933,68 @@ async def get_next_server( self, request: OpenAIRequest, exclude_server: Optional[str] = None) -> tuple[str, dict]: - async with self._lock: - servers = list([ - server for server in self._server_state.keys() - if server != exclude_server - ]) - if not servers: - raise ValueError( - f"No available servers after excluding {exclude_server}") + # Standalone (in-process) entry point = routing_key(tokenize+hash) then + # the shared _route core -- the SAME core the coordinator path uses. + key = await asyncio.to_thread(self._routing_key_sync, request) + server, info, _handle = await self._route( + key, exclude_server=exclude_server, request=request) + return server, info + + def _routing_key_sync(self, request: OpenAIRequest) -> dict: + """Tokenize + per-algo block-hash (CPU-bound; run in a thread). Returns a + plain dict so the coordinator path can send it over HTTP unchanged.""" cache_salt_id = self._get_request_cache_salt_id(request) - hash_algo_by_server = { - server: self._get_server_hash_algo(server) - for server in servers - } - # Tokenize + block-hash is CPU-bound (~50 ms p50 for a 40 k-token - # chat request with a Rust-backed tokenizer). Running it directly - # inside the async handler blocks the orchestrator's event loop and - # serializes all concurrent requests through it; with HuggingFace - # tokenizers releasing the GIL, offloading to a thread lets multiple - # tokenize calls run in parallel and frees the event loop to - # dispatch HTTP traffic to the CTX/GEN workers meanwhile. - token_lists, block_hashes_by_algo = await asyncio.to_thread( - self._tokenize_and_compute_block_hashes_by_algo, request, - hash_algo_by_server.values(), cache_salt_id) - # select the server by (KV match - load), bounded by load_cap - workloads = [ - self._server_state[server].num_active_requests() - for server in servers - ] - load_fractions = [ - workloads[i] / self._max_batch_size for i in range(len(servers)) - ] - scores = [] - matches = [] - for i in range(len(servers)): - server = servers[i] - hash_algo = hash_algo_by_server[server] - block_hashes = block_hashes_by_algo[hash_algo] - # https://github.com/ai-dynamo/dynamo/blob/main/docs/kv_cache_routing.md#kv-cache-routing-and-load-balancing + # Hash for every algo any server might use (usually one). + algos = {self._get_server_hash_algo(s) + for s in self._server_state.keys()} or None + token_lists, block_hashes_by_algo = \ + self._tokenize_and_compute_block_hashes_by_algo( + request, algos, cache_salt_id) + return {"token_lists": token_lists, + "block_hashes_by_algo": block_hashes_by_algo, + "conv_key": self._content_affinity_key(request)} + + async def _route(self, key, exclude_server=None, request=None, req_id=None): + """THE single routing core, shared by the standalone and coordinator + paths. Scores each server by (matched_tokens/tokens_per_block - + load_weight*load), applies load_cap, conversation affinity and RR + tie-break, then registers load + remembers routed blocks. The ONLY + difference between the two callers is request identity: standalone passes + the request (load keyed by id(request), finished via finish_request); + the coordinator passes req_id (the disagg request id) and finishes via + finish_request_by_id(req_id). Returns (server, info, req_id-or-None).""" + import time as _time + _rt_t0 = _time.monotonic() + token_lists = (key or {}).get("token_lists") or [] + block_hashes_by_algo = (key or {}).get("block_hashes_by_algo") or {} + conv_key = (key or {}).get("conv_key") + async with self._lock: + servers = [s for s in self._server_state.keys() + if s != exclude_server] + if not servers: + raise ValueError( + f"No available servers after excluding {exclude_server}") + + def _hashes(server): + algo = self._get_server_hash_algo(server) + return algo, block_hashes_by_algo.get(algo, []) + + workloads = [self._server_state[s].num_active_requests() + for s in servers] + load_fractions = [workloads[i] / self._max_batch_size + for i in range(len(servers))] + scores, matches = [], [] + for i, server in enumerate(servers): + algo, bh = _hashes(server) matches.append(await self._server_state[server].matched_tokens( - block_hashes, hash_algo)) - score = matches[-1] / self._tokens_per_block - self._load_weight * \ - workloads[i] - scores.append(score) - # Optional hard cap: drop servers at/over load_cap; fall back to all if - # none remain. Disabled by default (load_cap=inf) to match the original - # score-only selection. - candidate_idx = [ - i for i, lf in enumerate(load_fractions) if lf < self._load_cap - ] - if not candidate_idx: - candidate_idx = list(range(len(servers))) - # Conversation affinity: pin all turns of a conversation (keyed by a - # content-derived prefix hash, no conversation-id header) to the server - # it first landed on, so a worker eviction shrinking the match score - # cannot scatter the conversation off its warm home. New conversations - # (no pin yet) fall through to the score, which balances them by load. + bh, algo)) + scores.append(matches[-1] / self._tokens_per_block + - self._load_weight * workloads[i]) + candidate_idx = [i for i, lf in enumerate(load_fractions) + if lf < self._load_cap] or list(range(len(servers))) affinity = getattr(self, "_route_affinity", None) if affinity is None: affinity = self._route_affinity = OrderedDict() - conv_key = self._content_affinity_key(request) winner = None if conv_key is not None: pinned = affinity.get(conv_key) @@ -1211,33 +1013,84 @@ async def get_next_server( affinity.move_to_end(conv_key) while len(affinity) > ROUTE_AFFINITY_CACHE_SIZE: affinity.popitem(last=False) - hash_algo = hash_algo_by_server[server] - block_hashes = block_hashes_by_algo[hash_algo] + hash_algo, block_hashes = _hashes(server) + + # Register load + remember routed blocks in the SAME maps the standalone + # path uses; only the KEY differs. Standalone keys by id(request); the + # coordinator path keys by the disagg request id (req_id) -- the sole id + # that crosses the HTTP hop on /finish. No invented token, no fallback, + # no parallel map. + key = id(request) if req_id is None else req_id async with self._lock: - await self._register_request(server, request) - self._stash_routed_blocks_on_route(request, block_hashes, hash_algo) + await self._server_state[server].increment_load(request) + self._req_routing_table[key] = server + self._stash_routed_blocks_on_route(key, block_hashes, hash_algo) + self._record_route_timing(_time.monotonic() - _rt_t0) return server, { - "block_hashes": block_hashes, # list[list[int | str]] + "block_hashes": block_hashes, "hash_algo": hash_algo, - "token_lists": token_lists, # list[list[int]] - "matches": matches, # list[int] + "token_lists": token_lists, + "matches": matches, "server_info": self._server_info.get(server, {}), - } + }, req_id async def finish_request(self, request: OpenAIRequest, session: Optional[aiohttp.ClientSession] = None, success: bool = True): + # Standalone entry point: key by id(request); pass request so token-load + # accounting matches the increment_load(request) done at route time. + await self._finish(id(request), success, request=request, + session=session) + + async def _finish(self, key, success, request=None, session=None): + """THE single finish core, shared by finish_request (standalone, key = + id(request), request passed) and finish_request_by_id (coordinator, key = + disagg req_id, request=None -- symmetric with its increment_load(None)). + Pops the SAME maps _route filled, decrements load, commits routed blocks + on success, and refreshes the block table.""" async with self._lock: - server = self._req_routing_table.pop(id(request), None) + server = self._req_routing_table.pop(key, None) if server is not None and server in self._server_state: await self._server_state[server].decrement_load(request) - self._apply_routed_blocks_on_finish(request, server, success) + self._apply_routed_blocks_on_finish(key, server, success) + self._poll_server_on_finish(server, session) + + def _poll_server_on_finish(self, server, session=None): + """Refresh a server's KV-cache block table from /kv_cache_events after a + request finishes. Shared by finish_request (standalone) and + finish_request_by_id (coordinator) so the block table always stays warm -- + the delegated path is inert without this (matches score 0).""" if (server is not None and server in self._server_state and self._events_aligned(server)): # Fire-and-forget; poll runs in background and coalesces per server. self._server_state[server].schedule_poll_and_update(session) + # ---- coordinator delegation: thin wrappers over the shared _route core --- + # Under the disagg coordinator (WEB_CONCURRENCY>1) the fleet worker computes + # routing_key() locally and the coordinator (which owns _server_state via + # /kv_cache_events polling) runs get_next_server_by_key(). Both go through the + # exact same _route core as the standalone get_next_server -- no divergence. + + def routing_key(self, request: OpenAIRequest): + """Worker-side: tokenize + block-hash. Same dict _route consumes, JSON- + serializable for the /select POST.""" + return self._routing_key_sync(request) + + async def get_next_server_by_key(self, routing_key, exclude_server=None, + req_id=None): + """Coordinator-side placement: the SAME _route core, keyed by the + caller's disagg request id (req_id) instead of id(request).""" + return await self._route(routing_key, exclude_server=exclude_server, + request=None, req_id=req_id) + + async def finish_request_by_id(self, req_id, success=True): + """Coordinator-side finish: the SAME _finish core, keyed by the disagg + request id instead of id(request).""" + if req_id is None: + return + await self._finish(req_id, success) + def _on_servers_updated(self, old_servers, new_servers): new_state = {} for server in new_servers: @@ -1386,6 +1239,9 @@ def __init__(self, } # id(request) -> (server, weight, monotonic_timestamp) self._req_content_entry: dict[int, tuple[str, int, float]] = {} + # Coordinator-delegated path only: disagg req_id -> server, between + # select and finish (id(request) can't cross the HTTP hop). + self._coord_pending: dict = {} # ── content-based load tracking ── @@ -1682,6 +1538,151 @@ async def finish_request(self, logger.debug(f"ConversationRouter: FINISH server={server}, " f"content_loads={loads}") + # -- coordinator-path: conversation_id-only sticky routing -- + # The coordinator has no request object, so per-request load is tracked by an + # opaque handle instead of id(request). Only explicit conversation_id sessions + # are supported over the coordinator (no implicit content match). + + def routing_key(self, request: OpenAIRequest): + """The conversation_id (or None); no tokenization on the worker.""" + return self._get_conversation_id(request) + + async def get_next_server_by_key(self, routing_key, exclude_server=None, + req_id=None): + conv_id = routing_key + self._validate_servers_available() + async with self._lock: + entry = self._session_table.get(conv_id) if conv_id else None + if (entry is not None and entry[0] in self._server_state + and entry[0] != exclude_server): + server = entry[0] + self._session_table.move_to_end(conv_id) + else: + server = self._select_least_loaded(exclude_server) + if server is None: + raise ValueError( + f"No available servers after excluding {exclude_server}") + if conv_id: + self._update_session(conv_id, server, []) + # Request-count load (no request object at the coordinator). Keyed by + # the disagg req_id -- the sole id crossing the HTTP hop on /finish. + self._server_content_load[server] = ( + self._server_content_load.get(server, 0) + 1) + if req_id is not None: + self._coord_pending[req_id] = server + return server, {"server_info": self._server_info.get(server, {})}, req_id + + async def finish_request_by_id(self, req_id, success=True): + del success + if req_id is None: + return + async with self._lock: + server = self._coord_pending.pop(req_id, None) + if server and server in self._server_content_load: + self._server_content_load[server] = max( + 0, self._server_content_load[server] - 1) + + +class CoordinatorDelegatingRouter(Router): + """Worker-side Router that delegates placement to the disagg coordinator. + + Used only for *stateful* routers (conversation, kv_cache_aware): the worker + must not keep its own copy of that state, so it wraps a local router of the same + type and, for each request, computes the small ``routing_key`` locally and + POSTs it to the coordinator's ``/select``. ``finish_request`` POSTs the + returned handle to ``/finish`` so the coordinator releases per-request state. + Server-pool / prepare / close operations delegate to the wrapped local router. + + Stateless routers (round_robin, load_balancing) are NOT wrapped -- the worker + holds the real router and places locally, so they never reach this class (see + ``CoordinatorClient``). ``OpenAIClient`` already drives + ``router.get_next_server`` / ``router.finish_request``, so the completions + service needs no worker-specific branching. + """ + + def __init__(self, coordinator_url: str, local_router: "Router", role: str, + request_timeout_s: float = 5.0): + # Intentionally NOT calling Router.__init__: this is a thin proxy whose + # server-pool state lives on the wrapped local router (see __getattr__). + self._coordinator_url = coordinator_url.rstrip("/") + self._local = local_router + self._role = role # "context" | "generation" + self._request_timeout_s = request_timeout_s + self._session: Optional[aiohttp.ClientSession] = None + + def __getattr__(self, name): + # servers / prepare_servers / num_prepared_servers / start_server_monitoring + # / routing_key / ... all delegate to the local router. + return getattr(self._local, name) + + @property + def session(self) -> aiohttp.ClientSession: + if self._session is None: + self._session = aiohttp.ClientSession() + return self._session + + def _on_servers_updated(self, old_servers, new_servers): + pass + + def _request_id(self, request: OpenAIRequest) -> int: + """The request's disagg id -- the sole cross-process key for select/finish. + Context requests carry disagg_request_id; generation requests inherit it + as ctx_request_id. OpenAIDisaggregatedService always sets it before + routing, so a missing id is a bug -- assert, don't paper over.""" + dp = request.disaggregated_params + assert dp is not None, "delegated routing requires disaggregated_params" + rid = (dp.disagg_request_id if self._role == "context" + else dp.ctx_request_id) + assert rid is not None, ( + f"delegated {self._role} routing requires a disagg request id " + f"(disagg_request_id/ctx_request_id) on the request") + return rid + + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: + key = self._local.routing_key(request) + # Send the disagg request id as the sole cross-process request key; the + # coordinator keys its pending-request state by it for /finish. + payload = {"role": self._role, "routing_key": key, + "req_id": self._request_id(request), + "exclude_server": exclude_server} + async with self.session.post( + f"{self._coordinator_url}/select", json=payload, + timeout=self._request_timeout_s) as resp: + if resp.status != 200: + raise ValueError( + f"coordinator /select returned {resp.status}: " + f"{await resp.text()}") + body = await resp.json() + info = body.get("info") or {} + return body["server"], info + + async def finish_request(self, + request: OpenAIRequest, + session: Optional[aiohttp.ClientSession] = None, + success: bool = True): + del session + try: + async with self.session.post( + f"{self._coordinator_url}/finish", + json={"role": self._role, + "req_id": self._request_id(request), + "success": success}, + timeout=self._request_timeout_s) as resp: + if resp.status != 200: + logger.warning( + f"coordinator /finish returned {resp.status}") + except Exception as e: # noqa: BLE001 + logger.warning(f"CoordinatorDelegatingRouter finish failed: {e}") + + async def close(self): + if self._session is not None: + await self._session.close() + self._session = None + await self._local.close() + def create_router( router_config: Optional[RouterConfig], @@ -1728,3 +1729,35 @@ def create_router( metadata_server, server_preparation_func=server_preparation_func, **extra_args) + + +def build_disagg_routers( + ctx_router_config: Optional[RouterConfig], + gen_router_config: Optional[RouterConfig], + ctx_servers: Optional[List[str]], + gen_servers: Optional[List[str]], + metadata_server_cfg: Optional[MetadataServerConfig] = None, + metadata_server: Optional[JsonDictionary] = None, + server_preparation_func: Optional[Callable[[str], Awaitable[None]]] = None, + disagg_node_id: int = 0, + is_delegating_client: bool = False, +) -> tuple[Router, Router]: + """Build the ctx and gen routers for one disagg process. + + Each side is built independently via :func:`create_router`. Stateful router + types (conversation, kv_cache_aware) expose ``routing_key`` / + ``get_next_server_by_key``; when this process is a delegating client, the + caller (:class:`CoordinatorClient`) wraps those in a + :class:`CoordinatorDelegatingRouter` so placement is delegated to the + coordinator. Stateless types (round_robin, load_balancing) place locally. + ``is_delegating_client`` is accepted for call-site symmetry; router + construction itself does not depend on it. + """ + del is_delegating_client # no build-time behavior differs by this flag + ctx_router = create_router(ctx_router_config, ctx_servers, + metadata_server_cfg, metadata_server, + server_preparation_func, disagg_node_id) + gen_router = create_router(gen_router_config, gen_servers, + metadata_server_cfg, metadata_server, + server_preparation_func, disagg_node_id) + return ctx_router, gen_router diff --git a/tensorrt_llm/serve/router_utils.py b/tensorrt_llm/serve/router_utils.py new file mode 100644 index 000000000000..9484824c9430 --- /dev/null +++ b/tensorrt_llm/serve/router_utils.py @@ -0,0 +1,376 @@ +# Copyright (c) 2025-2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared router utilities: request tokenization and KV-cache block hashing. + +Extracted from ``router.py`` so the surface routers there can share a single +implementation of block hashing without importing the whole router module. +""" + +import os +from collections import OrderedDict +from typing import Iterable, List, Optional, Union + +from transformers import AutoTokenizer + +from tensorrt_llm.bindings.internal.batch_manager import \ + BlockKey as _NativeBlockKey +from tensorrt_llm.bindings.internal.batch_manager import \ + BlockKeyHasher as _NativeBlockKeyHasher +from tensorrt_llm.logger import logger +from tensorrt_llm.runtime import kv_cache_hash +from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import \ + Block as V2Block +from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import \ + ReuseScope +from tensorrt_llm.runtime.kv_cache_manager_v2._block_radix_tree import \ + RootBlock as V2RootBlock +from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, + CompletionRequest) + +KV_CACHE_HASH_ALGO_DEFAULT = kv_cache_hash.KV_CACHE_HASH_ALGO_DEFAULT +KV_CACHE_HASH_ALGO_V1 = kv_cache_hash.KV_CACHE_HASH_ALGO_V1 +KV_CACHE_HASH_ALGO_V2 = kv_cache_hash.KV_CACHE_HASH_ALGO_V2 +KV_CACHE_HASH_ALGO_V2_SHA256_64 = kv_cache_hash.KV_CACHE_HASH_ALGO_V2_SHA256_64 +get_cache_salt_id = kv_cache_hash.get_cache_salt_id +hash_v1_block_key = kv_cache_hash.hash_v1_block_key +truncate_sha256_hash_to_int64 = kv_cache_hash.truncate_sha256_hash_to_int64 + +OpenAIRequest = Union[CompletionRequest, ChatCompletionRequest] +BlockHash = Union[int, str] + +__all__ = [ + "KV_CACHE_HASH_ALGO_DEFAULT", + "KV_CACHE_HASH_ALGO_V1", + "KV_CACHE_HASH_ALGO_V2", + "KV_CACHE_HASH_ALGO_V2_SHA256_64", + "get_cache_salt_id", + "hash_v1_block_key", + "truncate_sha256_hash_to_int64", + "OpenAIRequest", + "BlockHash", + "get_request_num_tokens", + "block_key_hasher", + "v2_sha256_block_hasher", + "BlockHashMixin", + "PrefixBlockSet", +] + + +class PrefixBlockSet: + """Single-owner block-hash index -- a flat ``set`` of held block hashes. + + A KV-cache block hash folds in its parent chain (the worker computes it with + ``BlockKeyHasher.hash(block_key, parent_hash)``), so every block-hash *value* + is globally unique to one position in one prefix path. A request's ordered + block-hash list is itself the prefix path, so longest-common-prefix against a + set of held blocks is just "walk the list until the first hash the owner + doesn't hold" -- no explicit tree needed. + + This is the exact structure the orchestrator ``KvCacheAwareServerState`` uses + (a ``set[block_hash]`` per server, walked until the first miss). It is the + right index whenever there is only ONE logical owner -- e.g. the centralized + router's per-instance ``combined_trie`` (owner = instance) and each rank's + trie (owner = that rank). Those are only ever queried for a single owner's + prefix depth (:meth:`match_one`), so a ``hash -> {owner}`` reverse map with + per-depth set intersections would be pure overhead here. + + The ``owner_id`` argument on :meth:`add` / :meth:`remove` / :meth:`match_one` + is accepted and ignored so this is a drop-in for the single-owner call sites. + """ + + __slots__ = ("_blocks", ) + + def __init__(self) -> None: + self._blocks: set[int] = set() + + def add(self, owner_id: str, block_hashes: Iterable[int]) -> None: + self._blocks.update(block_hashes) + + def remove(self, owner_id: str, block_hashes: Iterable[int]) -> None: + self._blocks.difference_update(block_hashes) + + def remove_worker(self, owner_id: str) -> None: + self._blocks.clear() + + def match_one(self, owner_id: str, block_hashes: List[int]) -> int: + """Consecutive prefix-block count held by the (single) owner. + + Identical to ``KvCacheAwareServerState.matched_tokens``: walk the query + path, counting blocks present in the set, and stop at the first miss. + """ + blocks = self._blocks + depth = 0 + for h in block_hashes: + if h not in blocks: + break + depth += 1 + return depth + + def has_worker(self, owner_id: str) -> bool: + return bool(self._blocks) + + +def get_request_num_tokens(request: OpenAIRequest) -> int: + if request.disaggregated_params is None or request.disaggregated_params.request_type == "context_only": + if isinstance(request, ChatCompletionRequest): + raise ValueError( + "LoadBalancing router with tokens doesn't support ChatCompletionRequest yet" + ) + + if isinstance(request.prompt, str) or \ + (isinstance(request.prompt, list) and isinstance(request.prompt[0], int)): + prompts = [request.prompt] + else: + prompts = request.prompt + + num_tokens = sum(len(prompt) for prompt in prompts) + elif request.disaggregated_params.request_type == "generation_only": + raise ValueError( + "LoadBalancing router with tokens doesn't support generation_only requests" + ) + else: + raise ValueError( + f"Unsupported request type: {request.disaggregated_params.request_type}" + ) + + return num_tokens + + +def block_key_hasher(token_ids: list[int], + parent_hash: Optional[int] = None, + cache_salt_id: Optional[int] = None) -> int: + parent = 0 if parent_hash is None else parent_hash + # Fast path: the native C++ BlockKeyHasher is bit-exact with + # hash_v1_block_key and avoids the per-token Python loop. Its hash() binding + # takes no cache_salt_id, so fall back to Python only when a salt is set + # (rare opt-in; never in the unsalted agent/chat completion path). + if cache_salt_id is None: + return _NativeBlockKeyHasher.hash(_NativeBlockKey(token_ids), parent) + return hash_v1_block_key(token_ids, + parent_hash=parent, + cache_salt_id=cache_salt_id) + + +def v2_sha256_block_hasher(token_ids: list[int], + parent_hash: Optional[str] = None, + cache_salt_id: Optional[int] = None) -> str: + parent_key = (V2RootBlock.make_key(ReuseScope(salt=cache_salt_id)) + if parent_hash is None else bytes.fromhex(parent_hash)) + return V2Block.make_key(parent_key, token_ids).hex() + + +class BlockHashMixin: + """Shared tokenization and block-hash computation. + + Used by routers that need KV-cache-aware prefix matching. + """ + + def _init_block_hashing(self, + tokens_per_block: Optional[int] = None, + custom_tokenizer: Optional[str] = None, + tokenizer_dir: Optional[str] = None): + env_tokens_per_block = os.environ.get( + "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK") + if env_tokens_per_block is not None: + tokens_per_block = int(env_tokens_per_block) + self._tpb_auto = tokens_per_block is None + self._tokens_per_block = 32 if tokens_per_block is None \ + else tokens_per_block + self._tokenizers: dict = {} + self._custom_tokenizer = custom_tokenizer + self._tokenizer_dir = tokenizer_dir + logger.info(f"BlockHashMixin: tokens_per_block={self._tokens_per_block}" + f"{' (auto, adopts worker)' if self._tpb_auto else ''}" + f", custom_tokenizer={self._custom_tokenizer}") + + def _get_tokenizer(self, model: str): + if model not in self._tokenizers: + model_path = self._tokenizer_dir or model + if self._custom_tokenizer: + from tensorrt_llm.tokenizer import load_custom_tokenizer + self._tokenizers[model] = load_custom_tokenizer( + self._custom_tokenizer, model_path) + else: + from tensorrt_llm.tokenizer import \ + maybe_fix_byte_level_tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True) + self._tokenizers[model] = maybe_fix_byte_level_tokenizer( + tokenizer, model_path, trust_remote_code=True) + return self._tokenizers[model] + + def _encode_with_prefix_cache(self, rendered: str, key: int, + tokenizer) -> list[int]: + cache = getattr(self, "_tok_prefix_cache", None) + if cache is None: + cache = self._tok_prefix_cache = OrderedDict() + entry = cache.get(key) + if entry is not None and len(rendered) > len(entry[0]) and \ + rendered.startswith(entry[0]): + ids = entry[1] + tokenizer.encode(rendered[len(entry[0]):], + add_special_tokens=False) + else: + ids = tokenizer.encode(rendered, add_special_tokens=False) + cache[key] = (rendered, ids) + cache.move_to_end(key) + while len(cache) > 1024: + cache.popitem(last=False) + return ids + + def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: + # Handle ChatCompletionRequest (has messages, not prompt) + if isinstance(request, ChatCompletionRequest): + if request.prompt_token_ids is not None: + return [request.prompt_token_ids] + tokenizer = self._get_tokenizer(request.model) + tool_dicts = (None if getattr(request, "tools", None) is None else [ + tool.model_dump() if hasattr(tool, "model_dump") else tool + for tool in request.tools + ]) + chat_template_kwargs = (request.chat_template_kwargs if getattr( + request, "chat_template_kwargs", None) else {}) + rendered = tokenizer.apply_chat_template( + [ + msg if isinstance(msg, dict) else dict(msg) + for msg in request.messages + ], + add_generation_prompt=request.add_generation_prompt, + tokenize=False, + return_dict=False, + tools=tool_dicts, + **chat_template_kwargs, + ) + if isinstance(rendered, str): + key = hash("".join( + str( + msg.get("content") if isinstance(msg, dict) else + getattr(msg, "content", "")) + for msg in request.messages[:2])) + result = self._encode_with_prefix_cache(rendered, key, + tokenizer) + else: + result = list(rendered) + request.prompt_token_ids = result + return [result] + + # Handle CompletionRequest (has prompt) + prompts = request.prompt + if isinstance(prompts, list) and isinstance(prompts[0], list): + return prompts + elif isinstance(prompts, list) and isinstance(prompts[0], int): + return [prompts] + elif isinstance(prompts, str): + prompts = [prompts] + else: + assert isinstance(prompts, list) and isinstance(prompts[0], str) + + tokenizer = self._get_tokenizer(request.model) + token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts] + # Replace string prompts with token IDs so the worker server + # skips re-tokenization + request.prompt = (token_lists + if len(token_lists) > 1 else token_lists[0]) + return token_lists + + def _compute_block_hashes( + self, + token_lists: list[list[int]], + hash_algo: str = KV_CACHE_HASH_ALGO_DEFAULT, + cache_salt_id: Optional[int] = None, + ) -> list[list[BlockHash]]: + if hash_algo == KV_CACHE_HASH_ALGO_V1: + block_hasher = block_key_hasher + elif hash_algo == KV_CACHE_HASH_ALGO_V2: + block_hasher = v2_sha256_block_hasher + elif hash_algo == KV_CACHE_HASH_ALGO_V2_SHA256_64: + reuse_scope = ReuseScope(salt=cache_salt_id) + block_hashes: list[list[BlockHash]] = [] + for token_list in token_lists: + hash_list = [] + parent_key = V2RootBlock.make_key(reuse_scope) + for t in range(0, len(token_list) - 1, self._tokens_per_block): + t_end = min(t + self._tokens_per_block, len(token_list) - 1) + parent_key = V2Block.make_key(parent_key, + token_list[t:t_end]) + hash_list.append(truncate_sha256_hash_to_int64(parent_key)) + block_hashes.append(hash_list) + return block_hashes + else: + raise ValueError( + f"Unsupported KV cache hash algorithm: {hash_algo}") + + block_hashes: list[list[BlockHash]] = [] + for token_list in token_lists: + hash_list = [] + # in KvCacheManager, the last token is not included in the block key + for t in range(0, len(token_list) - 1, self._tokens_per_block): + t_end = min(t + self._tokens_per_block, len(token_list) - 1) + hash_list.append( + block_hasher(token_list[t:t_end], + None if t == 0 else hash_list[-1], + cache_salt_id)) + block_hashes.append(hash_list) + return block_hashes + + def _tokenize_and_compute_block_hashes( + self, + request: OpenAIRequest) -> tuple[list[list[int]], list[list[int]]]: + """Synchronous tokenize + block-hash, combined for thread offload. + + Factored into one method so ``get_next_server`` can offload the whole + CPU-bound step via ``asyncio.to_thread`` in a single call, keeping + the orchestrator's asyncio event loop free to dispatch other + requests in parallel. + """ + token_lists = self._tokenize(request) + block_hashes = self._compute_block_hashes(token_lists) + return token_lists, block_hashes + + def _tokenize_and_compute_block_hashes_with_salt( + self, request: OpenAIRequest, + cache_salt_id: Optional[int] = None, + ) -> tuple[list[list[int]], list[list[int]]]: + token_lists = self._tokenize(request) + block_hashes = self._compute_block_hashes( + token_lists, cache_salt_id=cache_salt_id) + return token_lists, block_hashes + + def _tokenize_and_compute_block_hashes_by_algo( + self, + request: OpenAIRequest, + hash_algos: Iterable[str], + cache_salt_id: Optional[int] = None, + ) -> tuple[list[list[int]], dict[str, list[list[BlockHash]]]]: + """Synchronous tokenize + per-algorithm block hashes for thread offload.""" + token_lists = self._tokenize(request) + return token_lists, { + hash_algo: + self._compute_block_hashes(token_lists, + hash_algo, + cache_salt_id=cache_salt_id) + for hash_algo in set(hash_algos) + } + + @staticmethod + def _text_to_int_sequences(texts: list[str]) -> list[list[int]]: + """Convert text strings to lists of unicode code points. + + Usable as input to ``_compute_block_hashes``. + """ + return [[ord(c) for c in text] for text in texts] + + @staticmethod + def _get_request_cache_salt_id(request: OpenAIRequest) -> Optional[int]: + cache_salt = getattr(request, "cache_salt", None) + return None if cache_salt is None else get_cache_salt_id(cache_salt) diff --git a/tests/unittest/disaggregated/test_coordinator_e2e.py b/tests/unittest/disaggregated/test_coordinator_e2e.py new file mode 100644 index 000000000000..a6a7337ad4dc --- /dev/null +++ b/tests/unittest/disaggregated/test_coordinator_e2e.py @@ -0,0 +1,410 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""End-to-end coordinator/worker disagg serving with mocked ctx/gen workers. + +CPU-only, MPI-free, single-process (three uvicorn threads): + + * mocked ctx + gen HTTP workers serve ``/health`` + ``/v1/completions`` + (ctx returns a context_only response with disaggregated_params so the disagg + server proceeds to gen; gen returns the final completion text), + * a real ``CoordinatorServer`` (wrapping a ``DisaggCoordinatorService``) runs on + an internal port -- the gen router is a *stateful* conversation router, so gen + placement is delegated to it via ``/select``; the ctx router is round-robin + (placed locally in the disagg server), + * a real ``OpenAIDisaggServer`` in worker mode (``coordinator_url`` set, so it + holds a ``CoordinatorClient``) serves the public ``/v1/completions``. + +A real HTTP completion is sent to the disagg server and must round-trip +ctx -> (coordinator /select) -> gen, returning the gen worker's text. This +exercises the whole chain including the coordinator HTTP hop. +""" + +import asyncio +import os +import subprocess +import sys +import tempfile +import threading +import time + +import aiohttp +import pytest +import uvicorn +import yaml +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response + +from tensorrt_llm.logger import logger + +from tensorrt_llm.llmapi.disagg_utils import (CtxGenServerConfig, + DisaggServerConfig, RouterConfig, + ServerRole) +from tensorrt_llm.serve.coordinator_server import CoordinatorServer +from tensorrt_llm.serve.disagg_coordinator import DisaggCoordinatorService +from tensorrt_llm.serve.openai_client import OpenAIHttpClient +from tensorrt_llm.serve.openai_disagg_server import OpenAIDisaggServer + +GEN_TEXT = "HELLO_FROM_GEN" + +# The uvicorn worker threads / CLI-output pump thread are background threads that +# outlive a strict thread snapshot; exempt this module (same as the other e2e). +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(autouse=True) +def _reset_prometheus_registry(): + """Role-prefixed Prometheus counters are registered in the global default + registry; clear it between tests so a second server build in the same pytest + process does not hit duplicate-timeseries errors.""" + from prometheus_client import REGISTRY + yield + for collector in list(REGISTRY._collector_to_names): + try: + REGISTRY.unregister(collector) + except Exception: + pass + + +def _free_port(): + import socket + s = socket.socket() + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.close() + return port + + +class _UvicornThread: + """Run a FastAPI app in a background uvicorn server thread.""" + + def __init__(self, app, port): + self.port = port + self._server = uvicorn.Server( + uvicorn.Config(app, host="127.0.0.1", port=port, + log_level="warning")) + self._thread = threading.Thread(target=self._server.run, daemon=True) + + def __enter__(self): + self._thread.start() + for _ in range(100): + if self._server.started: + break + time.sleep(0.1) + return self + + def __exit__(self, *a): + self._server.should_exit = True + self._thread.join(timeout=10) + + +def _mock_worker_app(role: str) -> FastAPI: + """A ctx or gen worker: /health + /server_info + /v1/completions.""" + app = FastAPI() + + @app.get("/health") + async def health(): + return Response(status_code=200) + + @app.get("/server_info") + async def server_info(): + return JSONResponse({"kv_cache_hash_algo": "v1"}) + + @app.post("/v1/completions") + async def completions(raw: Request): + body = await raw.json() + dp = body.get("disaggregated_params") or {} + model = body.get("model", "m") + if dp.get("request_type") == "context_only": + # Context phase: return disagg params so the disagg server proceeds + # to the gen worker (finish_reason "length" => needs generation). + rid = dp.get("disagg_request_id") + return JSONResponse({ + "id": "cmpl-ctx", + "object": "text_completion", + "created": 0, + "model": model, + "prompt_token_ids": [1, 2, 3], + "choices": [{ + "index": 0, + "text": "", + "finish_reason": "length", + "disaggregated_params": { + "request_type": "context_only", + "ctx_request_id": rid, + "disagg_request_id": rid, + }, + }], + "usage": {"prompt_tokens": 3, "completion_tokens": 0, + "total_tokens": 3}, + }) + # Generation phase: final answer. + return JSONResponse({ + "id": "cmpl-gen", + "object": "text_completion", + "created": 0, + "model": model, + "choices": [{"index": 0, "text": GEN_TEXT, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 3, "completion_tokens": 2, + "total_tokens": 5}, + }) + + return app + + +class _ReadinessClient: + """Minimal client the coordinator uses only for server readiness probing. + + In coordinator/worker mode the coordinator never sends completions (the disagg + servers do), so its readiness client needs no metrics -- reusing the real + ``check_ready_for_servers`` keeps the probe faithful while avoiding a second + set of role-prefixed Prometheus counters in this single-process test. + """ + + def __init__(self, router): + self._router = router + self._session = aiohttp.ClientSession() + + async def check_ready(self): + ready, unready = await OpenAIHttpClient.check_ready_for_servers( + self._session, self._router.servers) + if ready: + await self._router.prepare_servers(ready) + return ready, unready + + async def shutdown(self): + await self._session.close() + + +def _make_config(ctx_url, gen_url, public_port): + host_port = lambda u: (u.split(":")[0], int(u.split(":")[1])) + ctx_host, ctx_port = host_port(ctx_url) + gen_host, gen_port = host_port(gen_url) + return DisaggServerConfig( + server_configs=[ + CtxGenServerConfig(type="ctx", hostname=ctx_host, port=ctx_port), + CtxGenServerConfig(type="gen", hostname=gen_host, port=gen_port), + ], + hostname="127.0.0.1", + port=public_port, + # ctx: stateless (placed locally in the disagg server); + # gen: stateful conversation router (placement delegated to coordinator). + ctx_router_config=RouterConfig(type="round_robin", + server_role=ServerRole.CONTEXT), + gen_router_config=RouterConfig(type="conversation", + server_role=ServerRole.GENERATION)) + + +class _CoordinatorThread: + """Run a CoordinatorServer (DisaggCoordinatorService) in a uvicorn thread.""" + + def __init__(self, config): + self.port = _free_port() + self.url = f"http://127.0.0.1:{self.port}" + # The coordinator builds its own owner routers from config. + self._coordinator = DisaggCoordinatorService( + config, + client_factory=lambda router, role, mr=1: _ReadinessClient(router)) + self._impl = _UvicornThread(CoordinatorServer(self._coordinator).app, + self.port) + + def __enter__(self): + self._impl.__enter__() + return self + + def __exit__(self, *a): + self._impl.__exit__(*a) + + +async def _wait_healthy(url, timeout_s=30.0): + deadline = time.time() + timeout_s + async with aiohttp.ClientSession() as sess: + while time.time() < deadline: + try: + async with sess.get(f"{url}/health", timeout=1) as r: + if r.status == 200: + return True + except Exception: + pass + await asyncio.sleep(0.2) + return False + + +def test_disagg_completion_e2e_through_coordinator(): + with _UvicornThread(_mock_worker_app("ctx"), _free_port()) as ctx, \ + _UvicornThread(_mock_worker_app("gen"), _free_port()) as gen: + ctx_url = f"127.0.0.1:{ctx.port}" + gen_url = f"127.0.0.1:{gen.port}" + public_port = _free_port() + config = _make_config(ctx_url, gen_url, public_port) + + with _CoordinatorThread(config) as coord: + assert asyncio.run(_wait_healthy(coord.url)), \ + "coordinator never became healthy" + + disagg = OpenAIDisaggServer(config=config, + coordinator_url=coord.url) + with _UvicornThread(disagg.app, public_port) as server: + base = f"http://127.0.0.1:{server.port}" + assert asyncio.run(_wait_healthy(base)), \ + "disagg server never became healthy" + + async def drive(): + async with aiohttp.ClientSession() as sess: + payload = {"model": "m", "prompt": "hello", + "max_tokens": 8} + # X-Session-ID -> conversation_id, so the gen router + # (conversation) delegates placement to the coordinator. + headers = {"X-Session-ID": "conv-e2e"} + async with sess.post(f"{base}/v1/completions", + json=payload, headers=headers, + timeout=30) as r: + assert r.status == 200, await r.text() + return await r.json() + + body = asyncio.run(drive()) + + # The full ctx -> coordinator/select -> gen chain returned the gen text. + assert body["choices"][0]["text"] == GEN_TEXT, body + assert body["choices"][0]["finish_reason"] == "stop" + + +def _write_config(path, ctx_url, gen_url, public_port): + """A disagg config YAML: round-robin ctx, conversation gen (delegated).""" + cfg = { + "hostname": "127.0.0.1", + "port": public_port, + "context_servers": { + "num_instances": 1, + "urls": [ctx_url], + "router": {"type": "round_robin"}, + }, + "generation_servers": { + "num_instances": 1, + "urls": [gen_url], + "router": {"type": "conversation"}, + }, + } + with open(path, "w") as f: + yaml.safe_dump(cfg, f) + + +def test_disagg_completion_e2e_web_concurrency_4(): + """WEB_CONCURRENCY=4 through the real CLI: `trtllm-serve disaggregated` + forks a coordinator (port-1) + a uvicorn fleet of 4 disagg servers on the + public port. Mock ctx/gen workers run in this test process; a real HTTP + completion round-trips through one of the four workers -> coordinator -> gen. + """ + logger.set_level("info") # trtllm logger defaults to "error"; show progress + WORKERS = 4 + with _UvicornThread(_mock_worker_app("ctx"), _free_port()) as ctx, \ + _UvicornThread(_mock_worker_app("gen"), _free_port()) as gen: + ctx_url = f"127.0.0.1:{ctx.port}" + gen_url = f"127.0.0.1:{gen.port}" + # port-1 is the coordinator, so pick a public port with room below it. + public_port = _free_port() + coord_port = public_port - 1 + + with tempfile.TemporaryDirectory() as td: + cfg_path = os.path.join(td, "disagg.yaml") + _write_config(cfg_path, ctx_url, gen_url, public_port) + + env = dict(os.environ) + env["WEB_CONCURRENCY"] = str(WORKERS) + # Unbuffered so the child's launch logs stream out live (else stdout + # to a pipe is block-buffered and nothing shows until it exits). + env["PYTHONUNBUFFERED"] = "1" + # trtllm logger defaults to "error"; raise it so the coordinator/fleet + # launch logs are visible in the streamed [cli] output. + env["TLLM_LOG_LEVEL"] = "info" + # A plain HTTP fleet, never an MPI rank -- strip any launcher env so + # the CLI's own strip is not even relied upon. + for k in list(env): + if k.startswith(("SLURM_", "PMIX_", "PMI_", "OMPI_", "UCX_", + "I_MPI_", "HYDRA_", "MPI_")): + env.pop(k) + + logger.info(f"mock ctx={ctx_url} gen={gen_url}; launching " + f"`trtllm-serve disaggregated` WEB_CONCURRENCY={WORKERS}, " + f"public={public_port} coordinator={coord_port}") + proc = subprocess.Popen( + [sys.executable, "-m", "tensorrt_llm.commands.serve", + "disaggregated", "-c", cfg_path], + env=env, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, bufsize=1, start_new_session=True) + + # Stream the CLI child's stdout live (prefixed) so coordinator/fleet + # startup is visible in real time instead of only at teardown. + def _pump(): + for line in proc.stdout: + logger.info(f"[cli] {line.rstrip()}") + + pump = threading.Thread(target=_pump, daemon=True) + pump.start() + try: + base = f"http://127.0.0.1:{public_port}" + coord = f"http://127.0.0.1:{coord_port}" + + async def _wait_all(): + # Both the coordinator (port-1) and the public fleet must be + # up before the fleet reports ready (fleet is_ready proxies + # the coordinator). + logger.info("waiting for coordinator health...") + assert await _wait_healthy(coord, 120.0), \ + "coordinator never became healthy" + logger.info("coordinator healthy; waiting for fleet health...") + assert await _wait_healthy(base, 120.0), \ + "disagg fleet never became healthy" + logger.info("fleet healthy") + + asyncio.run(_wait_all()) + + async def drive(): + # Fire several requests so the kernel spreads them across the + # 4 uvicorn workers. Every one must round-trip to GEN_TEXT. + async with aiohttp.ClientSession() as sess: + texts = [] + for i in range(8): + payload = {"model": "m", "prompt": f"hello-{i}", + "max_tokens": 8} + headers = {"X-Session-ID": f"conv-{i}"} + async with sess.post(f"{base}/v1/completions", + json=payload, headers=headers, + timeout=30) as r: + assert r.status == 200, await r.text() + texts.append((await r.json())["choices"][0]["text"]) + logger.info(f"request {i} -> {texts[-1]!r}") + return texts + + texts = asyncio.run(drive()) + assert all(t == GEN_TEXT for t in texts), texts + logger.info(f"all {len(texts)} requests round-tripped to GEN_TEXT") + finally: + # Kill the whole process group: the CLI parent + coordinator + + # all uvicorn workers. Terminating only proc leaves the workers + # holding the stdout pipe open, so _pump never sees EOF. + logger.info("terminating CLI process group") + import signal + pgid = os.getpgid(proc.pid) + os.killpg(pgid, signal.SIGTERM) + try: + proc.wait(timeout=15) + except subprocess.TimeoutExpired: + os.killpg(pgid, signal.SIGKILL) + proc.wait() + pump.join(timeout=10) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v", "-s"])) diff --git a/tests/unittest/disaggregated/test_coordinator_worker.py b/tests/unittest/disaggregated/test_coordinator_worker.py new file mode 100644 index 000000000000..b7fd0ec9be5a --- /dev/null +++ b/tests/unittest/disaggregated/test_coordinator_worker.py @@ -0,0 +1,247 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Coordinator/worker disagg routing: cross-process placement contract. + +CPU-only, MPI-free. Wires the real coordinator surface to a real worker-side +coordinator: + + * fake ctx/gen HTTP workers answer ``/health`` (readiness only), + * a real ``CoordinatorServer`` (wrapping a ``DisaggCoordinatorService`` over the + configured routers) runs in a uvicorn thread on an internal port, + * a ``CoordinatorClient`` (what a worker holds) wraps only *stateful* routers + in a ``CoordinatorDelegatingRouter`` whose ``get_next_server`` computes the + routing key locally and POSTs it to the coordinator's ``/select``; + ``finish_request`` releases coordinator-side state via ``/finish`` and the + returned handle. *Stateless* routers (round_robin) are used as-is and place + locally in the worker. + +This proves the routing split: stateful routers (conversation, kv_cache_aware) +delegate to the coordinator via ``routing_key`` + ``get_next_server_by_key``, +while stateless routers never touch the coordinator. +""" + +import asyncio +import json +import threading +import time +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer + +import aiohttp +import pytest +import uvicorn + +from tensorrt_llm.llmapi.disagg_utils import (CtxGenServerConfig, + DisaggServerConfig, RouterConfig, + ServerRole) +from tensorrt_llm.serve.coordinator_server import CoordinatorServer +from tensorrt_llm.serve.disagg_coordinator import (CoordinatorClient, + DisaggCoordinatorService) +from tensorrt_llm.serve.openai_protocol import (CompletionRequest, + DisaggregatedParams) + + +@pytest.fixture(autouse=True) +def _reset_prometheus_registry(): + """Each coordinator builds role-prefixed Prometheus counters in the global + default registry (via its readiness OpenAIHttpClients). In production the + coordinator is a single process; here two coordinators share one pytest + process, so clear the registry between tests to avoid duplicate-timeseries + registration errors.""" + from prometheus_client import REGISTRY + yield + for collector in list(REGISTRY._collector_to_names): + try: + REGISTRY.unregister(collector) + except Exception: + pass + + +def _free_port(): + import socket + s = socket.socket() + # SO_REUSEADDR so a port left in TIME_WAIT by a sibling server in the same + # suite can be rebound immediately (closes the alloc->bind race window). + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.close() + return port + + +class _FakeWorker: + """Minimal HTTP worker: /health -> 200 (used for readiness only).""" + + def __init__(self): + self.port = _free_port() + + class Handler(BaseHTTPRequestHandler): + + def log_message(self, *a): + pass + + def do_GET(self): + if self.path == "/health": + self.send_response(200) + else: + self.send_response(404) + self.end_headers() + + self._httpd = ThreadingHTTPServer(("127.0.0.1", self.port), Handler) + self._thread = threading.Thread(target=self._httpd.serve_forever, + daemon=True) + + @property + def url(self): + return f"127.0.0.1:{self.port}" + + def __enter__(self): + self._thread.start() + return self + + def __exit__(self, *a): + self._httpd.shutdown() + + +def _make_config(ctx_urls, gen_urls, ctx_router_type, gen_router_type): + server_configs = [ + CtxGenServerConfig(type="ctx", hostname=u.split(":")[0], + port=int(u.split(":")[1])) for u in ctx_urls + ] + [ + CtxGenServerConfig(type="gen", hostname=u.split(":")[0], + port=int(u.split(":")[1])) for u in gen_urls + ] + return DisaggServerConfig( + server_configs=server_configs, + ctx_router_config=RouterConfig(type=ctx_router_type, + server_role=ServerRole.CONTEXT), + gen_router_config=RouterConfig(type=gen_router_type, + server_role=ServerRole.GENERATION)) + + +def _client_factory(router, role, max_retries=1): + from tensorrt_llm.serve.openai_client import OpenAIHttpClient + return OpenAIHttpClient(router, role, 30, max_retries) + + +class _CoordinatorThread: + """Run a CoordinatorServer (DisaggCoordinatorService) in a background thread.""" + + def __init__(self, config): + self.port = _free_port() + self.url = f"http://127.0.0.1:{self.port}" + # The coordinator builds its own owner routers from config. + self._cluster = DisaggCoordinatorService(config, _client_factory) + self._server = uvicorn.Server( + uvicorn.Config(CoordinatorServer(self._cluster).app, + host="127.0.0.1", port=self.port, + log_level="warning")) + self._thread = threading.Thread(target=self._server.run, daemon=True) + + def __enter__(self): + self._thread.start() + for _ in range(100): + if self._server.started: + break + time.sleep(0.1) + return self + + def __exit__(self, *a): + self._server.should_exit = True + self._thread.join(timeout=10) + + +async def _wait_coord_ready(url, timeout_s=30.0): + deadline = time.time() + timeout_s + async with aiohttp.ClientSession() as sess: + while time.time() < deadline: + try: + async with sess.get(f"{url}/health", timeout=1) as r: + if r.status == 200: + return True + except Exception: + pass + await asyncio.sleep(0.2) + return False + + +def test_stateless_router_places_locally_in_worker(): + """A round-robin (stateless) router is NOT wrapped: the worker places locally + with the real router and never calls the coordinator.""" + from tensorrt_llm.serve.router import (CoordinatorDelegatingRouter, + RoundRobinRouter) + with _FakeWorker() as ctx0, _FakeWorker() as gen0, _FakeWorker() as gen1: + config = _make_config([ctx0.url], [gen0.url, gen1.url], + "round_robin", "round_robin") + with _CoordinatorThread(config) as coord: + assert asyncio.run(_wait_coord_ready(coord.url)), \ + "coordinator never became healthy" + + async def drive(): + remote = CoordinatorClient(coord.url, config) + # Stateless -> real local router, not a delegating proxy. + assert isinstance(remote.gen_router, RoundRobinRouter) + assert not isinstance(remote.gen_router, + CoordinatorDelegatingRouter) + picks = [] + for _ in range(4): + req = CompletionRequest(model="m", prompt="hello") + server, _info = await remote.gen_router.get_next_server(req) + picks.append(server) + await remote.gen_router.finish_request(req) + await remote.stop() + return picks + + picks = asyncio.run(drive()) + assert set(picks) == {gen0.url, gen1.url}, \ + f"local round-robin should hit both gen workers, got {picks}" + + +def test_conversation_coordinator_sticky_by_conv_id(): + """Same conversation_id sticks to one gen worker; a stateful (conversation) + router delegates placement to the coordinator via /select.""" + from tensorrt_llm.serve.router import CoordinatorDelegatingRouter + with _FakeWorker() as ctx0, _FakeWorker() as gen0, _FakeWorker() as gen1: + config = _make_config([ctx0.url], [gen0.url, gen1.url], + "round_robin", "conversation") + with _CoordinatorThread(config) as coord: + assert asyncio.run(_wait_coord_ready(coord.url)) + + def _req(conv_id): + return CompletionRequest( + model="m", prompt="hi", + disaggregated_params=DisaggregatedParams( + request_type="context_only", conversation_id=conv_id)) + + async def drive(): + remote = CoordinatorClient(coord.url, config) + # Stateful -> wrapped in a coordinator-delegating router. + assert isinstance(remote.gen_router, + CoordinatorDelegatingRouter) + assert await remote.is_ready() is True + first, _ = await remote.gen_router.get_next_server(_req("conv-A")) + # Repeated conv-A requests must land on the same worker. + repeats = [] + for _ in range(3): + s, _ = await remote.gen_router.get_next_server(_req("conv-A")) + repeats.append(s) + await remote.stop() + return first, repeats + + first, repeats = asyncio.run(drive()) + assert all(s == first for s in repeats), \ + f"conv-A must be sticky, got first={first} repeats={repeats}" + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v", "-s"])) diff --git a/tests/unittest/disaggregated/test_openai_disagg_service.py b/tests/unittest/disaggregated/test_openai_disagg_service.py index 8a99100b54d4..cf6141cecd88 100644 --- a/tests/unittest/disaggregated/test_openai_disagg_service.py +++ b/tests/unittest/disaggregated/test_openai_disagg_service.py @@ -29,6 +29,7 @@ ServerRole, ) from tensorrt_llm.serve.disagg_auto_scaling import DisaggClusterManager, WorkerInfo +from tensorrt_llm.serve.disagg_coordinator import DisaggCoordinatorService from tensorrt_llm.serve.openai_disagg_service import OpenAIDisaggregatedService from tensorrt_llm.serve.openai_protocol import ( ChatCompletionRequest, @@ -57,11 +58,20 @@ def _client_factory(*_args, **_kwargs): def _make_service(schedule_style: str) -> OpenAIDisaggregatedService: config = DisaggServerConfig(server_configs=[], schedule_style=schedule_style) + # The coordinator builds its own (empty) routers from config; override them + # with mocks so tests can stub placement / readiness directly. + cluster = DisaggCoordinatorService(config, client_factory=_client_factory) ctx_router = AsyncMock(spec=Router) gen_router = AsyncMock(spec=Router) - return OpenAIDisaggregatedService( - config, ctx_router, gen_router, client_factory=_client_factory + cluster._ctx_router = ctx_router + cluster._gen_router = gen_router + service = OpenAIDisaggregatedService( + config, cluster, client_factory=_client_factory ) + # Convenience handles for tests that stub placement / readiness directly. + service._ctx_router = ctx_router + service._gen_router = gen_router + return service def _make_completion_response( @@ -207,16 +217,18 @@ async def test_is_ready_waits_for_router_preparation(): ), AsyncMock(), ) - service._disagg_cluster_manager = cluster_manager + # Readiness now lives on the DisaggCoordinatorService the service holds. + local = service._cluster + local._disagg_cluster_manager = cluster_manager cluster_manager._current_ctx_workers["ctx"] = WorkerInfo( worker_id="ctx", role=ServerRole.CONTEXT ) - service._ctx_router = SimpleNamespace(num_prepared_servers=0) - service._gen_router = SimpleNamespace(num_prepared_servers=1) + local._ctx_router = SimpleNamespace(num_prepared_servers=0) + local._gen_router = SimpleNamespace(num_prepared_servers=1) assert await service.is_ready() is False - service._ctx_router.num_prepared_servers = 1 + local._ctx_router.num_prepared_servers = 1 assert await service.is_ready() is False cluster_manager._current_gen_workers["gen"] = WorkerInfo( From 027ba10a236c67bdf2874022a6867edb4025be03 Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:02:54 -0700 Subject: [PATCH 2/7] generate request id from coordinator Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- tensorrt_llm/serve/coordinator_server.py | 6 ++++ tensorrt_llm/serve/disagg_coordinator.py | 25 ++++++++++++++- tensorrt_llm/serve/openai_client.py | 6 ++-- tensorrt_llm/serve/openai_disagg_server.py | 8 ++--- tensorrt_llm/serve/router.py | 20 ++++++++---- .../disaggregated/test_coordinator_worker.py | 32 +++++++++++++++++++ .../test_disagg_openai_client.py | 5 ++- 7 files changed, 87 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/serve/coordinator_server.py b/tensorrt_llm/serve/coordinator_server.py index a7c7750d4d36..02b72fa0cd58 100644 --- a/tensorrt_llm/serve/coordinator_server.py +++ b/tensorrt_llm/serve/coordinator_server.py @@ -60,6 +60,8 @@ async def lifespan(app: FastAPI): self.app = FastAPI(lifespan=lifespan) self.app.add_api_route("/select", self.select, methods=["POST"]) self.app.add_api_route("/finish", self.finish, methods=["POST"]) + self.app.add_api_route("/disagg_request_id", + self.disagg_request_id, methods=["GET"]) self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"]) self.app.add_api_route("/health", self.health, methods=["GET"]) @@ -98,6 +100,10 @@ async def finish(self, raw_req: Request) -> Response: body.get("success", True)) return JSONResponse(content={}) + async def disagg_request_id(self) -> Response: + return JSONResponse(content={"disagg_request_id": + await self._coordinator.get_disagg_request_id()}) + async def cluster_info(self) -> Response: return JSONResponse(content=await self._coordinator.cluster_info()) diff --git a/tensorrt_llm/serve/disagg_coordinator.py b/tensorrt_llm/serve/disagg_coordinator.py index 7ef705fb32fc..05f41c1f0d4e 100644 --- a/tensorrt_llm/serve/disagg_coordinator.py +++ b/tensorrt_llm/serve/disagg_coordinator.py @@ -42,7 +42,8 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, MetadataServerConfig, ServerRole, - get_ctx_gen_server_addrs) + get_ctx_gen_server_addrs, + get_global_disagg_request_id) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (ClusterStorage, WatchEventType, create_cluster_storage) @@ -93,6 +94,12 @@ async def stop(self) -> None: ... + + @abstractmethod + async def get_disagg_request_id(self) -> int: + ... + + class DisaggCoordinatorService(DisaggCoordinator): """In-process coordinator owning the ctx/gen routers and all cluster state. @@ -162,9 +169,15 @@ def set_clients(self, ctx_client: OpenAIClient, async def select(self, role: str, routing_key, req_id, exclude_server: Optional[str]) -> Tuple[str, dict, Optional[str]]: router = self._router_for_role(role) + if req_id is None: + # The coordinator owns IDs absent from generation requests. + req_id = get_global_disagg_request_id(self._config.node_id) return await router.get_next_server_by_key(routing_key, req_id=req_id, exclude_server=exclude_server) + async def get_disagg_request_id(self) -> int: + return get_global_disagg_request_id(self._config.node_id) + async def finish(self, role: str, req_id, success: bool = True) -> None: await self._router_for_role(role).finish_request_by_id(req_id, success) @@ -396,6 +409,16 @@ async def _await_coordinator(self) -> Dict[str, Any]: f"{self._remote_url} (attempt {attempt}, {last_err})") await asyncio.sleep(2.0) + async def get_disagg_request_id(self) -> int: + async with self.session.get( + f"{self._remote_url}/disagg_request_id", + timeout=self._request_timeout_s) as resp: + if resp.status != 200: + raise RuntimeError( + f"coordinator /disagg_request_id returned {resp.status}") + body = await resp.json() + return body["disagg_request_id"] + async def is_ready(self) -> bool: try: async with self.session.get( diff --git a/tensorrt_llm/serve/openai_client.py b/tensorrt_llm/serve/openai_client.py index 265a771daf6c..6552cd2ab25e 100644 --- a/tensorrt_llm/serve/openai_client.py +++ b/tensorrt_llm/serve/openai_client.py @@ -16,7 +16,7 @@ import asyncio import traceback from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Tuple, Type import aiohttp @@ -103,7 +103,7 @@ def __init__( max_retries: int = 1, retry_interval_sec: int = 1, session: Optional[aiohttp.ClientSession] = None, - disagg_id_generator: Optional[Callable[[], int]] = None, + disagg_id_generator: Optional[Callable[[], Awaitable[int]]] = None, ): self._router = router self._role = role @@ -180,7 +180,7 @@ async def _post_with_retry( if attempt > 0 and self._disagg_id_generator is not None: dp = getattr(request, "disaggregated_params", None) if dp is not None and getattr(dp, "disagg_request_id", None) is not None: - dp.disagg_request_id = self._disagg_id_generator() + dp.disagg_request_id = await self._disagg_id_generator() # Serialize once on the orchestrator's single event-loop thread: # model_dump_json (pydantic-core) is ~2.3x faster than # model_dump(mode="json") + aiohttp json= (json.dumps). Decodes to diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 06c43efbad4f..71fd54e62291 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -32,8 +32,7 @@ from tensorrt_llm.executor.executor import CppExecutorError from tensorrt_llm.llmapi import tracing from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, - MetadataServerConfig, ServerRole, - get_global_disagg_request_id) + MetadataServerConfig, ServerRole) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer, create_cluster_storage) @@ -154,10 +153,11 @@ async def validation_exception_handler(_, exc): self.register_routes() def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) -> OpenAIClient: - node_id = self._config.node_id + async def disagg_id_generator(): + return await self._coordinator.get_disagg_request_id() client = OpenAIHttpClient( router, role, self._req_timeout_secs, max_retries, - disagg_id_generator=lambda: get_global_disagg_request_id(node_id)) + disagg_id_generator=disagg_id_generator) self._perf_metrics_collector.add_client(client) return client diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 153b4693e84e..2040e66e197b 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -490,6 +490,7 @@ async def get_next_server( exclude_server: Optional[str] = None) -> tuple[str, dict]: '''Select server by request and return some intermediate information, exclude_server is a server to exclude from the selection''' + @abstractmethod async def finish_request(self, request: OpenAIRequest, @@ -1631,11 +1632,11 @@ def _request_id(self, request: OpenAIRequest) -> int: routing, so a missing id is a bug -- assert, don't paper over.""" dp = request.disaggregated_params assert dp is not None, "delegated routing requires disaggregated_params" - rid = (dp.disagg_request_id if self._role == "context" - else dp.ctx_request_id) - assert rid is not None, ( - f"delegated {self._role} routing requires a disagg request id " - f"(disagg_request_id/ctx_request_id) on the request") + rid = dp.disagg_request_id + if self._role != "generation": + assert rid is not None, ( + f"delegated {self._role} routing requires a disagg request id " + f"(disagg_request_id/ctx_request_id) on the request") return rid async def get_next_server( @@ -1645,8 +1646,10 @@ async def get_next_server( key = self._local.routing_key(request) # Send the disagg request id as the sole cross-process request key; the # coordinator keys its pending-request state by it for /finish. + req_id = (None if self._role == "generation" + else self._request_id(request)) payload = {"role": self._role, "routing_key": key, - "req_id": self._request_id(request), + "req_id": req_id, "exclude_server": exclude_server} async with self.session.post( f"{self._coordinator_url}/select", json=payload, @@ -1657,6 +1660,11 @@ async def get_next_server( f"{await resp.text()}") body = await resp.json() info = body.get("info") or {} + if self._role == "generation": + coordinator_req_id = body.get("req_id") + if coordinator_req_id is None: + raise ValueError("coordinator did not return a generation disagg_request_id") + request.disaggregated_params.disagg_request_id = coordinator_req_id return body["server"], info async def finish_request(self, diff --git a/tests/unittest/disaggregated/test_coordinator_worker.py b/tests/unittest/disaggregated/test_coordinator_worker.py index b7fd0ec9be5a..2c20be058197 100644 --- a/tests/unittest/disaggregated/test_coordinator_worker.py +++ b/tests/unittest/disaggregated/test_coordinator_worker.py @@ -243,5 +243,37 @@ async def drive(): f"conv-A must be sticky, got first={first} repeats={repeats}" +def test_coordinator_owns_generation_disagg_request_id(): + """Generation routing receives its request ID from the coordinator.""" + from tensorrt_llm.serve.router import CoordinatorDelegatingRouter + + with _FakeWorker() as ctx0, _FakeWorker() as gen0: + config = _make_config([ctx0.url], [gen0.url], "round_robin", + "conversation") + with _CoordinatorThread(config) as coord: + assert asyncio.run(_wait_coord_ready(coord.url)) + + async def drive(): + remote = CoordinatorClient(coord.url, config) + assert isinstance(remote.gen_router, + CoordinatorDelegatingRouter) + request = CompletionRequest( + model="m", + prompt="hello", + disaggregated_params=DisaggregatedParams( + request_type="generation_only", + ctx_request_id=123, + disagg_request_id=None, + conversation_id="conv-A")) + await remote.gen_router.get_next_server(request) + assigned_id = request.disaggregated_params.disagg_request_id + await remote.gen_router.finish_request(request) + await remote.stop() + return assigned_id + + assigned_id = asyncio.run(drive()) + assert assigned_id is not None and assigned_id != 123 + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__, "-v", "-s"])) diff --git a/tests/unittest/disaggregated/test_disagg_openai_client.py b/tests/unittest/disaggregated/test_disagg_openai_client.py index edc09809732b..9093cd9e2302 100644 --- a/tests/unittest/disaggregated/test_disagg_openai_client.py +++ b/tests/unittest/disaggregated/test_disagg_openai_client.py @@ -376,7 +376,10 @@ def _make_client(self, session, **kwargs): async def test_retry_regenerates_disagg_id(self): session = AsyncMock(spec=aiohttp.ClientSession) ids = iter(range(1000, 2000)) - client = self._make_client(session, disagg_id_generator=lambda: next(ids)) + async def next_id(): + return next(ids) + + client = self._make_client(session, disagg_id_generator=next_id) session.post.side_effect = [ aiohttp.ClientError("transient"), From 229770143b5d9273a4f27058f5da3e883c75e63c Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:46:16 -0700 Subject: [PATCH 3/7] Revert " generate request id from coordinator" This reverts commit 027ba10a236c67bdf2874022a6867edb4025be03. --- tensorrt_llm/serve/coordinator_server.py | 6 ---- tensorrt_llm/serve/disagg_coordinator.py | 25 +-------------- tensorrt_llm/serve/openai_client.py | 6 ++-- tensorrt_llm/serve/openai_disagg_server.py | 8 ++--- tensorrt_llm/serve/router.py | 20 ++++-------- .../disaggregated/test_coordinator_worker.py | 32 ------------------- .../test_disagg_openai_client.py | 5 +-- 7 files changed, 15 insertions(+), 87 deletions(-) diff --git a/tensorrt_llm/serve/coordinator_server.py b/tensorrt_llm/serve/coordinator_server.py index 02b72fa0cd58..a7c7750d4d36 100644 --- a/tensorrt_llm/serve/coordinator_server.py +++ b/tensorrt_llm/serve/coordinator_server.py @@ -60,8 +60,6 @@ async def lifespan(app: FastAPI): self.app = FastAPI(lifespan=lifespan) self.app.add_api_route("/select", self.select, methods=["POST"]) self.app.add_api_route("/finish", self.finish, methods=["POST"]) - self.app.add_api_route("/disagg_request_id", - self.disagg_request_id, methods=["GET"]) self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"]) self.app.add_api_route("/health", self.health, methods=["GET"]) @@ -100,10 +98,6 @@ async def finish(self, raw_req: Request) -> Response: body.get("success", True)) return JSONResponse(content={}) - async def disagg_request_id(self) -> Response: - return JSONResponse(content={"disagg_request_id": - await self._coordinator.get_disagg_request_id()}) - async def cluster_info(self) -> Response: return JSONResponse(content=await self._coordinator.cluster_info()) diff --git a/tensorrt_llm/serve/disagg_coordinator.py b/tensorrt_llm/serve/disagg_coordinator.py index 05f41c1f0d4e..7ef705fb32fc 100644 --- a/tensorrt_llm/serve/disagg_coordinator.py +++ b/tensorrt_llm/serve/disagg_coordinator.py @@ -42,8 +42,7 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, MetadataServerConfig, ServerRole, - get_ctx_gen_server_addrs, - get_global_disagg_request_id) + get_ctx_gen_server_addrs) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (ClusterStorage, WatchEventType, create_cluster_storage) @@ -94,12 +93,6 @@ async def stop(self) -> None: ... - - @abstractmethod - async def get_disagg_request_id(self) -> int: - ... - - class DisaggCoordinatorService(DisaggCoordinator): """In-process coordinator owning the ctx/gen routers and all cluster state. @@ -169,15 +162,9 @@ def set_clients(self, ctx_client: OpenAIClient, async def select(self, role: str, routing_key, req_id, exclude_server: Optional[str]) -> Tuple[str, dict, Optional[str]]: router = self._router_for_role(role) - if req_id is None: - # The coordinator owns IDs absent from generation requests. - req_id = get_global_disagg_request_id(self._config.node_id) return await router.get_next_server_by_key(routing_key, req_id=req_id, exclude_server=exclude_server) - async def get_disagg_request_id(self) -> int: - return get_global_disagg_request_id(self._config.node_id) - async def finish(self, role: str, req_id, success: bool = True) -> None: await self._router_for_role(role).finish_request_by_id(req_id, success) @@ -409,16 +396,6 @@ async def _await_coordinator(self) -> Dict[str, Any]: f"{self._remote_url} (attempt {attempt}, {last_err})") await asyncio.sleep(2.0) - async def get_disagg_request_id(self) -> int: - async with self.session.get( - f"{self._remote_url}/disagg_request_id", - timeout=self._request_timeout_s) as resp: - if resp.status != 200: - raise RuntimeError( - f"coordinator /disagg_request_id returned {resp.status}") - body = await resp.json() - return body["disagg_request_id"] - async def is_ready(self) -> bool: try: async with self.session.get( diff --git a/tensorrt_llm/serve/openai_client.py b/tensorrt_llm/serve/openai_client.py index 6552cd2ab25e..265a771daf6c 100644 --- a/tensorrt_llm/serve/openai_client.py +++ b/tensorrt_llm/serve/openai_client.py @@ -16,7 +16,7 @@ import asyncio import traceback from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Type import aiohttp @@ -103,7 +103,7 @@ def __init__( max_retries: int = 1, retry_interval_sec: int = 1, session: Optional[aiohttp.ClientSession] = None, - disagg_id_generator: Optional[Callable[[], Awaitable[int]]] = None, + disagg_id_generator: Optional[Callable[[], int]] = None, ): self._router = router self._role = role @@ -180,7 +180,7 @@ async def _post_with_retry( if attempt > 0 and self._disagg_id_generator is not None: dp = getattr(request, "disaggregated_params", None) if dp is not None and getattr(dp, "disagg_request_id", None) is not None: - dp.disagg_request_id = await self._disagg_id_generator() + dp.disagg_request_id = self._disagg_id_generator() # Serialize once on the orchestrator's single event-loop thread: # model_dump_json (pydantic-core) is ~2.3x faster than # model_dump(mode="json") + aiohttp json= (json.dumps). Decodes to diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 71fd54e62291..06c43efbad4f 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -32,7 +32,8 @@ from tensorrt_llm.executor.executor import CppExecutorError from tensorrt_llm.llmapi import tracing from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, - MetadataServerConfig, ServerRole) + MetadataServerConfig, ServerRole, + get_global_disagg_request_id) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer, create_cluster_storage) @@ -153,11 +154,10 @@ async def validation_exception_handler(_, exc): self.register_routes() def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) -> OpenAIClient: - async def disagg_id_generator(): - return await self._coordinator.get_disagg_request_id() + node_id = self._config.node_id client = OpenAIHttpClient( router, role, self._req_timeout_secs, max_retries, - disagg_id_generator=disagg_id_generator) + disagg_id_generator=lambda: get_global_disagg_request_id(node_id)) self._perf_metrics_collector.add_client(client) return client diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 2040e66e197b..153b4693e84e 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -490,7 +490,6 @@ async def get_next_server( exclude_server: Optional[str] = None) -> tuple[str, dict]: '''Select server by request and return some intermediate information, exclude_server is a server to exclude from the selection''' - @abstractmethod async def finish_request(self, request: OpenAIRequest, @@ -1632,11 +1631,11 @@ def _request_id(self, request: OpenAIRequest) -> int: routing, so a missing id is a bug -- assert, don't paper over.""" dp = request.disaggregated_params assert dp is not None, "delegated routing requires disaggregated_params" - rid = dp.disagg_request_id - if self._role != "generation": - assert rid is not None, ( - f"delegated {self._role} routing requires a disagg request id " - f"(disagg_request_id/ctx_request_id) on the request") + rid = (dp.disagg_request_id if self._role == "context" + else dp.ctx_request_id) + assert rid is not None, ( + f"delegated {self._role} routing requires a disagg request id " + f"(disagg_request_id/ctx_request_id) on the request") return rid async def get_next_server( @@ -1646,10 +1645,8 @@ async def get_next_server( key = self._local.routing_key(request) # Send the disagg request id as the sole cross-process request key; the # coordinator keys its pending-request state by it for /finish. - req_id = (None if self._role == "generation" - else self._request_id(request)) payload = {"role": self._role, "routing_key": key, - "req_id": req_id, + "req_id": self._request_id(request), "exclude_server": exclude_server} async with self.session.post( f"{self._coordinator_url}/select", json=payload, @@ -1660,11 +1657,6 @@ async def get_next_server( f"{await resp.text()}") body = await resp.json() info = body.get("info") or {} - if self._role == "generation": - coordinator_req_id = body.get("req_id") - if coordinator_req_id is None: - raise ValueError("coordinator did not return a generation disagg_request_id") - request.disaggregated_params.disagg_request_id = coordinator_req_id return body["server"], info async def finish_request(self, diff --git a/tests/unittest/disaggregated/test_coordinator_worker.py b/tests/unittest/disaggregated/test_coordinator_worker.py index 2c20be058197..b7fd0ec9be5a 100644 --- a/tests/unittest/disaggregated/test_coordinator_worker.py +++ b/tests/unittest/disaggregated/test_coordinator_worker.py @@ -243,37 +243,5 @@ async def drive(): f"conv-A must be sticky, got first={first} repeats={repeats}" -def test_coordinator_owns_generation_disagg_request_id(): - """Generation routing receives its request ID from the coordinator.""" - from tensorrt_llm.serve.router import CoordinatorDelegatingRouter - - with _FakeWorker() as ctx0, _FakeWorker() as gen0: - config = _make_config([ctx0.url], [gen0.url], "round_robin", - "conversation") - with _CoordinatorThread(config) as coord: - assert asyncio.run(_wait_coord_ready(coord.url)) - - async def drive(): - remote = CoordinatorClient(coord.url, config) - assert isinstance(remote.gen_router, - CoordinatorDelegatingRouter) - request = CompletionRequest( - model="m", - prompt="hello", - disaggregated_params=DisaggregatedParams( - request_type="generation_only", - ctx_request_id=123, - disagg_request_id=None, - conversation_id="conv-A")) - await remote.gen_router.get_next_server(request) - assigned_id = request.disaggregated_params.disagg_request_id - await remote.gen_router.finish_request(request) - await remote.stop() - return assigned_id - - assigned_id = asyncio.run(drive()) - assert assigned_id is not None and assigned_id != 123 - - if __name__ == "__main__": raise SystemExit(pytest.main([__file__, "-v", "-s"])) diff --git a/tests/unittest/disaggregated/test_disagg_openai_client.py b/tests/unittest/disaggregated/test_disagg_openai_client.py index 9093cd9e2302..edc09809732b 100644 --- a/tests/unittest/disaggregated/test_disagg_openai_client.py +++ b/tests/unittest/disaggregated/test_disagg_openai_client.py @@ -376,10 +376,7 @@ def _make_client(self, session, **kwargs): async def test_retry_regenerates_disagg_id(self): session = AsyncMock(spec=aiohttp.ClientSession) ids = iter(range(1000, 2000)) - async def next_id(): - return next(ids) - - client = self._make_client(session, disagg_id_generator=next_id) + client = self._make_client(session, disagg_id_generator=lambda: next(ids)) session.post.side_effect = [ aiohttp.ClientError("transient"), From 4effcbeb31766420449b655c9144d488cb264cca Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:51:12 -0700 Subject: [PATCH 4/7] Reapply " generate request id from coordinator" This reverts commit 229770143b5d9273a4f27058f5da3e883c75e63c. --- tensorrt_llm/serve/coordinator_server.py | 6 ++++ tensorrt_llm/serve/disagg_coordinator.py | 25 ++++++++++++++- tensorrt_llm/serve/openai_client.py | 6 ++-- tensorrt_llm/serve/openai_disagg_server.py | 8 ++--- tensorrt_llm/serve/router.py | 20 ++++++++---- .../disaggregated/test_coordinator_worker.py | 32 +++++++++++++++++++ .../test_disagg_openai_client.py | 5 ++- 7 files changed, 87 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/serve/coordinator_server.py b/tensorrt_llm/serve/coordinator_server.py index a7c7750d4d36..02b72fa0cd58 100644 --- a/tensorrt_llm/serve/coordinator_server.py +++ b/tensorrt_llm/serve/coordinator_server.py @@ -60,6 +60,8 @@ async def lifespan(app: FastAPI): self.app = FastAPI(lifespan=lifespan) self.app.add_api_route("/select", self.select, methods=["POST"]) self.app.add_api_route("/finish", self.finish, methods=["POST"]) + self.app.add_api_route("/disagg_request_id", + self.disagg_request_id, methods=["GET"]) self.app.add_api_route("/cluster_info", self.cluster_info, methods=["GET"]) self.app.add_api_route("/health", self.health, methods=["GET"]) @@ -98,6 +100,10 @@ async def finish(self, raw_req: Request) -> Response: body.get("success", True)) return JSONResponse(content={}) + async def disagg_request_id(self) -> Response: + return JSONResponse(content={"disagg_request_id": + await self._coordinator.get_disagg_request_id()}) + async def cluster_info(self) -> Response: return JSONResponse(content=await self._coordinator.cluster_info()) diff --git a/tensorrt_llm/serve/disagg_coordinator.py b/tensorrt_llm/serve/disagg_coordinator.py index 7ef705fb32fc..05f41c1f0d4e 100644 --- a/tensorrt_llm/serve/disagg_coordinator.py +++ b/tensorrt_llm/serve/disagg_coordinator.py @@ -42,7 +42,8 @@ from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, MetadataServerConfig, ServerRole, - get_ctx_gen_server_addrs) + get_ctx_gen_server_addrs, + get_global_disagg_request_id) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (ClusterStorage, WatchEventType, create_cluster_storage) @@ -93,6 +94,12 @@ async def stop(self) -> None: ... + + @abstractmethod + async def get_disagg_request_id(self) -> int: + ... + + class DisaggCoordinatorService(DisaggCoordinator): """In-process coordinator owning the ctx/gen routers and all cluster state. @@ -162,9 +169,15 @@ def set_clients(self, ctx_client: OpenAIClient, async def select(self, role: str, routing_key, req_id, exclude_server: Optional[str]) -> Tuple[str, dict, Optional[str]]: router = self._router_for_role(role) + if req_id is None: + # The coordinator owns IDs absent from generation requests. + req_id = get_global_disagg_request_id(self._config.node_id) return await router.get_next_server_by_key(routing_key, req_id=req_id, exclude_server=exclude_server) + async def get_disagg_request_id(self) -> int: + return get_global_disagg_request_id(self._config.node_id) + async def finish(self, role: str, req_id, success: bool = True) -> None: await self._router_for_role(role).finish_request_by_id(req_id, success) @@ -396,6 +409,16 @@ async def _await_coordinator(self) -> Dict[str, Any]: f"{self._remote_url} (attempt {attempt}, {last_err})") await asyncio.sleep(2.0) + async def get_disagg_request_id(self) -> int: + async with self.session.get( + f"{self._remote_url}/disagg_request_id", + timeout=self._request_timeout_s) as resp: + if resp.status != 200: + raise RuntimeError( + f"coordinator /disagg_request_id returned {resp.status}") + body = await resp.json() + return body["disagg_request_id"] + async def is_ready(self) -> bool: try: async with self.session.get( diff --git a/tensorrt_llm/serve/openai_client.py b/tensorrt_llm/serve/openai_client.py index 265a771daf6c..6552cd2ab25e 100644 --- a/tensorrt_llm/serve/openai_client.py +++ b/tensorrt_llm/serve/openai_client.py @@ -16,7 +16,7 @@ import asyncio import traceback from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Tuple, Type +from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Optional, Tuple, Type import aiohttp @@ -103,7 +103,7 @@ def __init__( max_retries: int = 1, retry_interval_sec: int = 1, session: Optional[aiohttp.ClientSession] = None, - disagg_id_generator: Optional[Callable[[], int]] = None, + disagg_id_generator: Optional[Callable[[], Awaitable[int]]] = None, ): self._router = router self._role = role @@ -180,7 +180,7 @@ async def _post_with_retry( if attempt > 0 and self._disagg_id_generator is not None: dp = getattr(request, "disaggregated_params", None) if dp is not None and getattr(dp, "disagg_request_id", None) is not None: - dp.disagg_request_id = self._disagg_id_generator() + dp.disagg_request_id = await self._disagg_id_generator() # Serialize once on the orchestrator's single event-loop thread: # model_dump_json (pydantic-core) is ~2.3x faster than # model_dump(mode="json") + aiohttp json= (json.dumps). Decodes to diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 06c43efbad4f..71fd54e62291 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -32,8 +32,7 @@ from tensorrt_llm.executor.executor import CppExecutorError from tensorrt_llm.llmapi import tracing from tensorrt_llm.llmapi.disagg_utils import (DisaggServerConfig, - MetadataServerConfig, ServerRole, - get_global_disagg_request_id) + MetadataServerConfig, ServerRole) from tensorrt_llm.logger import logger from tensorrt_llm.serve.cluster_storage import (HttpClusterStorageServer, create_cluster_storage) @@ -154,10 +153,11 @@ async def validation_exception_handler(_, exc): self.register_routes() def _create_client(self, router: Router, role: ServerRole, max_retries: int = 1) -> OpenAIClient: - node_id = self._config.node_id + async def disagg_id_generator(): + return await self._coordinator.get_disagg_request_id() client = OpenAIHttpClient( router, role, self._req_timeout_secs, max_retries, - disagg_id_generator=lambda: get_global_disagg_request_id(node_id)) + disagg_id_generator=disagg_id_generator) self._perf_metrics_collector.add_client(client) return client diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 153b4693e84e..2040e66e197b 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -490,6 +490,7 @@ async def get_next_server( exclude_server: Optional[str] = None) -> tuple[str, dict]: '''Select server by request and return some intermediate information, exclude_server is a server to exclude from the selection''' + @abstractmethod async def finish_request(self, request: OpenAIRequest, @@ -1631,11 +1632,11 @@ def _request_id(self, request: OpenAIRequest) -> int: routing, so a missing id is a bug -- assert, don't paper over.""" dp = request.disaggregated_params assert dp is not None, "delegated routing requires disaggregated_params" - rid = (dp.disagg_request_id if self._role == "context" - else dp.ctx_request_id) - assert rid is not None, ( - f"delegated {self._role} routing requires a disagg request id " - f"(disagg_request_id/ctx_request_id) on the request") + rid = dp.disagg_request_id + if self._role != "generation": + assert rid is not None, ( + f"delegated {self._role} routing requires a disagg request id " + f"(disagg_request_id/ctx_request_id) on the request") return rid async def get_next_server( @@ -1645,8 +1646,10 @@ async def get_next_server( key = self._local.routing_key(request) # Send the disagg request id as the sole cross-process request key; the # coordinator keys its pending-request state by it for /finish. + req_id = (None if self._role == "generation" + else self._request_id(request)) payload = {"role": self._role, "routing_key": key, - "req_id": self._request_id(request), + "req_id": req_id, "exclude_server": exclude_server} async with self.session.post( f"{self._coordinator_url}/select", json=payload, @@ -1657,6 +1660,11 @@ async def get_next_server( f"{await resp.text()}") body = await resp.json() info = body.get("info") or {} + if self._role == "generation": + coordinator_req_id = body.get("req_id") + if coordinator_req_id is None: + raise ValueError("coordinator did not return a generation disagg_request_id") + request.disaggregated_params.disagg_request_id = coordinator_req_id return body["server"], info async def finish_request(self, diff --git a/tests/unittest/disaggregated/test_coordinator_worker.py b/tests/unittest/disaggregated/test_coordinator_worker.py index b7fd0ec9be5a..2c20be058197 100644 --- a/tests/unittest/disaggregated/test_coordinator_worker.py +++ b/tests/unittest/disaggregated/test_coordinator_worker.py @@ -243,5 +243,37 @@ async def drive(): f"conv-A must be sticky, got first={first} repeats={repeats}" +def test_coordinator_owns_generation_disagg_request_id(): + """Generation routing receives its request ID from the coordinator.""" + from tensorrt_llm.serve.router import CoordinatorDelegatingRouter + + with _FakeWorker() as ctx0, _FakeWorker() as gen0: + config = _make_config([ctx0.url], [gen0.url], "round_robin", + "conversation") + with _CoordinatorThread(config) as coord: + assert asyncio.run(_wait_coord_ready(coord.url)) + + async def drive(): + remote = CoordinatorClient(coord.url, config) + assert isinstance(remote.gen_router, + CoordinatorDelegatingRouter) + request = CompletionRequest( + model="m", + prompt="hello", + disaggregated_params=DisaggregatedParams( + request_type="generation_only", + ctx_request_id=123, + disagg_request_id=None, + conversation_id="conv-A")) + await remote.gen_router.get_next_server(request) + assigned_id = request.disaggregated_params.disagg_request_id + await remote.gen_router.finish_request(request) + await remote.stop() + return assigned_id + + assigned_id = asyncio.run(drive()) + assert assigned_id is not None and assigned_id != 123 + + if __name__ == "__main__": raise SystemExit(pytest.main([__file__, "-v", "-s"])) diff --git a/tests/unittest/disaggregated/test_disagg_openai_client.py b/tests/unittest/disaggregated/test_disagg_openai_client.py index edc09809732b..9093cd9e2302 100644 --- a/tests/unittest/disaggregated/test_disagg_openai_client.py +++ b/tests/unittest/disaggregated/test_disagg_openai_client.py @@ -376,7 +376,10 @@ def _make_client(self, session, **kwargs): async def test_retry_regenerates_disagg_id(self): session = AsyncMock(spec=aiohttp.ClientSession) ids = iter(range(1000, 2000)) - client = self._make_client(session, disagg_id_generator=lambda: next(ids)) + async def next_id(): + return next(ids) + + client = self._make_client(session, disagg_id_generator=next_id) session.post.side_effect = [ aiohttp.ClientError("transient"), From e197aeaa65999aa98da1ac78bda54fb521632b7b Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Thu, 2 Jul 2026 22:59:18 -0700 Subject: [PATCH 5/7] remove llm_id changes Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- tensorrt_llm/llmapi/llm.py | 16 +++++++--------- tensorrt_llm/llmapi/llm_args.py | 6 ------ tensorrt_llm/serve/openai_server.py | 3 --- 3 files changed, 7 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 727152fdf4fe..995298b3060b 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -169,10 +169,7 @@ def __init__(self, self._executor_cls = kwargs.pop("executor_cls", GenerationExecutor) self._orchestrator_type = kwargs.get("orchestrator_type", None) - hostname = socket.gethostname() - pid = os.getpid() - timestamp = int(time.time() * 1000) - self._llm_id = f"{hostname}-{pid}-{timestamp}" + self._llm_id = None self._disaggregated_params: Optional[dict] = None log_level = logger.level @@ -221,8 +218,6 @@ def __init__(self, revision=revision, tokenizer_revision=tokenizer_revision, **kwargs) - if hasattr(self.args, 'llm_id'): - self.args.llm_id = self._llm_id except Exception as e: logger.error( @@ -277,9 +272,6 @@ def __init__(self, self.llm_build_stats = LlmBuildStats() self._build_model() - if self._executor is not None: - self._executor.llm_id = self._llm_id - except Exception: if self.mpi_session is not None: self.mpi_session.shutdown() @@ -319,6 +311,12 @@ def __init__(self, @property @set_api_status("beta") def llm_id(self) -> str: + if self._llm_id is None: + hostname = socket.gethostname() + pid = os.getpid() + timestamp = int(time.time() * 1000) + self._llm_id = f"{hostname}-{pid}-{timestamp}" + return self._llm_id @property diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 124f328af867..11ff3e3acc4b 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -4230,12 +4230,6 @@ class TorchLlmArgs(BaseLlmArgs): status="prototype", ) - llm_id: Optional[str] = Field( - default=None, - description="Stable instance identifier propagated to all ranks via MPI broadcast.", - status="prototype", - ) - # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 8cd63f8723a4..61738cb74b5f 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -1926,9 +1926,6 @@ async def update_weights(self, async def get_server_info(self) -> JSONResponse: content = {"disaggregated_params": self.generator.disaggregated_params} - llm_id = getattr(self.generator, "llm_id", None) - if llm_id is not None: - content["worker_id"] = llm_id args = getattr(self.generator, "args", None) if args is not None: if args.max_batch_size is not None: From bad1f69bb900f67dccb5d592be2b7ccf3509b4e5 Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Fri, 3 Jul 2026 00:25:33 -0700 Subject: [PATCH 6/7] [None][fix] Disagg: issue disagg_request_id from the coordinator, not per-worker The disagg request id was minted locally in each fleet worker via get_global_disagg_request_id(node_id). With num_workers>1 all workers share the same node_id and each keeps its own counter starting at 0, so the snowflake ids (timestamp, machine_id, counter) collide across workers. The ctx->gen KV-cache transceiver keys transfers by disagg id, so colliding ids make transfers clash and never complete: the gen engine's IndexMapper fills with stuck DISAGG_GENERATION_TRANS_IN_PROGRESS requests (all slots in use), new requests can't allocate KV and retry forever, and fleet throughput collapses (~2 req/s). Both _send_disagg_request_ctx_first and _gen_first now fetch the id from the single coordinator (await self._coordinator.get_disagg_request_id()) -- owner issues in-process, delegating fleet workers fetch over HTTP (/disagg_request_id, already wired). Single issuer => globally unique ids. Also rename the service's self._cluster -> self._coordinator to match coordinator_server, and drop the now unused get_global_disagg_request_id import. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- tensorrt_llm/serve/openai_disagg_service.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index 69819bf3be62..6040c8d8f770 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -22,7 +22,6 @@ ConditionalDisaggConfig, DisaggServerConfig, ServerRole, - get_global_disagg_request_id, ) from tensorrt_llm.logger import logger from tensorrt_llm.serve.openai_client import OpenAIClient @@ -61,7 +60,7 @@ def __init__( # finish_request uniformly -- so serving is identical whether the router # is the real one (single-process) or a CoordinatorDelegatingRouter that # forwards placement to a remote coordinator (worker). - self._cluster = coordinator + self._coordinator = coordinator self._ctx_router = coordinator.ctx_router self._gen_router = coordinator.gen_router self._client_factory = client_factory @@ -129,7 +128,7 @@ async def _send_disagg_request_ctx_first( need_ctx = need_ctx and not await self._check_gen_only_disagg(request) ctx_response = None gen_req = request - disagg_request_id = get_global_disagg_request_id(self._config.node_id) + disagg_request_id = await self._coordinator.get_disagg_request_id() if need_ctx: ctx_req = self._get_ctx_request(request, disagg_request_id) # ctx generator is empty @@ -417,7 +416,7 @@ async def is_ready(self) -> bool: # Per-request readiness gate for the /v1/ handlers (the server's /health # and /cluster_info hook the coordinator directly). Cluster topology # (cluster_info) is the coordinator's concern, not the request service's. - return await self._cluster.is_ready() + return await self._coordinator.is_ready() @property def conditional_disagg_config(self) -> Optional[ConditionalDisaggConfig]: @@ -435,14 +434,14 @@ async def setup(self) -> None: self._gen_router, ServerRole.GENERATION, self._config.max_retries ) - if hasattr(self._cluster, "set_clients"): - self._cluster.set_clients(self._ctx_client, self._gen_client) - await self._cluster.start() + if hasattr(self._coordinator, "set_clients"): + self._coordinator.set_clients(self._ctx_client, self._gen_client) + await self._coordinator.start() async def teardown(self) -> None: await self._ctx_client.shutdown() await self._gen_client.shutdown() - await self._cluster.stop() + await self._coordinator.stop() async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None: if ctx_response: @@ -480,7 +479,9 @@ async def _send_disagg_request_gen_first( ctx_server, gen_server = None, None ctx_server_info = None ctx_req, gen_req = None, None - disagg_request_id = get_global_disagg_request_id(self._config.node_id) + # Single-issuer disagg id (see _send_disagg_request_ctx_first): fetch from + # the coordinator so fleet workers never mint colliding ids. + disagg_request_id = await self._coordinator.get_disagg_request_id() if need_ctx: ctx_server, ctx_server_info = await self._ctx_router.get_next_server( request) From 46501daa600d1028716d2eff951ff346809b4d7a Mon Sep 17 00:00:00 2001 From: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> Date: Fri, 3 Jul 2026 02:05:05 -0700 Subject: [PATCH 7/7] [None][fix] Disagg fleet: don't re-issue the gen disagg id at routing time CoordinatorDelegatingRouter.get_next_server was overwriting the generation request's disagg_request_id with a fresh coordinator-issued id (sent req_id=None to /select, then wrote body["req_id"] back onto the request). But the ctx worker already registered its KV-cache transfer TxSession under the id the request carried from the ctx phase. Overwriting it makes the gen transceiver wait on a key the ctx side never registered: the transfer never completes, gen requests stay DISAGG_GENERATION_TRANS_IN_PROGRESS, the gen IndexMapper fills (No free IndexMapper slots), and fleet throughput collapses to ~2 req/s. (Single-process num_workers=1 never hit this: no delegating router, id never rewritten.) Restore the last-good behavior: _request_id returns disagg_request_id for context and ctx_request_id for generation (the inherited ctx id), and get_next_server sends that id as the /select key WITHOUT rewriting the request. Placement never changes the disagg id, so the ctx<->gen KV transfer key stays consistent. Signed-off-by: Lizhi Zhou <1432185+reasonsolo@users.noreply.github.com> --- tensorrt_llm/serve/router.py | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 2040e66e197b..39f32b5f804d 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -1627,16 +1627,20 @@ def _on_servers_updated(self, old_servers, new_servers): def _request_id(self, request: OpenAIRequest) -> int: """The request's disagg id -- the sole cross-process key for select/finish. - Context requests carry disagg_request_id; generation requests inherit it - as ctx_request_id. OpenAIDisaggregatedService always sets it before - routing, so a missing id is a bug -- assert, don't paper over.""" + Context requests carry disagg_request_id; generation requests inherit the + SAME id as ctx_request_id (set by OpenAIDisaggregatedService before + routing). This id is what the ctx worker registered its KV-transfer + TxSession under, so the gen request MUST keep it -- never re-issue a new + id here, or the gen transceiver waits on a key the ctx side never + registered (transfer never completes -> DISAGG_GENERATION_TRANS_IN_PROGRESS + fills the gen IndexMapper and throughput collapses).""" dp = request.disaggregated_params assert dp is not None, "delegated routing requires disaggregated_params" - rid = dp.disagg_request_id - if self._role != "generation": - assert rid is not None, ( - f"delegated {self._role} routing requires a disagg request id " - f"(disagg_request_id/ctx_request_id) on the request") + rid = (dp.disagg_request_id if self._role == "context" + else dp.ctx_request_id) + assert rid is not None, ( + f"delegated {self._role} routing requires a disagg request id " + f"(disagg_request_id/ctx_request_id) on the request") return rid async def get_next_server( @@ -1644,12 +1648,12 @@ async def get_next_server( request: OpenAIRequest, exclude_server: Optional[str] = None) -> tuple[str, dict]: key = self._local.routing_key(request) - # Send the disagg request id as the sole cross-process request key; the - # coordinator keys its pending-request state by it for /finish. - req_id = (None if self._role == "generation" - else self._request_id(request)) + # Send the request's existing disagg id as the sole cross-process key; the + # coordinator keys its pending-request state by it for /finish. Placement + # (server selection) must NOT change the id -- the ctx<->gen KV transfer is + # keyed by it and was registered ctx-side before this call. payload = {"role": self._role, "routing_key": key, - "req_id": req_id, + "req_id": self._request_id(request), "exclude_server": exclude_server} async with self.session.post( f"{self._coordinator_url}/select", json=payload, @@ -1660,11 +1664,6 @@ async def get_next_server( f"{await resp.text()}") body = await resp.json() info = body.get("info") or {} - if self._role == "generation": - coordinator_req_id = body.get("req_id") - if coordinator_req_id is None: - raise ValueError("coordinator did not return a generation disagg_request_id") - request.disaggregated_params.disagg_request_id = coordinator_req_id return body["server"], info async def finish_request(self,