Skip to content

Commit c4b8bd1

Browse files
Always return output for completions sync response (#412)
* Always return output for completions sync response * fix tests and dependencies * more tests * add more tests * revert dependency changes:
1 parent 82c9303 commit c4b8bd1

File tree

7 files changed

+442
-88
lines changed

7 files changed

+442
-88
lines changed

model-engine/model_engine_server/api/llms_v1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -332,12 +332,14 @@ async def create_completion_sync_task(
332332
metric_metadata,
333333
)
334334
return response
335-
except UpstreamServiceError:
335+
except UpstreamServiceError as exc:
336336
request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID)
337-
logger.exception(f"Upstream service error for request {request_id}")
337+
logger.exception(
338+
f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}"
339+
)
338340
raise HTTPException(
339341
status_code=500,
340-
detail=f"Upstream service error for request_id {request_id}.",
342+
detail=f"Upstream service error for request_id {request_id}",
341343
)
342344
except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc:
343345
raise HTTPException(

model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1555,9 +1555,11 @@ async def execute(
15551555
),
15561556
)
15571557
else:
1558-
return CompletionSyncV1Response(
1559-
request_id=request_id,
1560-
output=None,
1558+
raise UpstreamServiceError(
1559+
status_code=500,
1560+
content=predict_result.traceback.encode("utf-8")
1561+
if predict_result.traceback is not None
1562+
else b"",
15611563
)
15621564
elif (
15631565
endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE
@@ -1589,9 +1591,11 @@ async def execute(
15891591
)
15901592

15911593
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
1592-
return CompletionSyncV1Response(
1593-
request_id=request_id,
1594-
output=None,
1594+
raise UpstreamServiceError(
1595+
status_code=500,
1596+
content=predict_result.traceback.encode("utf-8")
1597+
if predict_result.traceback is not None
1598+
else b"",
15951599
)
15961600

15971601
output = json.loads(predict_result.result["result"])
@@ -1628,9 +1632,11 @@ async def execute(
16281632
)
16291633

16301634
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
1631-
return CompletionSyncV1Response(
1632-
request_id=request_id,
1633-
output=None,
1635+
raise UpstreamServiceError(
1636+
status_code=500,
1637+
content=predict_result.traceback.encode("utf-8")
1638+
if predict_result.traceback is not None
1639+
else b"",
16341640
)
16351641

16361642
output = json.loads(predict_result.result["result"])
@@ -1670,9 +1676,11 @@ async def execute(
16701676
)
16711677

16721678
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
1673-
return CompletionSyncV1Response(
1674-
request_id=request_id,
1675-
output=None,
1679+
raise UpstreamServiceError(
1680+
status_code=500,
1681+
content=predict_result.traceback.encode("utf-8")
1682+
if predict_result.traceback is not None
1683+
else b"",
16761684
)
16771685

16781686
output = json.loads(predict_result.result["result"])
@@ -1706,9 +1714,11 @@ async def execute(
17061714
)
17071715

17081716
if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None:
1709-
return CompletionSyncV1Response(
1710-
request_id=request_id,
1711-
output=None,
1717+
raise UpstreamServiceError(
1718+
status_code=500,
1719+
content=predict_result.traceback.encode("utf-8")
1720+
if predict_result.traceback is not None
1721+
else b"",
17121722
)
17131723

17141724
output = json.loads(predict_result.result["result"])

model-engine/requirements.txt

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ boto3==1.28.1
5454
# celery
5555
# kombu
5656
boto3-stubs[essential]==1.26.67
57-
# via -r model-engine/requirements.in
57+
# via
58+
# -r model-engine/requirements.in
59+
# boto3-stubs
5860
botocore==1.31.1
5961
# via
6062
# -r model-engine/requirements.in
@@ -71,15 +73,15 @@ cachetools==5.3.1
7173
cattrs==23.1.2
7274
# via ddtrace
7375
celery[redis,sqs,tblib]==5.3.1
74-
# via -r model-engine/requirements.in
76+
# via
77+
# -r model-engine/requirements.in
78+
# celery
7579
certifi==2023.7.22
7680
# via
7781
# datadog-api-client
7882
# kubernetes
7983
# kubernetes-asyncio
8084
# requests
81-
cffi==1.15.1
82-
# via cryptography
8385
charset-normalizer==3.2.0
8486
# via
8587
# aiohttp
@@ -107,8 +109,6 @@ commonmark==0.9.1
107109
# via rich
108110
croniter==1.4.1
109111
# via -r model-engine/requirements.in
110-
cryptography==41.0.3
111-
# via secretstorage
112112
dataclasses-json==0.5.9
113113
# via -r model-engine/requirements.in
114114
datadog==0.47.0
@@ -127,7 +127,7 @@ docutils==0.20.1
127127
# via readme-renderer
128128
envier==0.4.0
129129
# via ddtrace
130-
exceptiongroup==1.1.3
130+
exceptiongroup==1.2.0
131131
# via
132132
# anyio
133133
# cattrs
@@ -185,7 +185,7 @@ importlib-metadata==6.8.0
185185
# keyring
186186
# quart
187187
# twine
188-
importlib-resources==6.1.0
188+
importlib-resources==6.1.1
189189
# via
190190
# alembic
191191
# jsonschema
@@ -195,10 +195,6 @@ itsdangerous==2.1.2
195195
# via quart
196196
jaraco-classes==3.3.0
197197
# via keyring
198-
jeepney==0.8.0
199-
# via
200-
# keyring
201-
# secretstorage
202198
jinja2==3.0.3
203199
# via
204200
# -r model-engine/requirements.in
@@ -300,8 +296,6 @@ pyasn1==0.5.0
300296
# rsa
301297
pyasn1-modules==0.3.0
302298
# via google-auth
303-
pycparser==2.21
304-
# via cffi
305299
pycurl==7.45.2
306300
# via
307301
# -r model-engine/requirements.in
@@ -326,7 +320,7 @@ python-dateutil==2.8.2
326320
# pg8000
327321
python-multipart==0.0.6
328322
# via -r model-engine/requirements.in
329-
pyyaml==6.0
323+
pyyaml==6.0.1
330324
# via
331325
# huggingface-hub
332326
# kubeconfig
@@ -379,8 +373,6 @@ safetensors==0.4.0
379373
# via transformers
380374
scramp==1.4.4
381375
# via pg8000
382-
secretstorage==3.3.3
383-
# via keyring
384376
sentencepiece==0.1.99
385377
# via -r model-engine/requirements.in
386378
sh==1.14.3
@@ -409,6 +401,7 @@ sqlalchemy[asyncio]==2.0.4
409401
# via
410402
# -r model-engine/requirements.in
411403
# alembic
404+
# sqlalchemy
412405
sse-starlette==1.6.1
413406
# via -r model-engine/requirements.in
414407
sseclient-py==1.7.2
@@ -525,8 +518,4 @@ zipp==3.16.0
525518
# importlib-resources
526519

527520
# The following packages are considered to be unsafe in a requirements file:
528-
setuptools==68.0.0
529-
# via
530-
# gunicorn
531-
# kubernetes
532-
# kubernetes-asyncio
521+
# setuptools

model-engine/tests/unit/api/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def get_test_client(
106106
fake_file_system_gateway_contents=None,
107107
fake_trigger_repository_contents=None,
108108
fake_cron_job_gateway_contents=None,
109+
fake_sync_inference_content=None,
109110
) -> TestClient:
110111
if fake_docker_image_batch_job_gateway_contents is None:
111112
fake_docker_image_batch_job_gateway_contents = {}
@@ -131,6 +132,8 @@ def get_test_client(
131132
fake_trigger_repository_contents = {}
132133
if fake_cron_job_gateway_contents is None:
133134
fake_cron_job_gateway_contents = {}
135+
if fake_sync_inference_content is None:
136+
fake_sync_inference_content = {}
134137
app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper(
135138
fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists,
136139
fake_model_bundle_repository_contents=fake_model_bundle_repository_contents,
@@ -145,6 +148,7 @@ def get_test_client(
145148
fake_file_system_gateway_contents=fake_file_system_gateway_contents,
146149
fake_trigger_repository_contents=fake_trigger_repository_contents,
147150
fake_cron_job_gateway_contents=fake_cron_job_gateway_contents,
151+
fake_sync_inference_content=fake_sync_inference_content,
148152
)
149153
app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[
150154
get_external_interfaces

model-engine/tests/unit/api/test_llms.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pytest
66
from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response
7+
from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus
78
from model_engine_server.domain.entities import ModelEndpoint
89

910

@@ -102,14 +103,30 @@ def test_completion_sync_success(
102103
fake_batch_job_record_repository_contents={},
103104
fake_batch_job_progress_gateway_contents={},
104105
fake_docker_image_batch_job_bundle_repository_contents={},
106+
fake_sync_inference_content=SyncEndpointPredictV1Response(
107+
status=TaskStatus.SUCCESS,
108+
result={
109+
"result": """{
110+
"text": "output",
111+
"count_prompt_tokens": 1,
112+
"count_output_tokens": 1
113+
}"""
114+
},
115+
traceback=None,
116+
),
105117
)
106118
response_1 = client.post(
107119
f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}",
108120
auth=("no_user", ""),
109121
json=completion_sync_request,
110122
)
111123
assert response_1.status_code == 200
112-
assert response_1.json()["output"] is None
124+
assert response_1.json()["output"] == {
125+
"text": "output",
126+
"num_completion_tokens": 1,
127+
"num_prompt_tokens": 1,
128+
"tokens": None,
129+
}
113130
assert response_1.json().keys() == {"output", "request_id"}
114131

115132

0 commit comments

Comments
 (0)