Skip to content

Commit e6334ce

Browse files
Remove stack trace when model load takes long (#1674)
* remove stack trace when model load takes long * fix server tests and return 503 instead of raising exception * don't retry health checks to control server, but retry other endpoints * add comment clarifying different response for health check * refactor method to check if method is health check * clarify model not ready message and comments * update truss rc
1 parent 31fc8eb commit e6334ce

File tree

4 files changed

+29
-11
lines changed

4 files changed

+29
-11
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.96rc003"
3+
version = "0.9.96rc018"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/templates/control/control/endpoints.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,17 @@
55
import httpx
66
from fastapi import APIRouter, WebSocket
77
from fastapi.responses import JSONResponse, StreamingResponse
8-
from helpers.errors import ModelLoadFailed, ModelNotReady
98
from httpx_ws import aconnect_ws
109
from starlette.requests import ClientDisconnect, Request
1110
from starlette.responses import Response
1211
from tenacity import RetryCallState, Retrying, retry_if_exception_type, wait_fixed
1312
from wsproto.events import BytesMessage, TextMessage
1413

14+
from truss.templates.control.control.helpers.errors import (
15+
ModelLoadFailed,
16+
ModelNotReady,
17+
)
18+
1519
INFERENCE_SERVER_START_WAIT_SECS = 60
1620
BASE_RETRY_EXCEPTIONS = (
1721
retry_if_exception_type(httpx.ConnectError)
@@ -65,7 +69,16 @@ async def proxy_http(request: Request):
6569
resp = await client.send(inf_serv_req, stream=True)
6670

6771
if await _is_model_not_ready(resp):
68-
raise ModelNotReady("Model has started running, but not ready yet.")
72+
# If this is a health check request, don't raise an error so that a stack
73+
# trace isn't logged upon deploying a model with a long load time.
74+
if _is_health_check(path):
75+
return JSONResponse(
76+
"The server is live, but the model has not completed loading.",
77+
status_code=503,
78+
)
79+
raise ModelNotReady(
80+
"The server is live, but the model has not completed loading."
81+
)
6982
except (httpx.RemoteProtocolError, httpx.ConnectError) as exp:
7083
# This check is a bit expensive so we don't do it before every request, we
7184
# do it only if request fails with connection error. If the inference server
@@ -99,7 +112,7 @@ def inference_retries(
99112
retry=retry_condition,
100113
stop=_custom_stop_strategy,
101114
wait=wait_fixed(1),
102-
reraise=False,
115+
reraise=True,
103116
):
104117
yield attempt
105118

@@ -216,6 +229,13 @@ def _reroute_if_health_check(path: str) -> str:
216229
return path
217230

218231

232+
def _is_health_check(path: str) -> bool:
233+
"""
234+
Checks if the request path is for the health check endpoint.
235+
"""
236+
return path == "/v1/models/model/loaded"
237+
238+
219239
def _custom_stop_strategy(retry_state: RetryCallState) -> bool:
220240
# Stop after 10 attempts for ModelNotReady
221241
if retry_state.outcome is not None and isinstance(

truss/templates/control/control/helpers/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class InadmissiblePatch(PatchApplicatonError):
3737

3838

3939
class ModelNotReady(Error):
40-
"""Model has started running, but not ready yet."""
40+
"""The server is live, but the model has not completed loading."""
4141

4242
pass
4343

truss/tests/templates/control/control/test_server.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
import httpx
1010
import pytest
11-
from tenacity import RetryError
1211

1312
from truss.truss_handle.patch.custom_types import PatchRequest
1413

@@ -246,11 +245,10 @@ async def mock_send(*args, **kwargs):
246245

247246
app.state.proxy_client.send = AsyncMock(side_effect=mock_send)
248247

249-
with pytest.raises(RetryError):
250-
await client.get("/v1/models/model")
248+
await client.get("/v1/models/model")
251249

252-
# Health check was retried 10 times
253-
assert app.state.proxy_client.send.call_count == 10
250+
# Health check did not retry
251+
assert app.state.proxy_client.send.call_count == 1
254252

255253

256254
@pytest.mark.anyio
@@ -277,7 +275,7 @@ async def test_retries(client, app):
277275

278276
with (
279277
patch("endpoints.INFERENCE_SERVER_START_WAIT_SECS", new=4),
280-
pytest.raises(RetryError),
278+
pytest.raises(httpx.RemoteProtocolError),
281279
):
282280
await client.get("/v1/models/model")
283281

0 commit comments

Comments
 (0)