Skip to content

Commit ac40648

Browse files
authored
Use object storage client for telemetry. (#1033)
2 parents ef20d9a + 2de141a commit ac40648

File tree

3 files changed

+93
-56
lines changed

3 files changed

+93
-56
lines changed

ads/telemetry/base.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
import logging
76

8-
from ads import set_auth
7+
import oci
8+
99
from ads.common import oci_client as oc
10-
from ads.common.auth import default_signer
10+
from ads.common.auth import default_signer, resource_principal
1111
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
1212

13-
1413
logger = logging.getLogger(__name__)
14+
15+
1516
class TelemetryBase:
1617
"""Base class for Telemetry Client."""
1718

@@ -25,15 +26,21 @@ def __init__(self, bucket: str, namespace: str = None) -> None:
2526
namespace : str, optional
2627
Namespace of the OCI object storage bucket, by default None.
2728
"""
29+
# Use resource principal as authentication method if available,
30+
# however, do not change the ADS authentication if user configured it by set_auth.
2831
if OCI_RESOURCE_PRINCIPAL_VERSION:
29-
set_auth("resource_principal")
30-
self._auth = default_signer()
31-
self.os_client = oc.OCIClientFactory(**self._auth).object_storage
32+
self._auth = resource_principal()
33+
else:
34+
self._auth = default_signer()
35+
self.os_client: oci.object_storage.ObjectStorageClient = oc.OCIClientFactory(
36+
**self._auth
37+
).object_storage
3238
self.bucket = bucket
3339
self._namespace = namespace
3440
self._service_endpoint = None
35-
logger.debug(f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}")
36-
41+
logger.debug(
42+
f"Initialized Telemetry. Namespace: {self.namespace}, Bucket: {self.bucket}"
43+
)
3744

3845
@property
3946
def namespace(self) -> str:
@@ -58,5 +65,5 @@ def service_endpoint(self):
5865
Tenancy-specific endpoint.
5966
"""
6067
if not self._service_endpoint:
61-
self._service_endpoint = self.os_client.base_client.endpoint
68+
self._service_endpoint = str(self.os_client.base_client.endpoint)
6269
return self._service_endpoint

ads/telemetry/client.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65

76
import logging
87
import threading
8+
import traceback
99
import urllib.parse
10-
import requests
11-
from requests import Response
12-
from .base import TelemetryBase
10+
from typing import Optional
11+
12+
import oci
13+
1314
from ads.config import DEBUG_TELEMETRY
1415

16+
from .base import TelemetryBase
1517

1618
logger = logging.getLogger(__name__)
1719

@@ -32,7 +34,7 @@ class TelemetryClient(TelemetryBase):
3234
>>> import traceback
3335
>>> from ads.telemetry.client import TelemetryClient
3436
>>> AQUA_BUCKET = os.environ.get("AQUA_BUCKET", "service-managed-models")
35-
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "ociodscdev")
37+
>>> AQUA_BUCKET_NS = os.environ.get("AQUA_BUCKET_NS", "namespace")
3638
>>> telemetry = TelemetryClient(bucket=AQUA_BUCKET, namespace=AQUA_BUCKET_NS)
3739
>>> telemetry.record_event_async(category="aqua/service/model", action="create") # records create action
3840
>>> telemetry.record_event_async(category="aqua/service/model/create", action="shape", detail="VM.GPU.A10.1")
@@ -45,7 +47,7 @@ def _encode_user_agent(**kwargs):
4547

4648
def record_event(
4749
self, category: str = None, action: str = None, detail: str = None, **kwargs
48-
) -> Response:
50+
) -> Optional[int]:
4951
"""Send a head request to generate an event record.
5052
5153
Parameters
@@ -62,23 +64,41 @@ def record_event(
6264
6365
Returns
6466
-------
65-
Response
67+
int
68+
The status code for the telemetry request.
69+
200: The the object exists for the telemetry request
70+
404: The the object does not exist for the telemetry request.
71+
Note that for telemetry purpose, the object does not need to be exist.
72+
`None` will be returned if the telemetry request failed.
6673
"""
6774
try:
6875
if not category or not action:
6976
raise ValueError("Please specify the category and the action.")
7077
if detail:
7178
category, action = f"{category}/{action}", detail
79+
# Here `endpoint`` is for debugging purpose
80+
# For some federated/domain users, the `endpoint` may not be a valid URL
7281
endpoint = f"{self.service_endpoint}/n/{self.namespace}/b/{self.bucket}/o/telemetry/{category}/{action}"
73-
headers = {"User-Agent": self._encode_user_agent(**kwargs)}
7482
logger.debug(f"Sending telemetry to endpoint: {endpoint}")
75-
signer = self._auth["signer"]
76-
response = requests.head(endpoint, auth=signer, headers=headers)
77-
logger.debug(f"Telemetry status code: {response.status_code}")
78-
return response
83+
84+
self.os_client.base_client.user_agent = self._encode_user_agent(**kwargs)
85+
try:
86+
response: oci.response.Response = self.os_client.head_object(
87+
namespace_name=self.namespace,
88+
bucket_name=self.bucket,
89+
object_name=f"telemetry/{category}/{action}",
90+
)
91+
logger.debug(f"Telemetry status: {response.status}")
92+
return response.status
93+
except oci.exceptions.ServiceError as ex:
94+
if ex.status == 404:
95+
return ex.status
96+
raise ex
7997
except Exception as e:
8098
if DEBUG_TELEMETRY:
8199
logger.error(f"There is an error recording telemetry: {e}")
100+
traceback.print_exc()
101+
return None
82102

83103
def record_event_async(
84104
self, category: str = None, action: str = None, detail: str = None, **kwargs
Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,59 +1,69 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4+
from unittest.mock import patch
55

6-
7-
from unittest.mock import patch, PropertyMock
6+
import oci
87

98
from ads.telemetry.client import TelemetryClient
109

11-
class TestTelemetryClient:
12-
"""Contains unittests for TelemetryClient."""
10+
TEST_CONFIG = {
11+
"tenancy": "ocid1.tenancy.oc1..unique_ocid",
12+
"user": "ocid1.user.oc1..unique_ocid",
13+
"fingerprint": "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00",
14+
"key_file": "<path>/<to>/<key_file>",
15+
"region": "test_region",
16+
}
1317

14-
endpoint = "https://objectstorage.us-ashburn-1.oraclecloud.com"
18+
EXPECTED_ENDPOINT = "https://objectstorage.test_region.oraclecloud.com"
1519

16-
def mocked_requests_head(*args, **kwargs):
17-
class MockResponse:
18-
def __init__(self, status_code):
19-
self.status_code = status_code
2020

21-
return MockResponse(200)
21+
class TestTelemetryClient:
22+
"""Contains unittests for TelemetryClient."""
2223

23-
@patch('requests.head', side_effect=mocked_requests_head)
24-
@patch('ads.telemetry.client.TelemetryClient.service_endpoint', new_callable=PropertyMock,
25-
return_value=endpoint)
26-
def test_telemetry_client_record_event(self, mock_endpoint, mock_head):
27-
"""Tests TelemetryClient.record_event() with category/action and path, respectively.
28-
"""
24+
@patch("oci.base_client.BaseClient.request")
25+
@patch("oci.signer.Signer")
26+
def test_telemetry_client_record_event(self, signer, request_call):
27+
"""Tests TelemetryClient.record_event() with category/action and path, respectively."""
2928
data = {
3029
"cmd": "ads aqua model list",
3130
"category": "aqua/service/model",
3231
"action": "list",
3332
"bucket": "test_bucket",
3433
"namespace": "test_namespace",
35-
"value": {
36-
"keyword": "test_service_model_name_or_id"
37-
}
34+
"value": {"keyword": "test_service_model_name_or_id"},
3835
}
3936
category = data["category"]
4037
action = data["action"]
4138
bucket = data["bucket"]
4239
namespace = data["namespace"]
4340
value = data["value"]
44-
expected_endpoint = f"{self.endpoint}/n/{namespace}/b/{bucket}/o/telemetry/{category}/{action}"
4541

46-
telemetry = TelemetryClient(bucket=bucket, namespace=namespace)
42+
with patch("oci.config.from_file", return_value=TEST_CONFIG):
43+
telemetry = TelemetryClient(bucket=bucket, namespace=namespace)
4744
telemetry.record_event(category=category, action=action)
4845
telemetry.record_event(category=category, action=action, **value)
4946

50-
expected_headers = [
51-
{'User-Agent': ''},
52-
{'User-Agent': 'keyword=test_service_model_name_or_id'}
47+
expected_agent_headers = [
48+
"",
49+
"keyword=test_service_model_name_or_id",
5350
]
54-
i = 0
55-
for call_args in mock_head.call_args_list:
56-
args, kwargs = call_args
57-
assert all(endpoint == expected_endpoint for endpoint in args)
58-
assert kwargs['headers'] == expected_headers[i]
59-
i += 1
51+
52+
assert len(request_call.call_args_list) == 2
53+
expected_url = f"{EXPECTED_ENDPOINT}/n/{namespace}/b/{bucket}/o/telemetry/{category}/{action}"
54+
55+
# Event #1, no user-agent
56+
args, _ = request_call.call_args_list[0]
57+
request: oci.request.Request = args[0]
58+
operation = args[2]
59+
assert request.url == expected_url
60+
assert operation == "head_object"
61+
assert request.header_params["user-agent"] == expected_agent_headers[0]
62+
63+
# Event #2, with user-agent
64+
args, _ = request_call.call_args_list[1]
65+
request: oci.request.Request = args[0]
66+
operation = args[2]
67+
assert request.url == expected_url
68+
assert operation == "head_object"
69+
assert request.header_params["user-agent"] == expected_agent_headers[1]

0 commit comments

Comments
 (0)