Skip to content

Commit 586f730

Browse files
committed
WIP: add test and fix missing changes
Signed-off-by: Gene Su <[email protected]>
1 parent 7416b32 commit 586f730

File tree

5 files changed

+209
-1
lines changed

5 files changed

+209
-1
lines changed

python/ray/serve/_private/deployment_state.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -955,7 +955,7 @@ def get_routing_stats(self) -> Dict[str, Any]:
955955
logger.warning(
956956
"Didn't receive routing stats response for replica "
957957
f"{self._replica_id} after "
958-
f"{self.request_routing_stats_timeout_s}s, retrying."
958+
f"{self.request_routing_stats_timeout_s}s, retrying. {self._routing_stats=}"
959959
)
960960
self._record_routing_stats_ref = None
961961

@@ -2235,6 +2235,7 @@ def check_and_update_replicas(self):
22352235
},
22362236
)
22372237
routing_stats = replica.get_routing_stats()
2238+
print(f"in check_and_update_replicas {routing_stats=}")
22382239
replica.record_routing_stats(routing_stats)
22392240
else:
22402241
logger.warning(

python/ray/serve/api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,8 @@ def deployment(
335335
health_check_timeout_s: Default[float] = DEFAULT.VALUE,
336336
logging_config: Default[Union[Dict, LoggingConfig, None]] = DEFAULT.VALUE,
337337
request_router_class: Default[Union[str, RequestRouter, None]] = DEFAULT.VALUE,
338+
request_routing_stats_period_s: Default[float] = DEFAULT.VALUE,
339+
request_routing_stats_timeout_s: Default[float] = DEFAULT.VALUE,
338340
) -> Callable[[Callable], Deployment]:
339341
"""Decorator that converts a Python class to a `Deployment`.
340342
@@ -404,6 +406,14 @@ class MyDeployment:
404406
handle created for this deployment will use the routing policy
405407
defined by the request router. Default to Serve's
406408
PowerOfTwoChoicesRequestRouter.
409+
request_routing_stats_period_s: Duration between record scheduling stats
410+
calls for the replica. Defaults to 10s. The health check is by default a
411+
no-op Actor call to the replica, but you can define your own request
412+
scheduling stats using the "record_scheduling_stats" method in your
413+
deployment.
414+
request_routing_stats_timeout_s: Duration in seconds, that replicas wait for
415+
a request scheduling stats method to return before considering it as failed.
416+
Defaults to 30s.
407417
408418
Returns:
409419
`Deployment`
@@ -469,6 +479,8 @@ class MyDeployment:
469479
health_check_period_s=health_check_period_s,
470480
health_check_timeout_s=health_check_timeout_s,
471481
logging_config=logging_config,
482+
request_routing_stats_period_s=request_routing_stats_period_s,
483+
request_routing_stats_timeout_s=request_routing_stats_timeout_s,
472484
)
473485
deployment_config.user_configured_option_names = set(user_configured_option_names)
474486

python/ray/serve/deployment.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,8 @@ def options(
238238
health_check_timeout_s: Default[float] = DEFAULT.VALUE,
239239
logging_config: Default[Union[Dict, LoggingConfig, None]] = DEFAULT.VALUE,
240240
request_router_class: Default[Union[str, RequestRouter, None]] = DEFAULT.VALUE,
241+
request_routing_stats_period_s: Default[float] = DEFAULT.VALUE,
242+
request_routing_stats_timeout_s: Default[float] = DEFAULT.VALUE,
241243
_init_args: Default[Tuple[Any]] = DEFAULT.VALUE,
242244
_init_kwargs: Default[Dict[Any, Any]] = DEFAULT.VALUE,
243245
_internal: bool = False,
@@ -373,6 +375,16 @@ def options(
373375
if request_router_class is not DEFAULT.VALUE:
374376
new_deployment_config.request_router_class = request_router_class
375377

378+
if request_routing_stats_period_s is not DEFAULT.VALUE:
379+
new_deployment_config.request_routing_stats_period_s = (
380+
request_routing_stats_period_s
381+
)
382+
383+
if request_routing_stats_timeout_s is not DEFAULT.VALUE:
384+
new_deployment_config.request_routing_stats_timeout_s = (
385+
request_routing_stats_timeout_s
386+
)
387+
376388
new_replica_config = ReplicaConfig.create(
377389
func_or_class,
378390
init_args=_init_args,
@@ -441,6 +453,8 @@ def deployment_to_schema(d: Deployment) -> DeploymentSchema:
441453
"placement_group_bundles": d._replica_config.placement_group_bundles,
442454
"max_replicas_per_node": d._replica_config.max_replicas_per_node,
443455
"logging_config": d._deployment_config.logging_config,
456+
"request_routing_stats_period_s": d._deployment_config.request_routing_stats_period_s,
457+
"request_routing_stats_timeout_s": d._deployment_config.request_routing_stats_timeout_s,
444458
}
445459

446460
# Let non-user-configured options be set to defaults. If the schema
@@ -501,6 +515,8 @@ def schema_to_deployment(s: DeploymentSchema) -> Deployment:
501515
health_check_period_s=s.health_check_period_s,
502516
health_check_timeout_s=s.health_check_timeout_s,
503517
logging_config=s.logging_config,
518+
request_routing_stats_period_s=s.request_routing_stats_period_s,
519+
request_routing_stats_timeout_s=s.request_routing_stats_timeout_s,
504520
)
505521
deployment_config.user_configured_option_names = (
506522
s._get_user_configured_option_names()

python/ray/serve/schema.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,22 @@ class DeploymentSchema(BaseModel, allow_population_by_field_name=True):
409409
default=DEFAULT.VALUE,
410410
description="The path pointing to the custom request router class to use for this deployment.",
411411
)
412+
request_routing_stats_period_s: float = Field(
413+
default=DEFAULT.VALUE,
414+
description=(
415+
"Frequency at which the controller will record routing stats "
416+
"replicas. Uses a default if null."
417+
),
418+
gt=0,
419+
)
420+
request_routing_stats_timeout_s: float = Field(
421+
default=DEFAULT.VALUE,
422+
description=(
423+
"Timeout that the controller will wait for a response "
424+
"from the replica's record routing stats. Uses a default if null."
425+
),
426+
gt=0,
427+
)
412428

413429
@root_validator
414430
def validate_num_replicas_and_autoscaling_config(cls, values):
@@ -488,6 +504,8 @@ def _deployment_info_to_schema(name: str, info: DeploymentInfo) -> DeploymentSch
488504
health_check_timeout_s=info.deployment_config.health_check_timeout_s,
489505
ray_actor_options=info.replica_config.ray_actor_options,
490506
request_router_class=info.deployment_config.request_router_class,
507+
request_routing_stats_period_s=info.deployment_config.request_routing_stats_period_s,
508+
request_routing_stats_timeout_s=info.deployment_config.request_routing_stats_timeout_s,
491509
)
492510

493511
if info.deployment_config.autoscaling_config is not None:
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import asyncio
2+
from typing import Any, Dict, Optional
3+
4+
import pytest
5+
6+
import ray
7+
from ray import serve
8+
from ray._private.test_utils import wait_for_condition
9+
from ray.serve._private.common import ReplicaID
10+
from ray.serve.context import _get_internal_replica_context
11+
from ray.serve.handle import DeploymentHandle
12+
13+
14+
@serve.deployment(
15+
request_routing_stats_period_s=0.1, request_routing_stats_timeout_s=0.1
16+
)
17+
class Patient:
18+
def __init__(self):
19+
self.routing_stats = {}
20+
self.should_hang = False
21+
self.should_fail = False
22+
context = _get_internal_replica_context()
23+
self.replica_id: ReplicaID = context.replica_id
24+
25+
async def record_routing_stats(self):
26+
if self.should_hang:
27+
import time
28+
29+
time.sleep(10000)
30+
31+
if self.should_fail:
32+
raise Exception("intended to fail")
33+
34+
return self.routing_stats
35+
36+
def __call__(self, *args) -> ReplicaID:
37+
return self.replica_id
38+
39+
def set_routing_stats(self, routing_stats: Dict[str, Any]):
40+
print(f"set_routing_stats {routing_stats=}")
41+
self.routing_stats = routing_stats
42+
43+
def set_should_fail(self):
44+
self.should_fail = True
45+
46+
def set_should_hang(self):
47+
self.should_hang = True
48+
49+
50+
def check_routing_stats_recorded(
51+
handle: DeploymentHandle,
52+
expected_stats: Dict[str, Any],
53+
replica_id: Optional[ReplicaID] = None,
54+
) -> bool:
55+
running_replicas = handle._router._asyncio_router.request_router._replicas
56+
if replica_id:
57+
target_running_replica = running_replicas[replica_id]
58+
else:
59+
target_running_replica = next(iter(running_replicas.values()))
60+
assert (
61+
target_running_replica.routing_stats == expected_stats
62+
), f"{target_running_replica.routing_stats=} != {expected_stats=}"
63+
return True
64+
65+
66+
@pytest.mark.parametrize("use_class", [True, False])
67+
def test_no_user_defined_method(serve_instance, use_class):
68+
"""Check the default behavior."""
69+
if use_class:
70+
71+
@serve.deployment
72+
class A:
73+
def __call__(self, *args):
74+
return ray.get_runtime_context().current_actor
75+
76+
else:
77+
78+
@serve.deployment
79+
def A(*args):
80+
return ray.get_runtime_context().current_actor
81+
82+
h = serve.run(A.bind())
83+
_ = h.remote().result()
84+
replicas = list(h._router._asyncio_router.request_router._replicas.values())
85+
assert len(replicas) == 1
86+
assert replicas[0].routing_stats == {}
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_user_defined_method_fails(serve_instance):
91+
"""Check the behavior when a user-defined method fails."""
92+
expected_stats = {"foo": "bar"}
93+
h = serve.run(Patient.bind())
94+
await h.set_routing_stats.remote(expected_stats)
95+
replica_id = await h.remote()
96+
97+
# Ensure the routing stats are recorded correctly before the failure
98+
wait_for_condition(
99+
check_routing_stats_recorded,
100+
handle=h,
101+
expected_stats=expected_stats,
102+
replica_id=replica_id,
103+
)
104+
105+
await h.set_should_fail.remote()
106+
await asyncio.gather(*[h.remote() for _ in range(100)])
107+
108+
# After the failure the previous routing stats should still accessible
109+
wait_for_condition(
110+
check_routing_stats_recorded,
111+
handle=h,
112+
expected_stats=expected_stats,
113+
replica_id=replica_id,
114+
)
115+
116+
117+
# @pytest.mark.asyncio
118+
# async def test_user_defined_method_hangs(serve_instance):
119+
# """Check the behavior when a user-defined method hangs."""
120+
# expected_stats = {"foo": "bar"}
121+
# h = serve.run(Patient.bind())
122+
# await h.set_routing_stats.remote(expected_stats)
123+
# replica_id = await h.remote()
124+
#
125+
# # Ensure the routing stats are recorded correctly before the failure
126+
# wait_for_condition(check_routing_stats_recorded, handle=h, expected_stats=expected_stats, replica_id=replica_id)
127+
#
128+
# print("A")
129+
# await h.set_should_hang.remote()
130+
# print("B")
131+
# await asyncio.gather(*[h.remote() for _ in range(100)])
132+
# print("C")
133+
# # After the failure the previous routing stats should still accessible
134+
# wait_for_condition(check_routing_stats_recorded, handle=h, expected_stats=expected_stats, replica_id=replica_id)
135+
#
136+
#
137+
# @pytest.mark.asyncio
138+
# async def test_multiple_replicas(serve_instance):
139+
# h = serve.run(Patient.options(num_replicas=2).bind())
140+
# actors = {
141+
# a._actor_id for a in await asyncio.gather(*[h.remote() for _ in range(100)])
142+
# }
143+
# assert len(actors) == 2
144+
#
145+
# await h.set_should_fail.remote()
146+
#
147+
# await async_wait_for_condition(
148+
# check_new_actor_started, handle=h, original_actors=actors
149+
# )
150+
#
151+
# new_actors = {
152+
# a._actor_id for a in await asyncio.gather(*[h.remote() for _ in range(100)])
153+
# }
154+
# assert len(new_actors) == 2
155+
# assert len(new_actors.intersection(actors)) == 1
156+
157+
158+
if __name__ == "__main__":
159+
import sys
160+
161+
sys.exit(pytest.main(["-v", "-s", __file__]))

0 commit comments

Comments
 (0)