4
4
# Copyright (c) 2023 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
+ import base64
7
8
import json
8
9
import os
9
10
import tempfile
11
+ from copy import deepcopy
10
12
from typing import Any , Dict , List , Optional
11
13
12
14
import fsspec
13
15
import yaml
14
16
from langchain import llms
15
- from langchain .llms import loading
17
+ from langchain .chains import RetrievalQA
16
18
from langchain .chains .loading import load_chain_from_config
19
+ from langchain .llms import loading
20
+ from langchain .load import dumpd
17
21
from langchain .load .load import Reviver
18
22
from langchain .load .serializable import Serializable
23
+ from langchain .vectorstores import FAISS , OpenSearchVectorSearch
24
+ from opensearchpy .client import OpenSearch
19
25
20
26
from ads .common .auth import default_signer
21
27
from ads .common .object_storage_details import ObjectStorageDetails
22
- from ads .llm import GenerativeAI , ModelDeploymentVLLM , ModelDeploymentTGI
28
+ from ads .llm import GenerativeAI , ModelDeploymentTGI , ModelDeploymentVLLM
23
29
from ads .llm .chain import GuardrailSequence
24
30
from ads .llm .guardrails .base import CustomGuardrailBase
25
31
from ads .llm .patch import RunnableParallel , RunnableParallelSerializer
26
32
27
-
28
33
# This is a temp solution for supporting custom LLM in legacy load_chain
29
34
__lc_llm_dict = llms .get_type_to_cls_dict ()
30
35
__lc_llm_dict [GenerativeAI .__name__ ] = lambda : GenerativeAI
@@ -39,11 +44,129 @@ def __new_type_to_cls_dict():
39
44
llms .get_type_to_cls_dict = __new_type_to_cls_dict
40
45
loading .get_type_to_cls_dict = __new_type_to_cls_dict
41
46
47
+
48
+ class OpenSearchVectorDBSerializer :
49
+ """
50
+ Serializer for OpenSearchVectorSearch class
51
+ """
52
+ @staticmethod
53
+ def type ():
54
+ return OpenSearchVectorSearch .__name__
55
+
56
+ @staticmethod
57
+ def load (config : dict , ** kwargs ):
58
+ config ["kwargs" ]["embedding_function" ] = load (
59
+ config ["kwargs" ]["embedding_function" ], ** kwargs
60
+ )
61
+ return OpenSearchVectorSearch (
62
+ ** config ["kwargs" ],
63
+ http_auth = (
64
+ os .environ .get ("OCI_OPENSEARCH_USERNAME" , None ),
65
+ os .environ .get ("OCI_OPENSEARCH_PASSWORD" , None ),
66
+ ),
67
+ verify_certs = True if os .environ .get ("OCI_OPENSEARCH_VERIFY_CERTS" , None ).lower () == "true" else False ,
68
+ ca_certs = os .environ .get ("OCI_OPENSEARCH_CA_CERTS" , None ),
69
+ )
70
+
71
+ @staticmethod
72
+ def save (obj ):
73
+ serialized = dumpd (obj )
74
+ serialized ["type" ] = "constructor"
75
+ serialized ["_type" ] = OpenSearchVectorDBSerializer .type ()
76
+ kwargs = {}
77
+ for key , val in obj .__dict__ .items ():
78
+ if key == "client" :
79
+ if isinstance (val , OpenSearch ):
80
+ client_info = val .transport .hosts [0 ]
81
+ opensearch_url = (
82
+ f"https://{ client_info ['host' ]} :{ client_info ['port' ]} "
83
+ )
84
+ kwargs .update ({"opensearch_url" : opensearch_url })
85
+ else :
86
+ raise NotImplementedError ("Only support OpenSearch client." )
87
+ continue
88
+ kwargs [key ] = dump (val )
89
+ serialized ["kwargs" ] = kwargs
90
+ return serialized
91
+
92
+
93
+ class FaissSerializer :
94
+ """
95
+ Serializer for OpenSearchVectorSearch class
96
+ """
97
+ @staticmethod
98
+ def type ():
99
+ return FAISS .__name__
100
+
101
+ @staticmethod
102
+ def load (config : dict , ** kwargs ):
103
+ embedding_function = load (config ["embedding_function" ], ** kwargs )
104
+ decoded_pkl = base64 .b64decode (json .loads (config ["vectordb" ]))
105
+ return FAISS .deserialize_from_bytes (
106
+ embeddings = embedding_function , serialized = decoded_pkl
107
+ ) # Load the index
108
+
109
+ @staticmethod
110
+ def save (obj ):
111
+ serialized = {}
112
+ serialized ["_type" ] = FaissSerializer .type ()
113
+ pkl = obj .serialize_to_bytes ()
114
+ # Encoding bytes to a base64 string
115
+ encoded_pkl = base64 .b64encode (pkl ).decode ('utf-8' )
116
+ # Serializing the base64 string
117
+ serialized ["vectordb" ] = json .dumps (encoded_pkl )
118
+ serialized ["embedding_function" ] = dump (obj .__dict__ ["embedding_function" ])
119
+ return serialized
120
+
121
+ # Mapping class to vector store serialization functions
122
+ vectordb_serialization = {"OpenSearchVectorSearch" : OpenSearchVectorDBSerializer , "FAISS" : FaissSerializer }
123
+
124
+
125
+ class RetrievalQASerializer :
126
+ """
127
+ Serializer for RetrieverQA class
128
+ """
129
+ @staticmethod
130
+ def type ():
131
+ return "retrieval_qa"
132
+
133
+ @staticmethod
134
+ def load (config : dict , ** kwargs ):
135
+ config_param = deepcopy (config )
136
+ retriever_kwargs = config_param .pop ("retriever_kwargs" )
137
+ vectordb_serializer = vectordb_serialization [config_param ["vectordb" ]["class" ]]
138
+ vectordb = vectordb_serializer .load (config_param .pop ("vectordb" ), ** kwargs )
139
+ retriever = vectordb .as_retriever (** retriever_kwargs )
140
+ return load_chain_from_config (config = config_param , retriever = retriever )
141
+
142
+ @staticmethod
143
+ def save (obj ):
144
+ serialized = obj .dict ()
145
+ retriever_kwargs = {}
146
+ for key , val in obj .retriever .__dict__ .items ():
147
+ if key not in ["tags" , "metadata" , "vectorstore" ]:
148
+ retriever_kwargs [key ] = val
149
+ serialized ["retriever_kwargs" ] = retriever_kwargs
150
+ serialized ["vectordb" ] = {"class" : obj .retriever .vectorstore .__class__ .__name__ }
151
+
152
+ vectordb_serializer = vectordb_serialization [serialized ["vectordb" ]["class" ]]
153
+ serialized ["vectordb" ].update (
154
+ vectordb_serializer .save (obj .retriever .vectorstore )
155
+ )
156
+
157
+ if serialized ["vectordb" ]["class" ] not in vectordb_serialization :
158
+ raise NotImplementedError (
159
+ f"VectorDBSerializer for { serialized ['vectordb' ]['class' ]} is not implemented."
160
+ )
161
+ return serialized
162
+
163
+
42
164
# Mapping class to custom serialization functions
43
165
custom_serialization = {
44
166
GuardrailSequence : GuardrailSequence .save ,
45
167
CustomGuardrailBase : CustomGuardrailBase .save ,
46
168
RunnableParallel : RunnableParallelSerializer .save ,
169
+ RetrievalQA : RetrievalQASerializer .save ,
47
170
}
48
171
49
172
# Mapping _type to custom deserialization functions
@@ -52,6 +175,7 @@ def __new_type_to_cls_dict():
52
175
GuardrailSequence .type (): GuardrailSequence .load ,
53
176
CustomGuardrailBase .type (): CustomGuardrailBase .load ,
54
177
RunnableParallelSerializer .type (): RunnableParallelSerializer .load ,
178
+ RetrievalQASerializer .type (): RetrievalQASerializer .load ,
55
179
}
56
180
57
181
0 commit comments