Skip to content

Commit 1f32089

Browse files
authored
fix(adserver): adserver returns cloudevents compatible response (#5348)
* modify tests to check that adserver returns CE-compatible responses * refactor server post handler to return CE-compatible responses
1 parent 33dc760 commit 1f32089

File tree

4 files changed

+100
-39
lines changed

4 files changed

+100
-39
lines changed

components/alibi-detect-server/adserver/base/storage.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import sys
33
import logging
44
import tempfile
5-
from distutils.util import strtobool
5+
from typing import Optional
66

77
ARTIFACT_DOWNLOAD_LOCATION = os.environ.get("DRIFT_ARTIFACTS_DIR", "/tmp")
88

@@ -18,13 +18,13 @@
1818

1919

2020
class Rclone:
21-
def __init__(self, cfg_file: str = None):
21+
def __init__(self, cfg_file: Optional[str] = None):
2222
self.cfg_file = cfg_file
2323

24-
def copy(self, src: str, dest: str = None):
24+
def copy(self, src: str, dest: Optional[str] = None):
2525
if rclone is None:
2626
raise RuntimeError(
27-
"rclone binary not found - rclone-based storage funcionality disabled"
27+
"rclone binary not found - rclone-based storage functionality disabled"
2828
)
2929

3030
if dest is None:

components/alibi-detect-server/adserver/cm_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
SELDON_PREDICTOR_ID = DEFAULT_LABELS["predictor_name"]
2222

2323

24-
def _load_class_module(module_path: str) -> str:
24+
def _load_class_module(module_path: str):
2525
components = module_path.split(".")
2626
mod = __import__(".".join(components[:-1]))
2727
for comp in components[1:]:
@@ -32,7 +32,7 @@ def _load_class_module(module_path: str) -> str:
3232

3333
class CustomMetricsModel(CEModel): # pylint:disable=c-extension-no-member
3434
def __init__(
35-
self, name: str, storage_uri: str, elasticsearch_uri: str = None, model=None
35+
self, name: str, storage_uri: str, elasticsearch_uri: Optional[str] = None, model=None
3636
):
3737
"""
3838
Custom Metrics Model

components/alibi-detect-server/adserver/server.py

Lines changed: 71 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def __init__(
3939
event_type: str,
4040
event_source: str,
4141
http_port: int = DEFAULT_HTTP_PORT,
42-
reply_url: str = None,
42+
reply_url: Optional[str] = None,
4343
):
4444
"""
4545
CloudEvents server
@@ -146,29 +146,21 @@ def get_request_handler(protocol, request: Dict) -> RequestHandler:
146146
raise Exception(f"Unknown protocol {protocol}")
147147

148148

149-
def sendCloudEvent(event: v1.Event, url: str):
149+
def forward_request(headers, data, url):
150150
"""
151-
Send CloudEvent
151+
Forward request
152152
153153
Parameters
154154
----------
155-
event
156-
CloudEvent to send
155+
headers
156+
Headers to forward
157+
data
158+
Data to forward
157159
url
158-
Url to send event
160+
Url to forward to
159161
160162
"""
161-
http_marshaller = marshaller.NewDefaultHTTPMarshaller()
162-
binary_headers, binary_data = http_marshaller.ToRequest(
163-
event, converters.TypeBinary, json.dumps
164-
)
165-
166-
logging.info("binary CloudEvent")
167-
for k, v in binary_headers.items():
168-
logging.info("{0}: {1}\r\n".format(k, v))
169-
logging.info(binary_data)
170-
171-
response = requests.post(url, headers=binary_headers, data=binary_data)
163+
response = requests.post(url, headers=headers, data=data)
172164
response.raise_for_status()
173165

174166

@@ -252,27 +244,73 @@ def post(self):
252244
else:
253245
logging.error("Metrics returned are invalid: " + str(runtime_metrics))
254246

255-
if response.data is not None:
247+
revent = create_cloud_event(
248+
response.data,
249+
self.event_type,
250+
self.event_source,
251+
event_id=event.EventID(),
252+
extensions=event.Extensions(),
253+
)
256254

255+
if response.data is not None:
257256
# Create event from response if reply_url is active
257+
revent_headers, revent_data = http_marshaller.ToRequest(
258+
revent, converters.TypeBinary, json.dumps
259+
)
260+
258261
if not self.reply_url == "":
259-
if event.EventID() is None or event.EventID() == "":
260-
resp_event_id = uuid.uuid1().hex
261-
else:
262-
resp_event_id = event.EventID()
263-
revent = (
264-
v1.Event()
265-
.SetContentType("application/json")
266-
.SetData(response.data)
267-
.SetEventID(resp_event_id)
268-
.SetSource(self.event_source)
269-
.SetEventType(self.event_type)
270-
.SetExtensions(event.Extensions())
271-
)
272262
logging.debug(json.dumps(revent.Properties()))
273-
sendCloudEvent(revent, self.reply_url)
274-
self.write(json.dumps(response.data))
263+
logging.info("binary CloudEvent")
264+
for k, v in revent_headers.items():
265+
logging.info("{0}: {1}\r\n".format(k, v))
266+
logging.info(revent_data)
267+
forward_request(revent_headers, revent_data, self.reply_url)
268+
269+
self.set_header("Content-Type", "application/json")
270+
for headers in revent_headers:
271+
self.set_header(headers, revent_headers[headers])
272+
self.write(revent_data)
273+
274+
275+
def create_cloud_event(
276+
data: dict,
277+
event_type: str,
278+
event_source: str,
279+
extensions: dict,
280+
event_id: str = None,
281+
) -> v1.Event:
282+
"""
283+
Create a CloudEvent
284+
285+
Parameters
286+
----------
287+
data
288+
The data to send
289+
event_type
290+
The CE event type
291+
event_source
292+
The CE event source
293+
extensions
294+
Any extensions to add
295+
event_id
296+
The event id
297+
Returns
298+
-------
299+
A CloudEvent
275300
301+
"""
302+
if event_id is None or event_id == "":
303+
event_id = uuid.uuid1().hex
304+
305+
event = (
306+
v1.Event()
307+
.SetData(data)
308+
.SetEventID(event_id if event_id else str(uuid.uuid1().hex))
309+
.SetSource(event_source)
310+
.SetEventType(event_type)
311+
.SetExtensions(extensions)
312+
)
313+
return event
276314

277315
class LivenessHandler(tornado.web.RequestHandler):
278316
def get(self):

components/alibi-detect-server/adserver/tests/test_server.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from typing import List, Dict, Optional, Union
55
import json
66
import requests_mock
7+
from cloudevents.sdk import converters
8+
from cloudevents.sdk import marshaller
9+
from cloudevents.sdk.event import v1
710

811

912
class TestProtocol(AsyncHTTPTestCase):
@@ -74,11 +77,31 @@ def test_basic(self):
7477
)
7578
self.assertEqual(response.code, 200)
7679
expectedResponse = DummyModel.getResponse().data
80+
# assert that the expected response conforms to the CloudEvent spec
81+
event = v1.Event()
82+
http_marshaller = marshaller.NewDefaultHTTPMarshaller()
83+
try:
84+
event = http_marshaller.FromRequest(
85+
event, response.headers, response.body, json.loads
86+
)
87+
except Exception as e:
88+
assert False, f"Failed to unmarshall data with error: {type(e).__name__}('{e}')"
89+
90+
# assert cloud event properties have been set correctly in response
91+
self.assertEqual(event.Data(), expectedResponse)
92+
self.assertEqual(event.Source(), self.eventSource)
93+
self.assertEqual(event.EventType(), self.eventType)
94+
self.assertEqual(event.ContentType(), "application/json")
95+
self.assertEqual(event.EventID(), "1234")
96+
self.assertEqual(event.CloudEventVersion(), "1.0")
7797
self.assertEqual(response.body.decode("utf-8"), json.dumps(expectedResponse))
98+
99+
# assert requests have been made with the correct headers and data
78100
self.assertEqual(m.request_history[0].json(), expectedResponse)
79101
headers: Dict = m.request_history[0]._request.headers
80102
self.assertEqual(headers["ce-source"], self.eventSource)
81103
self.assertEqual(headers["ce-type"], self.eventType)
104+
self.assertNotIn("ce-datacontenttype", headers)
82105

83106

84107
class TestKFservingV2HttpModel(AsyncHTTPTestCase):

0 commit comments

Comments
 (0)