Skip to content

Commit b0ad6e5

Browse files
committed
Update embeddings.py to add serialization support.
1 parent 3b04340 commit b0ad6e5

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

ads/llm/langchain/plugins/embeddings.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77
from typing import List, Optional
8-
8+
from langchain.load.serializable import Serializable
99
from langchain.schema.embeddings import Embeddings
1010
from ads.llm.langchain.plugins.base import GenerativeAiClientModel
1111

1212

13-
class GenerativeAIEmbeddings(GenerativeAiClientModel, Embeddings):
13+
class GenerativeAIEmbeddings(GenerativeAiClientModel, Embeddings, Serializable):
1414
"""OCI Generative AI embedding models."""
1515

1616
model: str = "cohere.embed-english-light-v2.0"
@@ -19,6 +19,16 @@ class GenerativeAIEmbeddings(GenerativeAiClientModel, Embeddings):
1919
truncate: Optional[str] = None
2020
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
2121

22+
@classmethod
23+
def get_lc_namespace(cls) -> List[str]:
24+
"""Get the namespace of the LangChain object."""
25+
return ["ads", "llm"]
26+
27+
@classmethod
28+
def is_lc_serializable(cls) -> bool:
29+
"""This class can be serialized with default LangChain serialization."""
30+
return True
31+
2232
def embed_documents(self, texts: List[str]) -> List[List[float]]:
2333
"""Embeds a list of strings.
2434

0 commit comments

Comments
 (0)