Skip to content

Commit d4be9b9

Browse files
authored
[Logging I/O] Post inference hooks as background tasks (#422)
* changes for forwarder to run locally * forwarder hooks as background tasks and testing code * hooks for celery forwarder * revert local changes for testing * revert unncessary things * remove space * remove print statement + fix unit test * move logic to after_return * load json response in handler * add temp unit test for post inference hooks handler * add another temp unit test for json handling * not cover handle line for now
1 parent e6e9111 commit d4be9b9

File tree

5 files changed

+149
-55
lines changed

5 files changed

+149
-55
lines changed

model-engine/model_engine_server/inference/forwarding/celery_forwarder.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from celery import Celery, Task, states
66
from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME
77
from model_engine_server.common.dtos.model_endpoints import BrokerType
8+
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
89
from model_engine_server.core.celery import TaskVisibility, celery_app
910
from model_engine_server.core.config import infra_config
1011
from model_engine_server.core.loggers import logger_name, make_logger
@@ -25,45 +26,6 @@ class ErrorResponse(TypedDict):
2526
error_metadata: str
2627

2728

28-
class ErrorHandlingTask(Task):
29-
"""Sets a 'custom' field with error in the Task response for FAILURE.
30-
31-
Used when services are ran via the Celery backend.
32-
"""
33-
34-
def after_return(
35-
self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo
36-
) -> None:
37-
"""Handler that ensures custom error response information is available whenever a Task fails.
38-
39-
Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value
40-
:param:`retval` is an `Exception`, this handler extracts information from the `Exception`
41-
and constructs a custom error response JSON value (see :func:`error_response` for details).
42-
43-
This handler then re-propagates the Celery-required exception information (`"exc_type"` and
44-
`"exc_message"`) while adding this new error response information under the `"custom"` key.
45-
"""
46-
if status == states.FAILURE and isinstance(retval, Exception):
47-
logger.warning(f"Setting custom error response for failed task {task_id}")
48-
49-
info: dict = raw_celery_response(self.backend, task_id)
50-
result: dict = info["result"]
51-
err: Exception = retval
52-
53-
error_payload = error_response("Internal failure", err)
54-
55-
# Inspired by pattern from:
56-
# https://www.distributedpython.com/2018/09/28/celery-task-states/
57-
self.update_state(
58-
state=states.FAILURE,
59-
meta={
60-
"exc_type": result["exc_type"],
61-
"exc_message": result["exc_message"],
62-
"custom": json.dumps(error_payload, indent=False),
63-
},
64-
)
65-
66-
6729
def raw_celery_response(backend, task_id: str) -> Dict[str, Any]:
6830
key_info: str = backend.get_key_for_task(task_id)
6931
info_as_str: str = backend.get(key_info)
@@ -103,6 +65,47 @@ def create_celery_service(
10365
else None,
10466
)
10567

68+
class ErrorHandlingTask(Task):
69+
"""Sets a 'custom' field with error in the Task response for FAILURE.
70+
71+
Used when services are ran via the Celery backend.
72+
"""
73+
74+
def after_return(
75+
self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo
76+
) -> None:
77+
"""Handler that ensures custom error response information is available whenever a Task fails.
78+
79+
Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value
80+
:param:`retval` is an `Exception`, this handler extracts information from the `Exception`
81+
and constructs a custom error response JSON value (see :func:`error_response` for details).
82+
83+
This handler then re-propagates the Celery-required exception information (`"exc_type"` and
84+
`"exc_message"`) while adding this new error response information under the `"custom"` key.
85+
"""
86+
if status == states.FAILURE and isinstance(retval, Exception):
87+
logger.warning(f"Setting custom error response for failed task {task_id}")
88+
89+
info: dict = raw_celery_response(self.backend, task_id)
90+
result: dict = info["result"]
91+
err: Exception = retval
92+
93+
error_payload = error_response("Internal failure", err)
94+
95+
# Inspired by pattern from:
96+
# https://www.distributedpython.com/2018/09/28/celery-task-states/
97+
self.update_state(
98+
state=states.FAILURE,
99+
meta={
100+
"exc_type": result["exc_type"],
101+
"exc_message": result["exc_message"],
102+
"custom": json.dumps(error_payload, indent=False),
103+
},
104+
)
105+
request_params = args[0]
106+
request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params)
107+
forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore
108+
106109
# See documentation for options:
107110
# https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options
108111
@app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True)

model-engine/model_engine_server/inference/forwarding/forwarding.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import sseclient
1010
import yaml
1111
from fastapi.responses import JSONResponse
12-
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
1312
from model_engine_server.core.loggers import logger_name, make_logger
1413
from model_engine_server.inference.common import get_endpoint_config
1514
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
@@ -126,7 +125,6 @@ class Forwarder(ModelEngineSerializationMixin):
126125
forward_http_status: bool
127126

128127
def __call__(self, json_payload: Any) -> Any:
129-
request_obj = EndpointPredictV1Request.parse_obj(json_payload)
130128
json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload)
131129
json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload
132130

@@ -163,8 +161,6 @@ def __call__(self, json_payload: Any) -> Any:
163161
if self.wrap_response:
164162
response = self.get_response_payload(using_serialize_results_as_string, response)
165163

166-
# TODO: we actually want to do this after we've returned the response.
167-
self.post_inference_hooks_handler.handle(request_obj, response)
168164
if self.forward_http_status:
169165
return JSONResponse(content=response, status_code=response_raw.status_code)
170166
else:

model-engine/model_engine_server/inference/forwarding/http_forwarder.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import subprocess
55
from functools import lru_cache
66

7-
from fastapi import Depends, FastAPI
7+
from fastapi import BackgroundTasks, Depends, FastAPI
88
from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter
99
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
1010
from model_engine_server.core.loggers import logger_name, make_logger
@@ -70,11 +70,20 @@ def load_streaming_forwarder():
7070
@app.post("/predict")
7171
def predict(
7272
request: EndpointPredictV1Request,
73+
background_tasks: BackgroundTasks,
7374
forwarder=Depends(load_forwarder),
7475
limiter=Depends(get_concurrency_limiter),
7576
):
7677
with limiter:
77-
return forwarder(request.dict())
78+
try:
79+
response = forwarder(request.dict())
80+
background_tasks.add_task(
81+
forwarder.post_inference_hooks_handler.handle, request, response
82+
)
83+
return response
84+
except Exception:
85+
logger.error(f"Failed to decode payload from: {request}")
86+
raise
7887

7988

8089
@app.post("/stream")

model-engine/model_engine_server/inference/post_inference_hooks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import json
12
from abc import ABC, abstractmethod
2-
from typing import Any, Dict, List, Optional
3+
from typing import Any, Dict, List, Optional, Union
34

45
import requests
6+
from fastapi.responses import JSONResponse
57
from model_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK
68
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
79
from model_engine_server.core.loggers import logger_name, make_logger
@@ -108,13 +110,17 @@ def __init__(
108110
def handle(
109111
self,
110112
request_payload: EndpointPredictV1Request,
111-
response: Dict[str, Any],
113+
response: Union[Dict[str, Any], JSONResponse],
112114
task_id: Optional[str] = None,
113115
):
116+
if isinstance(response, JSONResponse):
117+
loaded_response = json.loads(response.body)
118+
else:
119+
loaded_response = response
114120
for hook_name, hook in self._hooks.items():
115121
self._monitoring_metrics_gateway.emit_attempted_post_inference_hook(hook_name)
116122
try:
117-
hook.handle(request_payload, response, task_id)
123+
hook.handle(request_payload, loaded_response, task_id) # pragma: no cover
118124
self._monitoring_metrics_gateway.emit_successful_post_inference_hook(hook_name)
119125
except Exception:
120126
logger.exception(f"Hook {hook_name} failed.")

model-engine/tests/unit/inference/test_http_forwarder.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,23 @@
11
import threading
2-
import time
2+
from dataclasses import dataclass
3+
from typing import Mapping
4+
from unittest import mock
35

46
import pytest
7+
from fastapi import BackgroundTasks
8+
from fastapi.responses import JSONResponse
59
from model_engine_server.common.dtos.tasks import EndpointPredictV1Request
10+
from model_engine_server.inference.forwarding.forwarding import Forwarder
611
from model_engine_server.inference.forwarding.http_forwarder import (
712
MultiprocessingConcurrencyLimiter,
813
predict,
914
)
15+
from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import (
16+
DatadogInferenceMonitoringMetricsGateway,
17+
)
18+
from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler
19+
20+
PAYLOAD: Mapping[str, str] = {"hello": "world"}
1021

1122

1223
class ExceptionCapturedThread(threading.Thread):
@@ -26,21 +37,90 @@ def join(self):
2637
raise self.ex
2738

2839

29-
def mock_forwarder(dict):
30-
time.sleep(1)
31-
return dict
40+
def mocked_get(*args, **kwargs): # noqa
41+
@dataclass
42+
class mocked_static_status_code:
43+
status_code: int = 200
44+
45+
return mocked_static_status_code()
46+
47+
48+
def mocked_post(*args, **kwargs): # noqa
49+
@dataclass
50+
class mocked_static_json:
51+
status_code: int = 200
52+
53+
def json(self) -> dict:
54+
return PAYLOAD # type: ignore
55+
56+
return mocked_static_json()
57+
58+
59+
@pytest.fixture
60+
def post_inference_hooks_handler():
61+
handler = PostInferenceHooksHandler(
62+
endpoint_name="test_endpoint_name",
63+
bundle_name="test_bundle_name",
64+
post_inference_hooks=[],
65+
user_id="test_user_id",
66+
billing_queue="billing_queue",
67+
billing_tags=[],
68+
default_callback_url=None,
69+
default_callback_auth=None,
70+
monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(),
71+
)
72+
return handler
73+
3274

75+
@pytest.fixture
76+
def mock_request():
77+
return EndpointPredictV1Request(
78+
url="test_url",
79+
return_pickled=False,
80+
args={"x": 1},
81+
)
3382

34-
def test_http_service_429():
83+
84+
@mock.patch("requests.post", mocked_post)
85+
@mock.patch("requests.get", mocked_get)
86+
def test_http_service_429(mock_request, post_inference_hooks_handler):
87+
mock_forwarder = Forwarder(
88+
"ignored",
89+
model_engine_unwrap=True,
90+
serialize_results_as_string=False,
91+
post_inference_hooks_handler=post_inference_hooks_handler,
92+
wrap_response=True,
93+
forward_http_status=True,
94+
)
3595
limiter = MultiprocessingConcurrencyLimiter(1, True)
3696
t1 = ExceptionCapturedThread(
37-
target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter)
97+
target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter)
3898
)
3999
t2 = ExceptionCapturedThread(
40-
target=predict, args=(EndpointPredictV1Request(), mock_forwarder, limiter)
100+
target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter)
41101
)
42102
t1.start()
43103
t2.start()
44104
t1.join()
45105
with pytest.raises(Exception): # 429 thrown
46106
t2.join()
107+
108+
109+
def test_handler_response(post_inference_hooks_handler):
110+
try:
111+
post_inference_hooks_handler.handle(
112+
request_payload=mock_request, response=PAYLOAD, task_id="test_task_id"
113+
)
114+
except Exception as e:
115+
pytest.fail(f"Unexpected exception: {e}")
116+
117+
118+
def test_handler_json_response(post_inference_hooks_handler):
119+
try:
120+
post_inference_hooks_handler.handle(
121+
request_payload=mock_request,
122+
response=JSONResponse(content=PAYLOAD),
123+
task_id="test_task_id",
124+
)
125+
except Exception as e:
126+
pytest.fail(f"Unexpected exception: {e}")

0 commit comments

Comments
 (0)