7
7
from typing import Optional
8
8
from urllib .parse import urlparse
9
9
10
- from huggingface_hub import HfApi
11
- from huggingface_hub .utils import (
12
- GatedRepoError ,
13
- HfHubHTTPError ,
14
- RepositoryNotFoundError ,
15
- RevisionNotFoundError ,
16
- )
17
10
from tornado .web import HTTPError
18
-
19
- from ads .aqua .common .decorator import handle_exceptions
20
- from ads .aqua .common .errors import AquaRuntimeError
21
11
from ads .aqua .extension .errors import Errors
12
+ from ads .aqua .common .decorator import handle_exceptions
22
13
from ads .aqua .extension .base_handler import AquaAPIhandler
23
14
from ads .aqua .model import AquaModelApp
24
- from ads .aqua .model .constants import ModelTask
25
- from ads .aqua .model .entities import AquaModelSummary , HFModelSummary
26
15
27
16
28
17
class AquaModelHandler (AquaAPIhandler ):
@@ -63,101 +52,14 @@ def list(self):
63
52
)
64
53
)
65
54
66
-
67
- class AquaModelLicenseHandler (AquaAPIhandler ):
68
- """Handler for Aqua Model license REST APIs."""
69
-
70
- @handle_exceptions
71
- def get (self , model_id ):
72
- """Handle GET request."""
73
-
74
- model_id = model_id .split ("/" )[0 ]
75
- return self .finish (AquaModelApp ().load_license (model_id ))
76
-
77
-
78
- class AquaHuggingFaceHandler (AquaAPIhandler ):
79
- """Handler for Aqua Hugging Face REST APIs."""
80
-
81
- def _find_matching_aqua_model (self , model_id : str ) -> Optional [AquaModelSummary ]:
82
- """
83
- Finds a matching model in AQUA based on the model ID from Hugging Face.
84
-
85
- Parameters
86
- ----------
87
- model_id (str): The Hugging Face model ID to match.
88
-
89
- Returns
90
- -------
91
- Optional[AquaModelSummary]
92
- Returns the matching AquaModelSummary object if found, else None.
93
- """
94
- # Convert the Hugging Face model ID to lowercase once
95
- model_id_lower = model_id .lower ()
96
-
97
- aqua_model_app = AquaModelApp ()
98
- model_ocid = aqua_model_app ._find_matching_aqua_model (model_id = model_id_lower )
99
- if model_ocid :
100
- return aqua_model_app .get (model_ocid , load_model_card = False )
101
-
102
- return None
103
-
104
- def _format_custom_error_message (self , error : HfHubHTTPError ) -> AquaRuntimeError :
105
- """
106
- Formats a custom error message based on the Hugging Face error response.
107
-
108
- Parameters
109
- ----------
110
- error (HfHubHTTPError): The caught exception.
111
-
112
- Raises
113
- ------
114
- AquaRuntimeError: A user-friendly error message.
115
- """
116
- # Extract the repository URL from the error message if present
117
- match = re .search (r"(https://huggingface.co/[^\s]+)" , str (error ))
118
- url = match .group (1 ) if match else "the requested Hugging Face URL."
119
-
120
- if isinstance (error , RepositoryNotFoundError ):
121
- raise AquaRuntimeError (
122
- reason = f"Failed to access `{ url } `. Please check if the provided repository name is correct. "
123
- "If the repo is private, make sure you are authenticated and have a valid HF token registered. "
124
- "To register your token, run this command in your terminal: `huggingface-cli login`" ,
125
- service_payload = {"error" : "RepositoryNotFoundError" },
126
- )
127
-
128
- if isinstance (error , GatedRepoError ):
129
- raise AquaRuntimeError (
130
- reason = f"Access denied to `{ url } ` "
131
- "This repository is gated. Access is restricted to authorized users. "
132
- "Please request access or check with the repository administrator. "
133
- "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
134
- "To register your token, run this command in your terminal: `huggingface-cli login`" ,
135
- service_payload = {"error" : "GatedRepoError" },
136
- )
137
-
138
- if isinstance (error , RevisionNotFoundError ):
139
- raise AquaRuntimeError (
140
- reason = f"The specified revision could not be found at `{ url } ` "
141
- "Please check the revision identifier and try again." ,
142
- service_payload = {"error" : "RevisionNotFoundError" },
143
- )
144
-
145
- raise AquaRuntimeError (
146
- reason = f"An error occurred while accessing `{ url } ` "
147
- "Please check your network connection and try again. "
148
- "If you are trying to access a gated repository, ensure you have a valid HF token registered. "
149
- "To register your token, run this command in your terminal: `huggingface-cli login`" ,
150
- service_payload = {"error" : "Error" },
151
- )
152
-
153
55
@handle_exceptions
154
56
def post (self , * args , ** kwargs ):
155
- """Handles post request for the HF Models APIs
156
-
57
+ """
58
+ Handles post request for the registering any Aqua model.
157
59
Raises
158
60
------
159
61
HTTPError
160
- Raises HTTPError if inputs are missing or are invalid.
62
+ Raises HTTPError if inputs are missing or are invalid
161
63
"""
162
64
try :
163
65
input_data = self .get_json_body ()
@@ -167,48 +69,41 @@ def post(self, *args, **kwargs):
167
69
if not input_data :
168
70
raise HTTPError (400 , Errors .NO_INPUT_DATA )
169
71
170
- model_id = input_data .get ("model_id" )
171
- token = input_data .get ("token" )
72
+ # required input parameters
73
+ model = input_data .get ("model" )
74
+ if not model :
75
+ raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("model" ))
76
+ os_path = input_data .get ("os_path" )
77
+ if not os_path :
78
+ raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("os_path" ))
172
79
173
- if not model_id :
174
- raise HTTPError (400 , Errors .MISSING_REQUIRED_PARAMETER .format ("model_id" ))
80
+ inference_container = input_data .get ("inference_container" )
81
+ finetuning_container = input_data .get ("finetuning_container" )
82
+ compartment_id = input_data .get ("compartment_id" )
83
+ project_id = input_data .get ("project_id" )
175
84
176
- # Get model info from the HF
177
- try :
178
- hf_model_info = HfApi (token = token ).model_info (model_id )
179
- except HfHubHTTPError as err :
180
- raise self ._format_custom_error_message (err )
181
-
182
- # Check if model is not disabled
183
- if hf_model_info .disabled :
184
- raise AquaRuntimeError (
185
- f"The chosen model '{ hf_model_info .id } ' is currently disabled and cannot be imported into AQUA. "
186
- "Please verify the model's status on the Hugging Face Model Hub or select a different model."
85
+ return self .finish (
86
+ AquaModelApp ().register (
87
+ model = model ,
88
+ os_path = os_path ,
89
+ inference_container = inference_container ,
90
+ finetuning_container = finetuning_container ,
91
+ compartment_id = compartment_id ,
92
+ project_id = project_id ,
187
93
)
94
+ )
188
95
189
- # Check pipeline_tag, it should be `text-generation`
190
- if (
191
- not hf_model_info .pipeline_tag
192
- or hf_model_info .pipeline_tag .lower () != ModelTask .TEXT_GENERATION
193
- ):
194
- raise AquaRuntimeError (
195
- f"Unsupported pipeline tag for the chosen model: '{ hf_model_info .pipeline_tag } '. "
196
- f"AQUA currently supports the following tasks only: { ', ' .join (ModelTask .values ())} . "
197
- "Please select a model with a compatible pipeline tag."
198
- )
199
96
200
- # Check if it is a service/verified model
201
- aqua_model_info : AquaModelSummary = self ._find_matching_aqua_model (
202
- model_id = hf_model_info .id
203
- )
97
+ class AquaModelLicenseHandler (AquaAPIhandler ):
98
+ """Handler for Aqua Model license REST APIs."""
204
99
205
- return self .finish (
206
- HFModelSummary (model_info = hf_model_info , aqua_model_info = aqua_model_info )
207
- )
100
+ @handle_exceptions
101
+ def get (self , model_id ):
102
+ """Handle GET request."""
103
+ return self .finish (AquaModelApp ().load_license (model_id ))
208
104
209
105
210
106
__handlers__ = [
211
107
("model/?([^/]*)" , AquaModelHandler ),
212
108
("model/?([^/]*)/license" , AquaModelLicenseHandler ),
213
- ("model/hf/search/?([^/]*)" , AquaHuggingFaceHandler ),
214
109
]
0 commit comments