Skip to content

Commit 3e78ff5

Browse files
committed
fix the enum and adding faiss serializer
1 parent 3d2fa25 commit 3e78ff5

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

ads/llm/langchain/plugins/llm_gen_ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def _identifying_params(self) -> Dict[str, Any]:
6464
return {
6565
**{
6666
"model": self.model,
67-
"task": self.task,
67+
"task": self.task.value,
6868
"client_kwargs": self.client_kwargs,
6969
"endpoint_kwargs": self.endpoint_kwargs,
7070
},

ads/llm/serialize.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,17 @@
1919
from langchain.load import dumpd
2020
from langchain.load.load import load as lc_load
2121
from langchain.load.serializable import Serializable
22-
from langchain.vectorstores import OpenSearchVectorSearch
22+
from langchain.vectorstores import OpenSearchVectorSearch, FAISS
2323
from opensearchpy.client import OpenSearch
2424

2525
from ads.common.auth import default_signer
2626
from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
2727
from ads.llm.chain import GuardrailSequence
2828
from ads.llm.guardrails.base import CustomGuardrailBase
2929
from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
30+
import json
31+
import base64
32+
3033

3134
# This is a temp solution for supporting custom LLM in legacy load_chain
3235
__lc_llm_dict = llms.get_type_to_cls_dict()
@@ -59,8 +62,8 @@ def load(config: dict, **kwargs):
5962
return OpenSearchVectorSearch(
6063
**config["kwargs"],
6164
http_auth=(
62-
os.environ.get("oci_opensearch_username"),
63-
os.environ.get("oci_opensearch_password"),
65+
os.environ.get("oci_opensearch_username", None),
66+
os.environ.get("oci_opensearch_password", None),
6467
),
6568
verify_certs=os.environ.get("oci_opensearch_verify_certs", False),
6669
ca_certs=os.environ.get("oci_opensearch_ca_certs", None),
@@ -88,8 +91,36 @@ def save(obj):
8891
return serialized
8992

9093

94+
class FaissSerializer:
95+
"""
96+
Serializer for OpenSearchVectorSearch class
97+
"""
98+
@staticmethod
99+
def type():
100+
return FAISS.__name__
101+
102+
@staticmethod
103+
def load(config: dict, **kwargs):
104+
embedding_function = load(config["embedding_function"])
105+
decoded_pkl = base64.b64decode(json.loads(config["vectordb"]))
106+
return FAISS.deserialize_from_bytes(
107+
embeddings=embedding_function, serialized=decoded_pkl
108+
) # Load the index
109+
110+
@staticmethod
111+
def save(obj):
112+
serialized = {}
113+
serialized["_type"] = FaissSerializer.type()
114+
pkl = obj.serialize_to_bytes()
115+
# Encoding bytes to a base64 string
116+
encoded_pkl = base64.b64encode(pkl).decode('utf-8')
117+
# Serializing the base64 string
118+
serialized["vectordb"] = json.dumps(encoded_pkl)
119+
serialized["embedding_function"] = dump(obj.__dict__["embedding_function"])
120+
return serialized
121+
91122
# Mapping class to vector store serialization functions
92-
vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer}
123+
vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer, "FAISS": FaissSerializer}
93124

94125

95126
class RetrieverQASerializer:
@@ -118,6 +149,7 @@ def save(obj):
118149
retriever_kwargs[key] = val
119150
serialized["retriever_kwargs"] = retriever_kwargs
120151
serialized["vectordb"] = {"class": obj.retriever.vectorstore.__class__.__name__}
152+
121153
vectordb_serializer = vectordb_serialization[serialized["vectordb"]["class"]]
122154
serialized["vectordb"].update(
123155
vectordb_serializer.save(obj.retriever.vectorstore)

0 commit comments

Comments
 (0)