Skip to content

Commit 0c17bef

Browse files
authored
AQUA. Adds HF Login handler. (#866)
2 parents 9d1a67b + e23bda8 commit 0c17bef

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

ads/aqua/extension/common_handler.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,16 @@
66

77
from importlib import metadata
88

9+
import huggingface_hub
910
import requests
11+
from tornado.web import HTTPError
1012

1113
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID
1214
from ads.aqua.common.decorator import handle_exceptions
1315
from ads.aqua.common.errors import AquaResourceAccessError
1416
from ads.aqua.common.utils import fetch_service_compartment, known_realm
1517
from ads.aqua.extension.base_handler import AquaAPIhandler
18+
from ads.aqua.extension.errors import Errors
1619

1720

1821
class ADSVersionHandler(AquaAPIhandler):
@@ -62,8 +65,40 @@ def get(self):
6265
return self.finish("success")
6366

6467

68+
class HFLoginHandler(AquaAPIhandler):
69+
"""Handler to login to HF."""
70+
71+
@handle_exceptions
72+
def post(self, *args, **kwargs):
73+
"""Handles post request for the HF login.
74+
75+
Raises
76+
------
77+
HTTPError
78+
Raises HTTPError if inputs are missing or are invalid.
79+
"""
80+
try:
81+
input_data = self.get_json_body()
82+
except Exception:
83+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
84+
85+
if not input_data:
86+
raise HTTPError(400, Errors.NO_INPUT_DATA)
87+
88+
token = input_data.get("token")
89+
90+
if not token:
91+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("token"))
92+
93+
# Login to HF
94+
huggingface_hub.login(token=token, new_session=False)
95+
96+
return self.finish("success")
97+
98+
6599
__handlers__ = [
66100
("ads_version", ADSVersionHandler),
67101
("hello", CompatibilityCheckHandler),
68102
("network_status", NetworkStatusHandler),
103+
("hf_login", HFLoginHandler),
69104
]

ads/aqua/extension/model_handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import re
7-
from typing import Optional, Tuple
7+
from typing import Optional
88
from urllib.parse import urlparse
99

1010
from huggingface_hub import HfApi
@@ -168,13 +168,14 @@ def post(self, *args, **kwargs):
168168
raise HTTPError(400, Errors.NO_INPUT_DATA)
169169

170170
model_id = input_data.get("model_id")
171+
token = input_data.get("token")
171172

172173
if not model_id:
173174
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("model_id"))
174175

175176
# Get model info from the HF
176177
try:
177-
hf_model_info = HfApi().model_info(model_id)
178+
hf_model_info = HfApi(token=token).model_info(model_id)
178179
except HfHubHTTPError as err:
179180
raise self._format_custom_error_message(err)
180181

0 commit comments

Comments
 (0)