Skip to content

Commit 8b37351

Browse files
Model creation from Object Storage (#871)
2 parents c2e07ac + 981ec42 commit 8b37351

File tree

9 files changed

+279
-893
lines changed

9 files changed

+279
-893
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,6 @@ htmllistparse
114114
* Source code: https://github.com/gumblex/htmllisting-parser
115115
* Project home: https://github.com/gumblex/htmllisting-parser
116116

117-
huggingface_hub
118-
* Copyright 2023-present, the HuggingFace Inc. team.
119-
* License: Apache-2.0 license
120-
* Source code: https://github.com/huggingface/huggingface_hub
121-
* Project home: https://github.com/huggingface/huggingface_hub
122-
123117
ibisframework
124118

125119
* Copyright 2015 Cloudera Inc.

ads/aqua/common/utils.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import os
1111
import random
1212
import re
13-
import shlex
14-
import subprocess
1513
from functools import wraps
1614
from pathlib import Path
1715
from string import Template
@@ -29,7 +27,7 @@
2927
)
3028
from ads.aqua.constants import *
3129
from ads.aqua.data import AquaResourceIdentifier
32-
from ads.common.auth import AuthState, default_signer
30+
from ads.common.auth import default_signer
3331
from ads.common.extended_enum import ExtendedEnumMeta
3432
from ads.common.object_storage_details import ObjectStorageDetails
3533
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
@@ -724,33 +722,6 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
724722
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
725723

726724

727-
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
728-
"""Upload the local folder to the object storage
729-
730-
Args:
731-
os_path (str): object storage URI with prefix. This is the path to upload
732-
local_dir (str): Local directory where the object is downloaded
733-
model_name (str): Name of the huggingface model
734-
Retuns:
735-
str: Object name inside the bucket
736-
"""
737-
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
738-
if not os_details.is_bucket_versioned():
739-
raise ValueError(f"Version is not enabled at object storage location {os_path}")
740-
auth_state = AuthState()
741-
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
742-
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
743-
try:
744-
logger.info(f"Running: {command}")
745-
subprocess.check_call(shlex.split(command))
746-
except subprocess.CalledProcessError as e:
747-
logger.error(
748-
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
749-
)
750-
751-
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
752-
753-
754725
def is_service_managed_container(container):
755726
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
756727

ads/aqua/extension/common_handler.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@
66

77
from importlib import metadata
88

9-
import huggingface_hub
109
import requests
11-
from huggingface_hub import HfApi
12-
from huggingface_hub.utils import LocalTokenNotFoundError
1310
from tornado.web import HTTPError
1411

1512
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
@@ -58,66 +55,7 @@ def get(self):
5855
)
5956

6057

61-
class NetworkStatusHandler(AquaAPIhandler):
62-
"""Handler to check internet connection."""
63-
64-
@handle_exceptions
65-
def get(self):
66-
requests.get("https://huggingface.com", timeout=2)
67-
return self.finish("success")
68-
69-
70-
class HFLoginHandler(AquaAPIhandler):
71-
"""Handler to login to HF."""
72-
73-
@handle_exceptions
74-
def post(self, *args, **kwargs):
75-
"""Handles post request for the HF login.
76-
77-
Raises
78-
------
79-
HTTPError
80-
Raises HTTPError if inputs are missing or are invalid.
81-
"""
82-
try:
83-
input_data = self.get_json_body()
84-
except Exception:
85-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
86-
87-
if not input_data:
88-
raise HTTPError(400, Errors.NO_INPUT_DATA)
89-
90-
token = input_data.get("token")
91-
92-
if not token:
93-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))
94-
95-
# Login to HF
96-
huggingface_hub.login(token=token, new_session=False)
97-
98-
return self.finish("success")
99-
100-
101-
class HFUserStatusHandler(AquaAPIhandler):
102-
"""Handler to check if user logged in to the HF."""
103-
104-
@handle_exceptions
105-
def get(self):
106-
try:
107-
HfApi().whoami()
108-
except LocalTokenNotFoundError:
109-
raise AquaRuntimeError(
110-
"You are not logged in. Please log in to Hugging Face using the `huggingface-cli login` command."
111-
"See https://huggingface.co/settings/tokens.",
112-
)
113-
114-
return self.finish("success")
115-
116-
11758
__handlers__ = [
11859
("ads_version", ADSVersionHandler),
11960
("hello", CompatibilityCheckHandler),
120-
("network_status", NetworkStatusHandler),
121-
("hf_login", HFLoginHandler),
122-
("hf_logged_in", HFUserStatusHandler),
12361
]

ads/aqua/extension/model_handler.py

Lines changed: 30 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,11 @@
77
from typing import Optional
88
from urllib.parse import urlparse
99

10-
from huggingface_hub import HfApi
11-
from huggingface_hub.utils import (
12-
GatedRepoError,
13-
HfHubHTTPError,
14-
RepositoryNotFoundError,
15-
RevisionNotFoundError,
16-
)
1710
from tornado.web import HTTPError
18-
19-
from ads.aqua.common.decorator import handle_exceptions
20-
from ads.aqua.common.errors import AquaRuntimeError
2111
from ads.aqua.extension.errors import Errors
12+
from ads.aqua.common.decorator import handle_exceptions
2213
from ads.aqua.extension.base_handler import AquaAPIhandler
2314
from ads.aqua.model import AquaModelApp
24-
from ads.aqua.model.constants import ModelTask
25-
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
2615

2716

2817
class AquaModelHandler(AquaAPIhandler):
@@ -63,101 +52,14 @@ def list(self):
6352
)
6453
)
6554

66-
67-
class AquaModelLicenseHandler(AquaAPIhandler):
68-
"""Handler for Aqua Model license REST APIs."""
69-
70-
@handle_exceptions
71-
def get(self, model_id):
72-
"""Handle GET request."""
73-
74-
model_id = model_id.split("/")[0]
75-
return self.finish(AquaModelApp().load_license(model_id))
76-
77-
78-
class AquaHuggingFaceHandler(AquaAPIhandler):
79-
"""Handler for Aqua Hugging Face REST APIs."""
80-
81-
def _find_matching_aqua_model(self, model_id: str) -> Optional[AquaModelSummary]:
82-
"""
83-
Finds a matching model in AQUA based on the model ID from Hugging Face.
84-
85-
Parameters
86-
----------
87-
model_id (str): The Hugging Face model ID to match.
88-
89-
Returns
90-
-------
91-
Optional[AquaModelSummary]
92-
Returns the matching AquaModelSummary object if found, else None.
93-
"""
94-
# Convert the Hugging Face model ID to lowercase once
95-
model_id_lower = model_id.lower()
96-
97-
aqua_model_app = AquaModelApp()
98-
model_ocid = aqua_model_app._find_matching_aqua_model(model_id=model_id_lower)
99-
if model_ocid:
100-
return aqua_model_app.get(model_ocid, load_model_card=False)
101-
102-
return None
103-
104-
def _format_custom_error_message(self, error: HfHubHTTPError) -> AquaRuntimeError:
105-
"""
106-
Formats a custom error message based on the Hugging Face error response.
107-
108-
Parameters
109-
----------
110-
error (HfHubHTTPError): The caught exception.
111-
112-
Raises
113-
------
114-
AquaRuntimeError: A user-friendly error message.
115-
"""
116-
# Extract the repository URL from the error message if present
117-
match = re.search(r"(https://huggingface.co/[^\s]+)", str(error))
118-
url = match.group(1) if match else "the requested Hugging Face URL."
119-
120-
if isinstance(error, RepositoryNotFoundError):
121-
raise AquaRuntimeError(
122-
reason=f"Failed to access `{url}`. Please check if the provided repository name is correct. "
123-
"If the repo is private, make sure you are authenticated and have a valid HF token registered. "
124-
"To register your token, run this command in your terminal: `huggingface-cli login`",
125-
service_payload={"error": "RepositoryNotFoundError"},
126-
)
127-
128-
if isinstance(error, GatedRepoError):
129-
raise AquaRuntimeError(
130-
reason=f"Access denied to `{url}` "
131-
"This repository is gated. Access is restricted to authorized users. "
132-
"Please request access or check with the repository administrator. "
133-
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
134-
"To register your token, run this command in your terminal: `huggingface-cli login`",
135-
service_payload={"error": "GatedRepoError"},
136-
)
137-
138-
if isinstance(error, RevisionNotFoundError):
139-
raise AquaRuntimeError(
140-
reason=f"The specified revision could not be found at `{url}` "
141-
"Please check the revision identifier and try again.",
142-
service_payload={"error": "RevisionNotFoundError"},
143-
)
144-
145-
raise AquaRuntimeError(
146-
reason=f"An error occurred while accessing `{url}` "
147-
"Please check your network connection and try again. "
148-
"If you are trying to access a gated repository, ensure you have a valid HF token registered. "
149-
"To register your token, run this command in your terminal: `huggingface-cli login`",
150-
service_payload={"error": "Error"},
151-
)
152-
15355
@handle_exceptions
15456
def post(self, *args, **kwargs):
155-
"""Handles post request for the HF Models APIs
156-
57+
"""
58+
Handles post request for the registering any Aqua model.
15759
Raises
15860
------
15961
HTTPError
160-
Raises HTTPError if inputs are missing or are invalid.
62+
Raises HTTPError if inputs are missing or are invalid
16163
"""
16264
try:
16365
input_data = self.get_json_body()
@@ -167,48 +69,41 @@ def post(self, *args, **kwargs):
16769
if not input_data:
16870
raise HTTPError(400, Errors.NO_INPUT_DATA)
16971

170-
model_id = input_data.get("model_id")
171-
token = input_data.get("token")
72+
# required input parameters
73+
model = input_data.get("model")
74+
if not model:
75+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model"))
76+
os_path = input_data.get("os_path")
77+
if not os_path:
78+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("os_path"))
17279

173-
if not model_id:
174-
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
80+
inference_container = input_data.get("inference_container")
81+
finetuning_container = input_data.get("finetuning_container")
82+
compartment_id = input_data.get("compartment_id")
83+
project_id = input_data.get("project_id")
17584

176-
# Get model info from the HF
177-
try:
178-
hf_model_info = HfApi(token=token).model_info(model_id)
179-
except HfHubHTTPError as err:
180-
raise self._format_custom_error_message(err)
181-
182-
# Check if model is not disabled
183-
if hf_model_info.disabled:
184-
raise AquaRuntimeError(
185-
f"The chosen model '{hf_model_info.id}' is currently disabled and cannot be imported into AQUA. "
186-
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
85+
return self.finish(
86+
AquaModelApp().register(
87+
model=model,
88+
os_path=os_path,
89+
inference_container=inference_container,
90+
finetuning_container=finetuning_container,
91+
compartment_id=compartment_id,
92+
project_id=project_id,
18793
)
94+
)
18895

189-
# Check pipeline_tag, it should be `text-generation`
190-
if (
191-
not hf_model_info.pipeline_tag
192-
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
193-
):
194-
raise AquaRuntimeError(
195-
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
196-
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
197-
"Please select a model with a compatible pipeline tag."
198-
)
19996

200-
# Check if it is a service/verified model
201-
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(
202-
model_id=hf_model_info.id
203-
)
97+
class AquaModelLicenseHandler(AquaAPIhandler):
98+
"""Handler for Aqua Model license REST APIs."""
20499

205-
return self.finish(
206-
HFModelSummary(model_info=hf_model_info, aqua_model_info=aqua_model_info)
207-
)
100+
@handle_exceptions
101+
def get(self, model_id):
102+
"""Handle GET request."""
103+
return self.finish(AquaModelApp().load_license(model_id))
208104

209105

210106
__handlers__ = [
211107
("model/?([^/]*)", AquaModelHandler),
212108
("model/?([^/]*)/license", AquaModelLicenseHandler),
213-
("model/hf/search/?([^/]*)", AquaHuggingFaceHandler),
214109
]

ads/aqua/model/entities.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from typing import List, Optional
1515

1616
import oci
17-
from huggingface_hub import hf_api
1817

1918
from ads.aqua import logger
2019
from ads.aqua.app import CLIBuilderMixin
@@ -97,14 +96,6 @@ class HFModelContainerInfo:
9796
finetuning_container: str = None
9897

9998

100-
@dataclass(repr=False)
101-
class HFModelSummary:
102-
"""Represents a summary of Hugging Face model."""
103-
104-
model_info: hf_api.ModelInfo = field(default_factory=hf_api.ModelInfo)
105-
aqua_model_info: Optional[AquaModel] = field(default_factory=AquaModel)
106-
107-
10899
@dataclass(repr=False)
109100
class AquaEvalFTCommon(DataClassSerializable):
110101
"""Represents common fields for evaluation and fine-tuning."""
@@ -266,11 +257,8 @@ def _extract_job_lifecycle_details(self, lifecycle_details):
266257
class ImportModelDetails(CLIBuilderMixin):
267258
model: str
268259
os_path: str
269-
local_dir: Optional[str] = None
270260
inference_container: Optional[str] = None
271-
inference_container_type_smc: Optional[bool] = False
272261
finetuning_container: Optional[str] = None
273-
finetuning_container_type_smc: Optional[bool] = False
274262
compartment_id: Optional[str] = None
275263
project_id: Optional[str] = None
276264

0 commit comments

Comments
 (0)