Skip to content

Commit 5f52085

Browse files
committed
Updated pr.
1 parent fced37d commit 5f52085

File tree

4 files changed

+14
-36
lines changed

4 files changed

+14
-36
lines changed

ads/llm/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
ChatOCIModelDeploymentTGI,
1313
ChatOCIModelDeploymentVLLM,
1414
)
15+
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
16+
OCIDataScienceEmbedding,
17+
)
1518
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
1619
OCIModelDeploymentTGI,
1720
OCIModelDeploymentVLLM,

ads/llm/langchain/plugins/embeddings/oci_data_science_model_deployment_endpoint.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import requests
99
from langchain_core.embeddings import Embeddings
1010
from langchain_core.language_models.llms import create_base_retry_decorator
11-
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
12-
from langchain_core.utils import get_from_dict_or_env
11+
from langchain_core.pydantic_v1 import BaseModel, Field
1312

1413
DEFAULT_HEADER = {
1514
"Content-Type": "application/json",
@@ -29,16 +28,16 @@ def _create_retry_decorator(llm) -> Callable[[Any], Any]:
2928
return decorator
3029

3130

32-
class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
31+
class OCIDataScienceEmbedding(BaseModel, Embeddings):
3332
"""Embedding model deployed on OCI Data Science Model Deployment.
3433
3534
Example:
3635
3736
.. code-block:: python
3837
39-
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings
38+
from ads.llm import OCIDataScienceEmbedding
4039
41-
embeddings = OCIModelDeploymentEndpointEmbeddings(
40+
embeddings = OCIDataScienceEmbedding(
4241
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
4342
)
4443
""" # noqa: E501
@@ -64,28 +63,6 @@ class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
6463
max_retries: int = 1
6564
"""The maximum number of retries to make when generating."""
6665

67-
@root_validator()
68-
def validate_environment( # pylint: disable=no-self-argument
69-
cls, values: Dict
70-
) -> Dict:
71-
"""Validate that python package exists in environment."""
72-
try:
73-
import ads
74-
75-
except ImportError as ex:
76-
raise ImportError(
77-
"Could not import ads python package. "
78-
"Please install it with `pip install oracle_ads`."
79-
) from ex
80-
if not values.get("auth"):
81-
values["auth"] = ads.common.auth.default_signer()
82-
values["endpoint"] = get_from_dict_or_env(
83-
values,
84-
"endpoint",
85-
"OCI_LLM_ENDPOINT",
86-
)
87-
return values
88-
8966
@property
9067
def _identifying_params(self) -> Mapping[str, Any]:
9168
"""Get the identifying parameters."""

docs/source/user_guide/large_language_model/langchain_models.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,11 @@ You can also use embedding model that's hosted on a `OCI Data Science Model Depl
135135

136136
.. code-block:: python3
137137
138-
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings
138+
from ads.llm import OCIDataScienceEmbedding
139139
140140
# Create an instance of OCI Model Deployment Endpoint
141141
# Replace the endpoint uri with your own
142-
embeddings = OCIModelDeploymentEndpointEmbeddings(
142+
embeddings = OCIDataScienceEmbedding(
143143
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict",
144144
)
145145

tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,10 @@
77
"""Test OCI Data Science Model Deployment Endpoint."""
88

99
from unittest.mock import MagicMock, patch
10-
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
11-
OCIModelDeploymentEndpointEmbeddings,
12-
)
10+
from ads.llm import OCIDataScienceEmbedding
1311

1412

15-
@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry")
13+
@patch("ads.llm.OCIDataScienceEmbedding._embed_with_retry")
1614
def test_embed_documents(mock_embed_with_retry) -> None:
1715
"""Test valid call to oci model deployment endpoint."""
1816
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
@@ -26,15 +24,15 @@ def test_embed_documents(mock_embed_with_retry) -> None:
2624
endpoint = "https://MD_OCID/predict"
2725
documents = ["Hello", "World"]
2826

29-
embeddings = OCIModelDeploymentEndpointEmbeddings(
27+
embeddings = OCIDataScienceEmbedding(
3028
endpoint=endpoint,
3129
)
3230

3331
output = embeddings.embed_documents(documents)
3432
assert output == expected_output
3533

3634

37-
@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry")
35+
@patch("ads.llm.OCIDataScienceEmbedding._embed_with_retry")
3836
def test_embed_query(mock_embed_with_retry) -> None:
3937
"""Test valid call to oci model deployment endpoint."""
4038
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
@@ -48,7 +46,7 @@ def test_embed_query(mock_embed_with_retry) -> None:
4846
endpoint = "https://MD_OCID/predict"
4947
query = "Hello world"
5048

51-
embeddings = OCIModelDeploymentEndpointEmbeddings(
49+
embeddings = OCIDataScienceEmbedding(
5250
endpoint=endpoint,
5351
)
5452

0 commit comments

Comments
 (0)