Skip to content

Commit da73b0d

Browse files
Logging debug - WIP
1 parent 90ff956 commit da73b0d

File tree

7 files changed

+826
-682
lines changed

7 files changed

+826
-682
lines changed

truss-chains/examples/streaming/streaming_chain.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import asyncio
2+
import logging
23
import time
34
from typing import AsyncIterator
45

6+
import fastapi
57
import pydantic
68

79
import truss_chains as chains
@@ -38,17 +40,30 @@ class ConsumerOutput(pydantic.BaseModel):
3840
class Generator(chains.ChainletBase):
3941
"""Example that streams fully structured pydantic items with header and footer."""
4042

41-
async def run_remote(self, cause_error: bool) -> AsyncIterator[bytes]:
42-
print("Entering Generator")
43+
async def run_remote(
44+
self, cause_pre_stream_error: bool, cause_mid_stream_error: bool
45+
) -> AsyncIterator[bytes]:
46+
logging.info("Entering Generator")
47+
if cause_pre_stream_error:
48+
logging.info("Raise Pre Stream")
49+
raise fastapi.HTTPException(
50+
status_code=fastapi.status.HTTP_400_BAD_REQUEST,
51+
detail="Error pre stream.",
52+
)
53+
logging.info("Starting streamer.")
54+
4355
streamer = streaming.stream_writer(STREAM_TYPES)
4456
header = Header(time=time.time(), msg="Start.")
4557
yield streamer.yield_header(header)
4658
for i in range(1, 5):
4759
data = MyDataChunk(words=[chr(x + 70) * x for x in range(1, i + 1)])
48-
print("Yield")
60+
logging.info("Yield")
4961
yield streamer.yield_item(data)
50-
if cause_error and i > 2:
51-
raise RuntimeError("Test Error")
62+
if cause_mid_stream_error and i > 2:
63+
raise fastapi.HTTPException(
64+
status_code=fastapi.status.HTTP_501_NOT_IMPLEMENTED,
65+
detail="Error mid stream",
66+
)
5267
await asyncio.sleep(0.05)
5368

5469
end_time = time.time()
@@ -74,16 +89,20 @@ class Consumer(chains.ChainletBase):
7489

7590
def __init__(
7691
self,
77-
generator=chains.depends(Generator),
78-
string_generator=chains.depends(StringGenerator),
92+
# generator=chains.depends(Generator),
93+
# string_generator=chains.depends(StringGenerator),
7994
):
80-
self._generator = generator
81-
self._string_generator = string_generator
95+
# self._generator = generator
96+
# self._string_generator = string_generator
97+
pass
8298

83-
async def run_remote(self, cause_error: bool) -> ConsumerOutput:
99+
async def run_remote(
100+
self, cause_pre_stream_error: bool, cause_mid_stream_error: bool
101+
) -> ConsumerOutput:
84102
print("Entering Consumer")
85103
reader = streaming.stream_reader(
86-
STREAM_TYPES, self._generator.run_remote(cause_error)
104+
STREAM_TYPES,
105+
self._generator.run_remote(cause_pre_stream_error, cause_mid_stream_error),
87106
)
88107
print("Consuming...")
89108
header = await reader.read_header()
@@ -92,15 +111,15 @@ async def run_remote(self, cause_error: bool) -> ConsumerOutput:
92111
print(f"Read: {data}")
93112
chunks.append(data)
94113

95-
footer = await reader.read_footer()
96-
strings = []
97-
async for part in self._string_generator.run_remote():
98-
strings.append(part)
99-
100-
print("Exiting Consumer")
101-
return ConsumerOutput(
102-
header=header, chunks=chunks, footer=footer, strings="".join(strings)
103-
)
114+
# footer = await reader.read_footer()
115+
# strings = []
116+
# async for part in self._string_generator.run_remote():
117+
# strings.append(part)
118+
#
119+
# print("Exiting Consumer")
120+
# return ConsumerOutput(
121+
# header=header, chunks=chunks, footer=footer, strings="".join(strings)
122+
# )
104123

105124

106125
if __name__ == "__main__":

truss-chains/tests/test_e2e.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,24 @@ def test_streaming_chain():
163163
assert service is not None
164164
time.sleep(1.0) # Wait for models to be ready.
165165

166-
response = service.run_remote({"cause_error": False})
167-
assert response.status_code == 200
168-
print(response.json())
169-
result = response.json()
170-
print(result)
171-
assert result["header"]["msg"] == "Start."
172-
assert result["chunks"][0]["words"] == ["G"]
173-
assert result["chunks"][1]["words"] == ["G", "HH"]
174-
assert result["chunks"][2]["words"] == ["G", "HH", "III"]
175-
assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"]
176-
assert result["footer"]["duration_sec"] > 0
177-
assert result["strings"] == "First second last."
166+
response = service.run_remote(
167+
{"cause_pre_stream_error": "hell", "cause_mid_stream_error": False}
168+
)
169+
print(response.status_code)
170+
print(response.content)
171+
assert False
172+
173+
# assert response.status_code == 200
174+
# print(response.json())
175+
# result = response.json()
176+
# print(result)
177+
# assert result["header"]["msg"] == "Start."
178+
# assert result["chunks"][0]["words"] == ["G"]
179+
# assert result["chunks"][1]["words"] == ["G", "HH"]
180+
# assert result["chunks"][2]["words"] == ["G", "HH", "III"]
181+
# assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"]
182+
# assert result["footer"]["duration_sec"] > 0
183+
# assert result["strings"] == "First second last."
178184

179185
# TODO: build error handling for stream reader.
180186
# response = service.run_remote({"cause_error": True})

truss/templates/server/model_wrapper.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ async def preprocess(
656656
)
657657
return await self._execute_user_model_fn(inputs, request, descriptor)
658658

659+
# TODO: can we eliminate this bloat layer?
659660
async def _predict(
660661
self, inputs: Any, request: starlette.requests.Request
661662
) -> Union[OutputType, Any]:
@@ -696,6 +697,11 @@ async def _write_response_to_queue(
696697
f"Exception while generating streamed response: {str(e)}",
697698
exc_info=errors.filter_traceback(self.model_file_name),
698699
)
700+
# Since this runs in a task, we *do not* raise the exception, just
701+
# log the error, close the queue and finish the task.
702+
# It's not possible to signal an error to the client (e.g. via HTTP
703+
# status) after streaming has begun unless introducing a schema for
704+
# error messages on the stream itself - but here we are unopinionated.
699705
finally:
700706
await queue.put(SENTINEL)
701707

@@ -711,14 +717,27 @@ async def _stream_with_background_task(
711717
streaming_read_timeout = self._config.get("runtime", {}).get(
712718
"streaming_read_timeout", STREAMING_RESPONSE_QUEUE_READ_TIMEOUT_SECS
713719
)
714-
async_generator = _force_async_generator(generator)
715720
# To ensure that a partial read from a client does not keep the semaphore
716721
# claimed, we write all the data from the stream to the queue as it is produced,
717722
# irrespective of how fast it is consumed.
718723
# We then return a new generator that reads from the queue, and then
719724
# exits the semaphore block.
720725
response_queue: asyncio.Queue = asyncio.Queue()
721726

727+
# In order to catch errors before the first `yield` (e.g. user implemented
728+
# input validation), we get the first chunk here and raise the error if needed.
729+
with tracing.section_as_event(span, "await_first_element"):
730+
try:
731+
async_generator = _force_async_generator(generator)
732+
first_chunk = await async_generator.__anext__()
733+
await response_queue.put(first_chunk)
734+
except StopAsyncIteration:
735+
cleanup_fn()
736+
return (chunk async for chunk in []) # Empty dummy generator.
737+
except Exception as e:
738+
cleanup_fn()
739+
# print("CALL STACK:\n" + "".join(traceback.format_stack()))
740+
raise e
722741
# `write_response_to_queue` keeps running the background until completion.
723742
gen_task = asyncio.create_task(
724743
self._write_response_to_queue(response_queue, async_generator, span)
@@ -727,8 +746,6 @@ async def _stream_with_background_task(
727746
gen_task.add_done_callback(lambda _: cleanup_fn())
728747

729748
# The gap between responses in a stream must be < streaming_read_timeout
730-
# TODO: this whole buffering might be superfluous and sufficiently done by
731-
# by the FastAPI server already. See `test_limit_concurrency_with_sse`.
732749
async def _buffered_response_generator() -> AsyncGenerator[bytes, None]:
733750
# `span` is tied to the "producer" `gen_task` which might complete before
734751
# "consume" part here finishes, therefore a dedicated span is required.
@@ -854,20 +871,23 @@ async def predict(
854871
# exactly handle that case we would need to apply `detach_context`
855872
# around each `next`-invocation that consumes the generator, which is
856873
# prohibitive.
874+
# TODO: predict has exception interception via `_execute_user_model_fn`,
875+
# but all the other parts of the flow don't have that...
876+
# why is the stack trace above here missing?
857877
predict_result = await self._predict(preprocess_result, request)
858878

859879
if inspect.isgenerator(predict_result) or inspect.isasyncgen(
860880
predict_result
861881
):
862882
if self.model_descriptor.postprocess:
863-
with errors.intercept_exceptions(
864-
self._logger, self.model_file_name
865-
):
866-
raise errors.ModelDefinitionError(
867-
"If the predict function returns a generator (streaming), "
868-
"you cannot use postprocessing. Include all processing in "
869-
"the predict method."
870-
)
883+
# with errors.intercept_exceptions(
884+
# self._logger, self.model_file_name
885+
# ):
886+
raise errors.ModelDefinitionError(
887+
"If the predict function returns a generator (streaming), "
888+
"you cannot use postprocessing. Include all processing in "
889+
"the predict method."
890+
)
871891

872892
return await self._handle_generator_response(
873893
request,

truss/templates/server/truss_server.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,6 @@ async def _execute_request(
179179
request: Request,
180180
body_raw: bytes,
181181
) -> Response:
182-
"""
183-
Executes a predictive endpoint
184-
"""
185182
self.check_healthy()
186183
trace_ctx = otel_propagate.extract(request.headers) or None
187184
# This is the top-level span in the truss-server, so we set the context here.

truss/tests/helpers.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
1+
import contextlib
2+
import json
3+
import tempfile
4+
import textwrap
15
from pathlib import Path
6+
from typing import Iterator, Optional
27

8+
from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all
9+
from truss.truss_handle.truss_handle import TrussHandle
310

4-
def create_truss(truss_dir: Path, config_contents: str, model_contents: str):
11+
12+
def _create_truss(truss_dir: Path, config_contents: str, model_contents: str):
513
truss_dir.mkdir(exist_ok=True) # Ensure the 'truss' directory exists
614
truss_model_dir = truss_dir / "model"
715
truss_model_dir.mkdir(parents=True, exist_ok=True)
@@ -12,3 +20,44 @@ def create_truss(truss_dir: Path, config_contents: str, model_contents: str):
1220
file.write(config_contents)
1321
with open(model_file, "w", encoding="utf-8") as file:
1422
file.write(model_contents)
23+
24+
25+
@contextlib.contextmanager
26+
def temp_truss(model_src: str, config_src: str = "") -> Iterator[TrussHandle]:
27+
with ensure_kill_all(), tempfile.TemporaryDirectory(dir=".") as tmp_work_dir:
28+
truss_dir = Path(tmp_work_dir, "truss")
29+
_create_truss(truss_dir, config_src, textwrap.dedent(model_src))
30+
yield TrussHandle(truss_dir)
31+
32+
33+
DEFAULT_LOG_ERROR = "Internal Server Error"
34+
35+
36+
def _log_contains_line(
37+
line: dict, message: str, level: str, error: Optional[str] = None
38+
):
39+
return (
40+
line["levelname"] == level
41+
and message in line["message"]
42+
and (error is None or error in line["exc_info"])
43+
)
44+
45+
46+
def assert_logs_contain_error(
47+
logs: str, error: Optional[str], message=DEFAULT_LOG_ERROR
48+
):
49+
loglines = [json.loads(line) for line in logs.splitlines()]
50+
assert any(
51+
_log_contains_line(line, message, "ERROR", error) for line in loglines
52+
), (
53+
f"Did not find expected error in logs.\nExpected error: {error}\n"
54+
f"Expected message: {message}\nActual logs:\n{loglines}"
55+
)
56+
57+
58+
def assert_logs_contain(logs: str, message: str, level: str = "INFO"):
59+
loglines = [json.loads(line) for line in logs.splitlines()]
60+
assert any(_log_contains_line(line, message, level) for line in loglines), (
61+
f"Did not find expected logs.\n"
62+
f"Expected message: {message}\nActual logs:\n{loglines}"
63+
)

0 commit comments

Comments
 (0)