4
4
# Copyright (c) 2023 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
- import base64
8
7
import json
9
8
import os
10
9
import tempfile
11
- from copy import deepcopy
12
10
from typing import Any , Dict , List , Optional
13
11
14
12
import fsspec
20
18
from langchain .load import dumpd
21
19
from langchain .load .load import Reviver
22
20
from langchain .load .serializable import Serializable
23
- from langchain .vectorstores import FAISS , OpenSearchVectorSearch
24
- from opensearchpy .client import OpenSearch
21
+ from langchain .schema .runnable import RunnableParallel
25
22
26
23
from ads .common .auth import default_signer
27
24
from ads .common .object_storage_details import ObjectStorageDetails
28
25
from ads .llm import GenerativeAI , ModelDeploymentTGI , ModelDeploymentVLLM
29
26
from ads .llm .chain import GuardrailSequence
30
27
from ads .llm .guardrails .base import CustomGuardrailBase
31
- from ads .llm .patch import RunnableParallel , RunnableParallelSerializer
28
+ from ads .llm .serializers .runnable_parallel import RunnableParallelSerializer
29
+ from ads .llm .serializers .retrieval_qa import RetrievalQASerializer
32
30
33
31
# This is a temp solution for supporting custom LLM in legacy load_chain
34
32
__lc_llm_dict = llms .get_type_to_cls_dict ()
@@ -45,122 +43,6 @@ def __new_type_to_cls_dict():
45
43
loading .get_type_to_cls_dict = __new_type_to_cls_dict
46
44
47
45
48
- class OpenSearchVectorDBSerializer :
49
- """
50
- Serializer for OpenSearchVectorSearch class
51
- """
52
- @staticmethod
53
- def type ():
54
- return OpenSearchVectorSearch .__name__
55
-
56
- @staticmethod
57
- def load (config : dict , ** kwargs ):
58
- config ["kwargs" ]["embedding_function" ] = load (
59
- config ["kwargs" ]["embedding_function" ], ** kwargs
60
- )
61
- return OpenSearchVectorSearch (
62
- ** config ["kwargs" ],
63
- http_auth = (
64
- os .environ .get ("OCI_OPENSEARCH_USERNAME" , None ),
65
- os .environ .get ("OCI_OPENSEARCH_PASSWORD" , None ),
66
- ),
67
- verify_certs = True if os .environ .get ("OCI_OPENSEARCH_VERIFY_CERTS" , None ).lower () == "true" else False ,
68
- ca_certs = os .environ .get ("OCI_OPENSEARCH_CA_CERTS" , None ),
69
- )
70
-
71
- @staticmethod
72
- def save (obj ):
73
- serialized = dumpd (obj )
74
- serialized ["type" ] = "constructor"
75
- serialized ["_type" ] = OpenSearchVectorDBSerializer .type ()
76
- kwargs = {}
77
- for key , val in obj .__dict__ .items ():
78
- if key == "client" :
79
- if isinstance (val , OpenSearch ):
80
- client_info = val .transport .hosts [0 ]
81
- opensearch_url = (
82
- f"https://{ client_info ['host' ]} :{ client_info ['port' ]} "
83
- )
84
- kwargs .update ({"opensearch_url" : opensearch_url })
85
- else :
86
- raise NotImplementedError ("Only support OpenSearch client." )
87
- continue
88
- kwargs [key ] = dump (val )
89
- serialized ["kwargs" ] = kwargs
90
- return serialized
91
-
92
-
93
- class FaissSerializer :
94
- """
95
- Serializer for OpenSearchVectorSearch class
96
- """
97
- @staticmethod
98
- def type ():
99
- return FAISS .__name__
100
-
101
- @staticmethod
102
- def load (config : dict , ** kwargs ):
103
- embedding_function = load (config ["embedding_function" ], ** kwargs )
104
- decoded_pkl = base64 .b64decode (json .loads (config ["vectordb" ]))
105
- return FAISS .deserialize_from_bytes (
106
- embeddings = embedding_function , serialized = decoded_pkl
107
- ) # Load the index
108
-
109
- @staticmethod
110
- def save (obj ):
111
- serialized = {}
112
- serialized ["_type" ] = FaissSerializer .type ()
113
- pkl = obj .serialize_to_bytes ()
114
- # Encoding bytes to a base64 string
115
- encoded_pkl = base64 .b64encode (pkl ).decode ('utf-8' )
116
- # Serializing the base64 string
117
- serialized ["vectordb" ] = json .dumps (encoded_pkl )
118
- serialized ["embedding_function" ] = dump (obj .__dict__ ["embedding_function" ])
119
- return serialized
120
-
121
- # Mapping class to vector store serialization functions
122
- vectordb_serialization = {"OpenSearchVectorSearch" : OpenSearchVectorDBSerializer , "FAISS" : FaissSerializer }
123
-
124
-
125
- class RetrievalQASerializer :
126
- """
127
- Serializer for RetrieverQA class
128
- """
129
- @staticmethod
130
- def type ():
131
- return "retrieval_qa"
132
-
133
- @staticmethod
134
- def load (config : dict , ** kwargs ):
135
- config_param = deepcopy (config )
136
- retriever_kwargs = config_param .pop ("retriever_kwargs" )
137
- vectordb_serializer = vectordb_serialization [config_param ["vectordb" ]["class" ]]
138
- vectordb = vectordb_serializer .load (config_param .pop ("vectordb" ), ** kwargs )
139
- retriever = vectordb .as_retriever (** retriever_kwargs )
140
- return load_chain_from_config (config = config_param , retriever = retriever )
141
-
142
- @staticmethod
143
- def save (obj ):
144
- serialized = obj .dict ()
145
- retriever_kwargs = {}
146
- for key , val in obj .retriever .__dict__ .items ():
147
- if key not in ["tags" , "metadata" , "vectorstore" ]:
148
- retriever_kwargs [key ] = val
149
- serialized ["retriever_kwargs" ] = retriever_kwargs
150
- serialized ["vectordb" ] = {"class" : obj .retriever .vectorstore .__class__ .__name__ }
151
-
152
- vectordb_serializer = vectordb_serialization [serialized ["vectordb" ]["class" ]]
153
- serialized ["vectordb" ].update (
154
- vectordb_serializer .save (obj .retriever .vectorstore )
155
- )
156
-
157
- if serialized ["vectordb" ]["class" ] not in vectordb_serialization :
158
- raise NotImplementedError (
159
- f"VectorDBSerializer for { serialized ['vectordb' ]['class' ]} is not implemented."
160
- )
161
- return serialized
162
-
163
-
164
46
# Mapping class to custom serialization functions
165
47
custom_serialization = {
166
48
GuardrailSequence : GuardrailSequence .save ,
0 commit comments