|
19 | 19 | from langchain.load import dumpd
|
20 | 20 | from langchain.load.load import load as lc_load
|
21 | 21 | from langchain.load.serializable import Serializable
|
22 |
| -from langchain.vectorstores import OpenSearchVectorSearch |
| 22 | +from langchain.vectorstores import OpenSearchVectorSearch, FAISS |
23 | 23 | from opensearchpy.client import OpenSearch
|
24 | 24 |
|
25 | 25 | from ads.common.auth import default_signer
|
26 | 26 | from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
|
27 | 27 | from ads.llm.chain import GuardrailSequence
|
28 | 28 | from ads.llm.guardrails.base import CustomGuardrailBase
|
29 | 29 | from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
|
| 30 | +import json |
| 31 | +import base64 |
| 32 | + |
30 | 33 |
|
31 | 34 | # This is a temp solution for supporting custom LLM in legacy load_chain
|
32 | 35 | __lc_llm_dict = llms.get_type_to_cls_dict()
|
@@ -59,8 +62,8 @@ def load(config: dict, **kwargs):
|
59 | 62 | return OpenSearchVectorSearch(
|
60 | 63 | **config["kwargs"],
|
61 | 64 | http_auth=(
|
62 |
| - os.environ.get("oci_opensearch_username"), |
63 |
| - os.environ.get("oci_opensearch_password"), |
| 65 | + os.environ.get("oci_opensearch_username", None), |
| 66 | + os.environ.get("oci_opensearch_password", None), |
64 | 67 | ),
|
65 | 68 | verify_certs=os.environ.get("oci_opensearch_verify_certs", False),
|
66 | 69 | ca_certs=os.environ.get("oci_opensearch_ca_certs", None),
|
@@ -88,8 +91,36 @@ def save(obj):
|
88 | 91 | return serialized
|
89 | 92 |
|
90 | 93 |
|
| 94 | +class FaissSerializer: |
| 95 | + """ |
| 96 | + Serializer for OpenSearchVectorSearch class |
| 97 | + """ |
| 98 | + @staticmethod |
| 99 | + def type(): |
| 100 | + return FAISS.__name__ |
| 101 | + |
| 102 | + @staticmethod |
| 103 | + def load(config: dict, **kwargs): |
| 104 | + embedding_function = load(config["embedding_function"]) |
| 105 | + decoded_pkl = base64.b64decode(json.loads(config["vectordb"])) |
| 106 | + return FAISS.deserialize_from_bytes( |
| 107 | + embeddings=embedding_function, serialized=decoded_pkl |
| 108 | + ) # Load the index |
| 109 | + |
| 110 | + @staticmethod |
| 111 | + def save(obj): |
| 112 | + serialized = {} |
| 113 | + serialized["_type"] = FaissSerializer.type() |
| 114 | + pkl = obj.serialize_to_bytes() |
| 115 | + # Encoding bytes to a base64 string |
| 116 | + encoded_pkl = base64.b64encode(pkl).decode('utf-8') |
| 117 | + # Serializing the base64 string |
| 118 | + serialized["vectordb"] = json.dumps(encoded_pkl) |
| 119 | + serialized["embedding_function"] = dump(obj.__dict__["embedding_function"]) |
| 120 | + return serialized |
| 121 | + |
91 | 122 | # Mapping class to vector store serialization functions
|
92 |
| -vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer} |
| 123 | +vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer, "FAISS": FaissSerializer} |
93 | 124 |
|
94 | 125 |
|
95 | 126 | class RetrieverQASerializer:
|
@@ -118,6 +149,7 @@ def save(obj):
|
118 | 149 | retriever_kwargs[key] = val
|
119 | 150 | serialized["retriever_kwargs"] = retriever_kwargs
|
120 | 151 | serialized["vectordb"] = {"class": obj.retriever.vectorstore.__class__.__name__}
|
| 152 | + |
121 | 153 | vectordb_serializer = vectordb_serialization[serialized["vectordb"]["class"]]
|
122 | 154 | serialized["vectordb"].update(
|
123 | 155 | vectordb_serializer.save(obj.retriever.vectorstore)
|
|
0 commit comments