Skip to content

Commit 3e7fd13

Browse files
authored
Improvements on handling error raised in AQUA
1 parent 58cb61f commit 3e7fd13

File tree

11 files changed

+266
-51
lines changed

11 files changed

+266
-51
lines changed

ads/aqua/decorator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
RequestException,
1818
ServiceError,
1919
)
20+
from tornado.web import HTTPError
2021

2122
from ads.aqua.exception import AquaError
2223
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -58,6 +59,7 @@ def inner_function(self: AquaAPIhandler, *args, **kwargs):
5859
except ServiceError as error:
5960
self.write_error(
6061
status_code=error.status or 500,
62+
message=error.message,
6163
reason=error.message,
6264
service_payload=error.args[0] if error.args else None,
6365
exc_info=sys.exc_info(),
@@ -91,6 +93,12 @@ def inner_function(self: AquaAPIhandler, *args, **kwargs):
9193
service_payload=error.service_payload,
9294
exc_info=sys.exc_info(),
9395
)
96+
except HTTPError as e:
97+
self.write_error(
98+
status_code=e.status_code,
99+
reason=e.log_message,
100+
exc_info=sys.exc_info(),
101+
)
94102
except Exception as ex:
95103
self.write_error(
96104
status_code=500,

ads/aqua/extension/base_handler.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
import traceback
99
import uuid
1010
from dataclasses import asdict, is_dataclass
11+
from http.client import responses
1112
from typing import Any
1213

1314
from notebook.base.handlers import APIHandler
14-
from tornado.web import HTTPError, Application
1515
from tornado import httputil
16-
from ads.telemetry.client import TelemetryClient
17-
from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
16+
from tornado.web import Application, HTTPError
17+
1818
from ads.aqua import logger
19+
from ads.config import AQUA_TELEMETRY_BUCKET, AQUA_TELEMETRY_BUCKET_NS
20+
from ads.telemetry.client import TelemetryClient
1921

2022

2123
class AquaAPIhandler(APIHandler):
@@ -66,12 +68,15 @@ def finish(self, payload=None): # pylint: disable=W0221
6668

6769
def write_error(self, status_code, **kwargs):
6870
"""AquaAPIhandler errors are JSON, not human pages."""
69-
7071
self.set_header("Content-Type", "application/json")
7172
reason = kwargs.get("reason")
7273
self.set_status(status_code, reason=reason)
7374
service_payload = kwargs.get("service_payload", {})
74-
message = self.get_default_error_messages(service_payload, str(status_code))
75+
default_msg = responses.get(status_code, "Unknown HTTP Error")
76+
message = self.get_default_error_messages(
77+
service_payload, str(status_code), kwargs.get("message", default_msg)
78+
)
79+
7580
reply = {
7681
"status": status_code,
7782
"message": message,
@@ -84,7 +89,7 @@ def write_error(self, status_code, **kwargs):
8489
e = exc_info[1]
8590
if isinstance(e, HTTPError):
8691
reply["message"] = e.log_message or message
87-
reply["reason"] = e.reason
92+
reply["reason"] = e.reason if e.reason else reply["reason"]
8893
reply["request_id"] = str(uuid.uuid4())
8994
else:
9095
reply["request_id"] = str(uuid.uuid4())
@@ -102,15 +107,18 @@ def write_error(self, status_code, **kwargs):
102107
self.finish(json.dumps(reply))
103108

104109
@staticmethod
105-
def get_default_error_messages(service_payload: dict, status_code: str):
110+
def get_default_error_messages(
111+
service_payload: dict,
112+
status_code: str,
113+
default_msg: str = "Unknown HTTP Error.",
114+
):
106115
"""Method that maps the error messages based on the operation performed or the status codes encountered."""
107116

108117
messages = {
109118
"400": "Something went wrong with your request.",
110119
"403": "We're having trouble processing your request with the information provided.",
111120
"404": "Authorization Failed: The resource you're looking for isn't accessible.",
112121
"408": "Server is taking too long to response, please try again.",
113-
"500": "An error occurred while creating the resource.",
114122
"create": "Authorization Failed: Could not create resource.",
115123
"get": "Authorization Failed: The resource you're looking for isn't accessible.",
116124
}
@@ -128,7 +136,7 @@ def get_default_error_messages(service_payload: dict, status_code: str):
128136
if status_code in messages:
129137
return messages[status_code]
130138
else:
131-
return "Unknown HTTP Error."
139+
return default_msg
132140

133141

134142
# todo: remove after error handler is implemented

ads/aqua/extension/common_handler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,24 @@
66

77
from importlib import metadata
88

9-
from ads.aqua.extension.base_handler import AquaAPIhandler
109
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
10+
from ads.aqua.decorator import handle_exceptions
1111
from ads.aqua.exception import AquaResourceAccessError
12+
from ads.aqua.extension.base_handler import AquaAPIhandler
1213

1314

1415
class ADSVersionHandler(AquaAPIhandler):
1516
"""The handler to get the current version of the ADS."""
1617

18+
@handle_exceptions
1719
def get(self):
1820
self.finish({"data": metadata.version("oracle_ads")})
1921

2022

2123
class CompatibilityCheckHandler(AquaAPIhandler):
2224
"""The handler to check if the extension is compatible."""
2325

26+
@handle_exceptions
2427
def get(self):
2528
if ODSC_MODEL_COMPARTMENT_OCID:
2629
return self.finish(dict(status="ok"))

ads/aqua/extension/deployment_handler.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77

88
from tornado.web import HTTPError
99

10+
from ads.aqua.decorator import handle_exceptions
1011
from ads.aqua.deployment import AquaDeploymentApp, MDInferenceResponse, ModelParams
1112
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
1213
from ads.config import COMPARTMENT_OCID, PROJECT_OCID
13-
from ads.aqua.decorator import handle_exceptions
1414

1515

1616
class AquaDeploymentHandler(AquaAPIhandler):
@@ -110,12 +110,10 @@ def post(self, *args, **kwargs):
110110
)
111111
)
112112

113-
@handle_exceptions
114113
def read(self, id):
115114
"""Read the information of an Aqua model deployment."""
116115
return self.finish(AquaDeploymentApp().get(model_deployment_id=id))
117116

118-
@handle_exceptions
119117
def list(self):
120118
"""List Aqua models."""
121119
# If default is not specified,
@@ -129,7 +127,6 @@ def list(self):
129127
)
130128
)
131129

132-
@handle_exceptions
133130
def get_deployment_config(self, model_id):
134131
"""Gets the deployment config for Aqua model."""
135132
return self.finish(AquaDeploymentApp().get_deployment_config(model_id=model_id))

ads/aqua/extension/evaluation_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66
from urllib.parse import urlparse
77

8-
from requests import HTTPError
8+
from tornado.web import HTTPError
99

1010
from ads.aqua.decorator import handle_exceptions
1111
from ads.aqua.evaluation import AquaEvaluationApp, CreateAquaEvaluationDetails
12-
from ads.aqua.exception import AquaError
1312
from ads.aqua.extension.base_handler import AquaAPIhandler, Errors
1413
from ads.aqua.extension.utils import validate_function_parameters
1514
from ads.config import COMPARTMENT_OCID

ads/aqua/extension/finetune_handler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def post(self, *args, **kwargs):
5454

5555
self.finish(AquaFineTuningApp().create(CreateFineTuningDetails(**input_data)))
5656

57-
@handle_exceptions
5857
def get_finetuning_config(self, model_id):
5958
"""Gets the finetuning config for Aqua model."""
6059
return self.finish(AquaFineTuningApp().get_finetuning_config(model_id=model_id))

ads/aqua/extension/ui_handler.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class AquaUIHandler(AquaAPIhandler):
3434
HTTPError: For various failure scenarios such as invalid input format, missing data, etc.
3535
"""
3636

37+
@handle_exceptions
3738
def get(self, id=""):
3839
"""Handle GET request."""
3940
url_parse = urlparse(self.request.path)
@@ -76,30 +77,25 @@ def delete(self, id=""):
7677
else:
7778
raise HTTPError(400, f"The request {self.request.path} is invalid.")
7879

79-
@handle_exceptions
8080
def list_log_groups(self, **kwargs):
8181
"""Lists all log groups for the specified compartment or tenancy."""
8282
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
8383
return self.finish(
8484
AquaUIApp().list_log_groups(compartment_id=compartment_id, **kwargs)
8585
)
8686

87-
@handle_exceptions
8887
def list_logs(self, log_group_id: str, **kwargs):
8988
"""Lists the specified log group's log objects."""
9089
return self.finish(AquaUIApp().list_logs(log_group_id=log_group_id, **kwargs))
9190

92-
@handle_exceptions
9391
def list_compartments(self):
9492
"""Lists the compartments in a compartment specified by ODSC_MODEL_COMPARTMENT_OCID env variable."""
9593
return self.finish(AquaUIApp().list_compartments())
9694

97-
@handle_exceptions
9895
def get_default_compartment(self):
9996
"""Returns user compartment ocid."""
10097
return self.finish(AquaUIApp().get_default_compartment())
10198

102-
@handle_exceptions
10399
def list_model_version_sets(self, **kwargs):
104100
"""Lists all model version sets for the specified compartment or tenancy."""
105101

@@ -112,7 +108,6 @@ def list_model_version_sets(self, **kwargs):
112108
)
113109
)
114110

115-
@handle_exceptions
116111
def list_experiments(self, **kwargs):
117112
"""Lists all experiments for the specified compartment or tenancy."""
118113

@@ -125,7 +120,6 @@ def list_experiments(self, **kwargs):
125120
)
126121
)
127122

128-
@handle_exceptions
129123
def list_buckets(self, **kwargs):
130124
"""Lists all model version sets for the specified compartment or tenancy."""
131125
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
@@ -138,23 +132,20 @@ def list_buckets(self, **kwargs):
138132
)
139133
)
140134

141-
@handle_exceptions
142135
def list_job_shapes(self, **kwargs):
143136
"""Lists job shapes available in the specified compartment."""
144137
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
145138
return self.finish(
146139
AquaUIApp().list_job_shapes(compartment_id=compartment_id, **kwargs)
147140
)
148141

149-
@handle_exceptions
150142
def list_vcn(self, **kwargs):
151143
"""Lists the virtual cloud networks (VCNs) in the specified compartment."""
152144
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
153145
return self.finish(
154146
AquaUIApp().list_vcn(compartment_id=compartment_id, **kwargs)
155147
)
156148

157-
@handle_exceptions
158149
def list_subnets(self, **kwargs):
159150
"""Lists the subnets in the specified VCN and the specified compartment."""
160151
compartment_id = self.get_argument("compartment_id", default=COMPARTMENT_OCID)
@@ -165,7 +156,6 @@ def list_subnets(self, **kwargs):
165156
)
166157
)
167158

168-
@handle_exceptions
169159
def get_shape_availability(self, **kwargs):
170160
"""For a given compartmentId, resource limit name, and scope, returns the number of available resources associated
171161
with the given limit."""
@@ -178,7 +168,6 @@ def get_shape_availability(self, **kwargs):
178168
)
179169
)
180170

181-
@handle_exceptions
182171
def is_bucket_versioned(self):
183172
"""For a given compartmentId, resource limit name, and scope, returns the number of available resources associated
184173
with the given limit."""

ads/aqua/extension/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55
from dataclasses import fields
66
from typing import Dict, Optional
7-
from requests import HTTPError
7+
8+
from tornado.web import HTTPError
89

910
from ads.aqua.extension.base_handler import Errors
1011

1112

1213
def validate_function_parameters(data_class, input_data: Dict):
13-
"""Validates if the required parameters are provided in input data."""
14+
"""Validates if the required parameters are provided in input data."""
1415
required_parameters = [
15-
field.name for field in fields(data_class)
16-
if field.type != Optional[field.type]
16+
field.name for field in fields(data_class) if field.type != Optional[field.type]
1717
]
1818

1919
for required_parameter in required_parameters:

0 commit comments

Comments
 (0)