Skip to content

Commit ed0e207

Browse files
committed
Updated pr.
1 parent b643faa commit ed0e207

File tree

6 files changed

+282
-8
lines changed

6 files changed

+282
-8
lines changed

ads/llm/__init__.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*--
32

4-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
54
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
65

76
try:
87
import langchain
9-
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
10-
OCIModelDeploymentVLLM,
11-
OCIModelDeploymentTGI,
12-
)
8+
9+
from ads.llm.chat_template import ChatTemplates
1310
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
1411
ChatOCIModelDeployment,
15-
ChatOCIModelDeploymentVLLM,
1612
ChatOCIModelDeploymentTGI,
13+
ChatOCIModelDeploymentVLLM,
14+
)
15+
from ads.llm.langchain.plugins.embeddings.oci_data_science_model_deployment_endpoint import (
16+
OCIModelDeploymentEndpointEmbeddings,
17+
)
18+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
19+
OCIModelDeploymentTGI,
20+
OCIModelDeploymentVLLM,
1721
)
18-
from ads.llm.chat_template import ChatTemplates
1922
except ImportError as ex:
2023
if ex.name == "langchain":
2124
raise ImportError(
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
from typing import Any, Callable, Dict, List, Mapping, Optional
7+
8+
import requests
9+
from langchain_core.embeddings import Embeddings
10+
from langchain_core.language_models.llms import create_base_retry_decorator
11+
from langchain_core.utils import get_from_dict_or_env
12+
from pydantic import BaseModel, Field, model_validator
13+
14+
DEFAULT_HEADER = {
15+
"Content-Type": "application/json",
16+
}
17+
18+
19+
class TokenExpiredError(Exception):
20+
pass
21+
22+
23+
def _create_retry_decorator(llm) -> Callable[[Any], Any]:
24+
"""Creates a retry decorator."""
25+
errors = [requests.exceptions.ConnectTimeout, TokenExpiredError]
26+
decorator = create_base_retry_decorator(
27+
error_types=errors, max_retries=llm.max_retries
28+
)
29+
return decorator
30+
31+
32+
class OCIModelDeploymentEndpointEmbeddings(BaseModel, Embeddings):
33+
"""Embedding model deployed on OCI Data Science Model Deployment.
34+
35+
Example:
36+
37+
.. code-block:: python
38+
39+
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings
40+
41+
embeddings = OCIModelDeploymentEndpointEmbeddings(
42+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<md_ocid>/predict",
43+
)
44+
""" # noqa: E501
45+
46+
auth: dict = Field(default_factory=dict, exclude=True)
47+
"""ADS auth dictionary for OCI authentication:
48+
https://accelerated-data-science.readthedocs.io/en/latest/user_guide/cli/authentication.html.
49+
This can be generated by calling `ads.common.auth.api_keys()`
50+
or `ads.common.auth.resource_principal()`. If this is not
51+
provided then the `ads.common.default_signer()` will be used."""
52+
53+
endpoint: str = ""
54+
"""The uri of the endpoint from the deployed Model Deployment model."""
55+
56+
model_kwargs: Optional[Dict] = None
57+
"""Keyword arguments to pass to the model."""
58+
59+
endpoint_kwargs: Optional[Dict] = None
60+
"""Optional attributes (except for headers) passed to the request.post
61+
function.
62+
"""
63+
64+
max_retries: int = 1
65+
"""The maximum number of retries to make when generating."""
66+
67+
@model_validator(mode="before")
68+
def validate_environment( # pylint: disable=no-self-argument
69+
cls, values: Dict
70+
) -> Dict:
71+
"""Validate that python package exists in environment."""
72+
try:
73+
import ads
74+
75+
except ImportError as ex:
76+
raise ImportError(
77+
"Could not import ads python package. "
78+
"Please install it with `pip install oracle_ads`."
79+
) from ex
80+
if not values.get("auth"):
81+
values["auth"] = ads.common.auth.default_signer()
82+
values["endpoint"] = get_from_dict_or_env(
83+
values,
84+
"endpoint",
85+
"OCI_LLM_ENDPOINT",
86+
)
87+
return values
88+
89+
@property
90+
def _identifying_params(self) -> Mapping[str, Any]:
91+
"""Get the identifying parameters."""
92+
_model_kwargs = self.model_kwargs or {}
93+
return {
94+
**{"endpoint": self.endpoint},
95+
**{"model_kwargs": _model_kwargs},
96+
}
97+
98+
def _embed_with_retry(self, **kwargs) -> Any:
99+
"""Use tenacity to retry the call."""
100+
retry_decorator = _create_retry_decorator(self)
101+
102+
@retry_decorator
103+
def _completion_with_retry(**kwargs: Any) -> Any:
104+
try:
105+
response = requests.post(self.endpoint, **kwargs)
106+
response.raise_for_status()
107+
return response
108+
except requests.exceptions.HTTPError as http_err:
109+
if response.status_code == 401 and self._refresh_signer():
110+
raise TokenExpiredError() from http_err
111+
else:
112+
raise ValueError(
113+
f"Server error: {str(http_err)}. Message: {response.text}"
114+
) from http_err
115+
except Exception as e:
116+
raise ValueError(f"Error occurs by inference endpoint: {str(e)}") from e
117+
118+
return _completion_with_retry(**kwargs)
119+
120+
def _embedding(self, texts: List[str]) -> List[List[float]]:
121+
"""Call out to OCI Data Science Model Deployment Endpoint.
122+
123+
Args:
124+
texts: A list of texts to embed.
125+
126+
Returns:
127+
A list of list of floats representing the embeddings, or None if an
128+
error occurs.
129+
"""
130+
_model_kwargs = self.model_kwargs or {}
131+
body = self._construct_request_body(texts, _model_kwargs)
132+
request_kwargs = self._construct_request_kwargs(body)
133+
response = self._embed_with_retry(**request_kwargs)
134+
return self._proceses_response(response)
135+
136+
def _construct_request_kwargs(self, body: Any) -> dict:
137+
"""Constructs the request kwargs as a dictionary."""
138+
from ads.model.common.utils import _is_json_serializable
139+
140+
_endpoint_kwargs = self.endpoint_kwargs or {}
141+
headers = _endpoint_kwargs.pop("headers", DEFAULT_HEADER)
142+
return (
143+
dict(
144+
headers=headers,
145+
json=body,
146+
auth=self.auth.get("signer"),
147+
**_endpoint_kwargs,
148+
)
149+
if _is_json_serializable(body)
150+
else dict(
151+
headers=headers,
152+
data=body,
153+
auth=self.auth.get("signer"),
154+
**_endpoint_kwargs,
155+
)
156+
)
157+
158+
def _construct_request_body(self, texts: List[str], params: dict) -> Any:
159+
"""Constructs the request body."""
160+
return {"input": texts}
161+
162+
def _proceses_response(self, response: requests.Response) -> List[List[float]]:
163+
"""Extracts results from requests.Response."""
164+
try:
165+
res_json = response.json()
166+
embeddings = res_json["data"][0]["embedding"]
167+
except Exception as e:
168+
raise ValueError(
169+
f"Error raised by inference API: {e}.\nResponse: {response.text}"
170+
)
171+
return embeddings
172+
173+
def embed_documents(
174+
self,
175+
texts: List[str],
176+
chunk_size: Optional[int] = None,
177+
) -> List[List[float]]:
178+
"""Compute doc embeddings using OCI Data Science Model Deployment Endpoint.
179+
180+
Args:
181+
texts: The list of texts to embed.
182+
chunk_size: The chunk size defines how many input texts will
183+
be grouped together as request. If None, will use the
184+
chunk size specified by the class.
185+
186+
Returns:
187+
List of embeddings, one for each text.
188+
"""
189+
results = []
190+
_chunk_size = (
191+
len(texts) if (not chunk_size or chunk_size > len(texts)) else chunk_size
192+
)
193+
for i in range(0, len(texts), _chunk_size):
194+
response = self._embedding(texts[i : i + _chunk_size])
195+
results.extend(response)
196+
return results
197+
198+
def embed_query(self, text: str) -> List[float]:
199+
"""Compute query embeddings using OCI Data Science Model Deployment Endpoint.
200+
201+
Args:
202+
text: The text to embed.
203+
204+
Returns:
205+
Embeddings for the text.
206+
"""
207+
return self._embedding([text])[0]

docs/source/user_guide/large_language_model/langchain_models.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,26 @@ Chat models takes `chat messages <https://python.langchain.com/docs/concepts/#me
127127
print(chunk.content, end="")
128128
129129
130+
Embedding Models
131+
================
132+
133+
You can also use embedding model that's hosted on a `OCI Data Science Model Deployment <https://docs.oracle.com/en-us/iaas/data-science/using/model-dep-about.htm>`_.
134+
135+
136+
.. code-block:: python3
137+
138+
from langchain_community.embeddings import OCIModelDeploymentEndpointEmbeddings
139+
140+
# Create an instance of OCI Model Deployment Endpoint
141+
# Replace the endpoint uri with your own
142+
embeddings = OCIModelDeploymentEndpointEmbeddings(
143+
endpoint="https://modeldeployment.us-ashburn-1.oci.customer-oci.com/<MD_OCID>/predict",
144+
)
145+
146+
query = "Hello World!"
147+
embeddings.embed_query(query)
148+
149+
130150
Tool Calling
131151
============
132152

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2025 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2025 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
"""Test OCI Data Science Model Deployment Endpoint."""
8+
9+
import responses
10+
from pytest_mock import MockerFixture
11+
from ads.llm import OCIModelDeploymentEndpointEmbeddings
12+
13+
14+
@responses.activate
15+
def test_embedding_call(mocker: MockerFixture) -> None:
16+
"""Test valid call to oci model deployment endpoint."""
17+
endpoint = "https://MD_OCID/predict"
18+
documents = ["Hello", "World"]
19+
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
20+
responses.add(
21+
responses.POST,
22+
endpoint,
23+
json={
24+
"data": [{"embedding": expected_output}],
25+
},
26+
status=200,
27+
)
28+
mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
29+
30+
embeddings = OCIModelDeploymentEndpointEmbeddings(
31+
endpoint=endpoint,
32+
)
33+
34+
output = embeddings.embed_documents(documents)
35+
assert output == expected_output

0 commit comments

Comments
 (0)