Skip to content

Commit 19269a3

Browse files
committed
Merge branch 'main' into feature/ad_merlion
2 parents 4679082 + a37882c commit 19269a3

File tree

71 files changed

+19740
-1914
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+19740
-1914
lines changed

THIRD_PARTY_LICENSES.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ langchain
157157
* Source code: https://github.com/langchain-ai/langchain
158158
* Project home: https://www.langchain.com/
159159

160+
langchain-community
161+
* Copyright (c) 2023 LangChain, Inc.
162+
* License: MIT license
163+
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/community
164+
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/community
165+
166+
langchain-openai
167+
* Copyright (c) 2023 LangChain, Inc.
168+
* License: MIT license
169+
* Source code: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
170+
* Project home: https://github.com/langchain-ai/langchain/tree/master/libs/partners/openai
171+
160172
lightgbm
161173
* Copyright (c) 2023 Microsoft Corporation
162174
* License: MIT license

ads/aqua/config/evaluation/evaluation_service_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def search_shapes(
224224

225225
class Config:
226226
extra = "ignore"
227+
protected_namespaces = ()
227228

228229

229230
class EvaluationServiceConfig(Serializable):

ads/aqua/extension/model_handler.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ads.aqua.extension.base_handler import AquaAPIhandler
1414
from ads.aqua.extension.errors import Errors
1515
from ads.aqua.model import AquaModelApp
16-
from ads.aqua.model.constants import ModelTask
1716
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
1817
from ads.aqua.ui import ModelFormat
1918

@@ -68,7 +67,7 @@ def read(self, model_id):
6867
return self.finish(AquaModelApp().get(model_id))
6968

7069
@handle_exceptions
71-
def delete(self):
70+
def delete(self, id=""):
7271
"""Handles DELETE request for clearing cache"""
7372
url_parse = urlparse(self.request.path)
7473
paths = url_parse.path.strip("/")
@@ -177,10 +176,8 @@ def _find_matching_aqua_model(model_id: str) -> Optional[AquaModelSummary]:
177176

178177
return None
179178

180-
181-
182179
@handle_exceptions
183-
def get(self,*args, **kwargs):
180+
def get(self, *args, **kwargs):
184181
"""
185182
Finds a list of matching models from hugging face based on query string provided from users.
186183
@@ -194,13 +191,11 @@ def get(self,*args, **kwargs):
194191
Returns the matching model ids string
195192
"""
196193

197-
query=self.get_argument("query",default=None)
194+
query = self.get_argument("query", default=None)
198195
if not query:
199-
raise HTTPError(400,Errors.MISSING_REQUIRED_PARAMETER.format("query"))
200-
models=list_hf_models(query)
201-
return self.finish({"models":models})
202-
203-
196+
raise HTTPError(400, Errors.MISSING_REQUIRED_PARAMETER.format("query"))
197+
models = list_hf_models(query)
198+
return self.finish({"models": models})
204199

205200
@handle_exceptions
206201
def post(self, *args, **kwargs):
@@ -234,16 +229,17 @@ def post(self, *args, **kwargs):
234229
"Please verify the model's status on the Hugging Face Model Hub or select a different model."
235230
)
236231

237-
# Check pipeline_tag, it should be `text-generation`
238-
if (
239-
not hf_model_info.pipeline_tag
240-
or hf_model_info.pipeline_tag.lower() != ModelTask.TEXT_GENERATION
241-
):
242-
raise AquaRuntimeError(
243-
f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
244-
f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
245-
"Please select a model with a compatible pipeline tag."
246-
)
232+
# Commented the validation below to let users to register any model type.
233+
# # Check pipeline_tag, it should be `text-generation`
234+
# if not (
235+
# hf_model_info.pipeline_tag
236+
# and hf_model_info.pipeline_tag.lower() in ModelTask
237+
# ):
238+
# raise AquaRuntimeError(
239+
# f"Unsupported pipeline tag for the chosen model: '{hf_model_info.pipeline_tag}'. "
240+
# f"AQUA currently supports the following tasks only: {', '.join(ModelTask.values())}. "
241+
# "Please select a model with a compatible pipeline tag."
242+
# )
247243

248244
# Check if it is a service/verified model
249245
aqua_model_info: AquaModelSummary = self._find_matching_aqua_model(

ads/aqua/model/constants.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -9,6 +8,7 @@
98
109
This module contains constants/enums used in Aqua Model.
1110
"""
11+
1212
from ads.common.extended_enum import ExtendedEnumMeta
1313

1414

@@ -21,6 +21,8 @@ class ModelCustomMetadataFields(str, metaclass=ExtendedEnumMeta):
2121

2222
class ModelTask(str, metaclass=ExtendedEnumMeta):
2323
TEXT_GENERATION = "text-generation"
24+
IMAGE_TEXT_TO_TEXT = "image-text-to-text"
25+
IMAGE_TO_TEXT = "image-to-text"
2426

2527

2628
class FineTuningMetricCategories(str, metaclass=ExtendedEnumMeta):

ads/llm/__init__.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,16 @@
66

77
try:
88
import langchain
9-
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
10-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
11-
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
12-
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
9+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10+
OCIModelDeploymentVLLM,
11+
OCIModelDeploymentTGI,
12+
)
13+
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14+
ChatOCIModelDeployment,
15+
ChatOCIModelDeploymentVLLM,
16+
ChatOCIModelDeploymentTGI,
17+
)
18+
from ads.llm.chat_template import ChatTemplates
1319
except ImportError as ex:
1420
if ex.name == "langchain":
1521
raise ImportError(

ads/llm/chat_template.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import os
9+
10+
11+
class ChatTemplates:
12+
"""Contains chat templates."""
13+
14+
@staticmethod
15+
def _read_template(filename):
16+
with open(
17+
os.path.join(os.path.dirname(__file__), "templates", filename),
18+
mode="r",
19+
encoding="utf-8",
20+
) as f:
21+
return f.read()
22+
23+
@staticmethod
24+
def mistral():
25+
"""Chat template for auto tool calling with Mistral model deploy with vLLM."""
26+
return ChatTemplates._read_template("tool_chat_template_mistral_parallel.jinja")
27+
28+
@staticmethod
29+
def hermes():
30+
"""Chat template for auto tool calling with Hermes model deploy with vLLM."""
31+
return ChatTemplates._read_template("tool_chat_template_hermes.jinja")

ads/llm/guardrails/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import Any, List, Dict, Tuple
1515
from langchain.schema.prompt import PromptValue
1616
from langchain.tools.base import BaseTool, ToolException
17-
from langchain.pydantic_v1 import BaseModel, root_validator
17+
from pydantic import BaseModel, model_validator
1818

1919

2020
class RunInfo(BaseModel):
@@ -190,7 +190,8 @@ class Config:
190190
This is used by the ``apply_filter()`` method.
191191
"""
192192

193-
@root_validator
193+
@model_validator(mode="before")
194+
@classmethod
194195
def default_name(cls, values):
195196
"""Sets the default name of the guardrail."""
196197
if not values.get("name"):

ads/llm/guardrails/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
import evaluate
9-
from langchain.pydantic_v1 import root_validator
9+
from pydantic.v1 import root_validator
1010
from .base import Guardrail
1111

1212

ads/llm/langchain/plugins/base.py

Lines changed: 0 additions & 118 deletions
This file was deleted.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

0 commit comments

Comments
 (0)