Skip to content

Commit ec78e84

Browse files
authored
Refactor custom serialization for langchain components. (#472)
2 parents b776126 + b7ff05b commit ec78e84

File tree

6 files changed

+225
-153
lines changed

6 files changed

+225
-153
lines changed

ads/llm/serialize.py

Lines changed: 3 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7-
import base64
87
import json
98
import os
109
import tempfile
11-
from copy import deepcopy
1210
from typing import Any, Dict, List, Optional
1311

1412
import fsspec
@@ -20,15 +18,15 @@
2018
from langchain.load import dumpd
2119
from langchain.load.load import Reviver
2220
from langchain.load.serializable import Serializable
23-
from langchain.vectorstores import FAISS, OpenSearchVectorSearch
24-
from opensearchpy.client import OpenSearch
21+
from langchain.schema.runnable import RunnableParallel
2522

2623
from ads.common.auth import default_signer
2724
from ads.common.object_storage_details import ObjectStorageDetails
2825
from ads.llm import GenerativeAI, ModelDeploymentTGI, ModelDeploymentVLLM
2926
from ads.llm.chain import GuardrailSequence
3027
from ads.llm.guardrails.base import CustomGuardrailBase
31-
from ads.llm.patch import RunnableParallel, RunnableParallelSerializer
28+
from ads.llm.serializers.runnable_parallel import RunnableParallelSerializer
29+
from ads.llm.serializers.retrieval_qa import RetrievalQASerializer
3230

3331
# This is a temp solution for supporting custom LLM in legacy load_chain
3432
__lc_llm_dict = llms.get_type_to_cls_dict()
@@ -45,122 +43,6 @@ def __new_type_to_cls_dict():
4543
loading.get_type_to_cls_dict = __new_type_to_cls_dict
4644

4745

48-
class OpenSearchVectorDBSerializer:
49-
"""
50-
Serializer for OpenSearchVectorSearch class
51-
"""
52-
@staticmethod
53-
def type():
54-
return OpenSearchVectorSearch.__name__
55-
56-
@staticmethod
57-
def load(config: dict, **kwargs):
58-
config["kwargs"]["embedding_function"] = load(
59-
config["kwargs"]["embedding_function"], **kwargs
60-
)
61-
return OpenSearchVectorSearch(
62-
**config["kwargs"],
63-
http_auth=(
64-
os.environ.get("OCI_OPENSEARCH_USERNAME", None),
65-
os.environ.get("OCI_OPENSEARCH_PASSWORD", None),
66-
),
67-
verify_certs=True if os.environ.get("OCI_OPENSEARCH_VERIFY_CERTS", None).lower() == "true" else False,
68-
ca_certs=os.environ.get("OCI_OPENSEARCH_CA_CERTS", None),
69-
)
70-
71-
@staticmethod
72-
def save(obj):
73-
serialized = dumpd(obj)
74-
serialized["type"] = "constructor"
75-
serialized["_type"] = OpenSearchVectorDBSerializer.type()
76-
kwargs = {}
77-
for key, val in obj.__dict__.items():
78-
if key == "client":
79-
if isinstance(val, OpenSearch):
80-
client_info = val.transport.hosts[0]
81-
opensearch_url = (
82-
f"https://{client_info['host']}:{client_info['port']}"
83-
)
84-
kwargs.update({"opensearch_url": opensearch_url})
85-
else:
86-
raise NotImplementedError("Only support OpenSearch client.")
87-
continue
88-
kwargs[key] = dump(val)
89-
serialized["kwargs"] = kwargs
90-
return serialized
91-
92-
93-
class FaissSerializer:
94-
"""
95-
Serializer for OpenSearchVectorSearch class
96-
"""
97-
@staticmethod
98-
def type():
99-
return FAISS.__name__
100-
101-
@staticmethod
102-
def load(config: dict, **kwargs):
103-
embedding_function = load(config["embedding_function"], **kwargs)
104-
decoded_pkl = base64.b64decode(json.loads(config["vectordb"]))
105-
return FAISS.deserialize_from_bytes(
106-
embeddings=embedding_function, serialized=decoded_pkl
107-
) # Load the index
108-
109-
@staticmethod
110-
def save(obj):
111-
serialized = {}
112-
serialized["_type"] = FaissSerializer.type()
113-
pkl = obj.serialize_to_bytes()
114-
# Encoding bytes to a base64 string
115-
encoded_pkl = base64.b64encode(pkl).decode('utf-8')
116-
# Serializing the base64 string
117-
serialized["vectordb"] = json.dumps(encoded_pkl)
118-
serialized["embedding_function"] = dump(obj.__dict__["embedding_function"])
119-
return serialized
120-
121-
# Mapping class to vector store serialization functions
122-
vectordb_serialization = {"OpenSearchVectorSearch": OpenSearchVectorDBSerializer, "FAISS": FaissSerializer}
123-
124-
125-
class RetrievalQASerializer:
126-
"""
127-
Serializer for RetrieverQA class
128-
"""
129-
@staticmethod
130-
def type():
131-
return "retrieval_qa"
132-
133-
@staticmethod
134-
def load(config: dict, **kwargs):
135-
config_param = deepcopy(config)
136-
retriever_kwargs = config_param.pop("retriever_kwargs")
137-
vectordb_serializer = vectordb_serialization[config_param["vectordb"]["class"]]
138-
vectordb = vectordb_serializer.load(config_param.pop("vectordb"), **kwargs)
139-
retriever = vectordb.as_retriever(**retriever_kwargs)
140-
return load_chain_from_config(config=config_param, retriever=retriever)
141-
142-
@staticmethod
143-
def save(obj):
144-
serialized = obj.dict()
145-
retriever_kwargs = {}
146-
for key, val in obj.retriever.__dict__.items():
147-
if key not in ["tags", "metadata", "vectorstore"]:
148-
retriever_kwargs[key] = val
149-
serialized["retriever_kwargs"] = retriever_kwargs
150-
serialized["vectordb"] = {"class": obj.retriever.vectorstore.__class__.__name__}
151-
152-
vectordb_serializer = vectordb_serialization[serialized["vectordb"]["class"]]
153-
serialized["vectordb"].update(
154-
vectordb_serializer.save(obj.retriever.vectorstore)
155-
)
156-
157-
if serialized["vectordb"]["class"] not in vectordb_serialization:
158-
raise NotImplementedError(
159-
f"VectorDBSerializer for {serialized['vectordb']['class']} is not implemented."
160-
)
161-
return serialized
162-
163-
16446
# Mapping class to custom serialization functions
16547
custom_serialization = {
16648
GuardrailSequence: GuardrailSequence.save,

ads/llm/serializers/__init__.py

Whitespace-only changes.

ads/llm/serializers/retrieval_qa.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import base64
2+
import json
3+
import os
4+
from copy import deepcopy
5+
from langchain.load import dumpd
6+
from langchain.chains.loading import load_chain_from_config
7+
from langchain.vectorstores import FAISS, OpenSearchVectorSearch
8+
9+
10+
class OpenSearchVectorDBSerializer:
11+
"""
12+
Serializer for OpenSearchVectorSearch class
13+
"""
14+
15+
@staticmethod
16+
def type():
17+
return OpenSearchVectorSearch.__name__
18+
19+
@staticmethod
20+
def load(config: dict, **kwargs):
21+
from ads.llm.serialize import load
22+
23+
config["kwargs"]["embedding_function"] = load(
24+
config["kwargs"]["embedding_function"], **kwargs
25+
)
26+
return OpenSearchVectorSearch(
27+
**config["kwargs"],
28+
http_auth=(
29+
os.environ.get("OCI_OPENSEARCH_USERNAME", None),
30+
os.environ.get("OCI_OPENSEARCH_PASSWORD", None),
31+
),
32+
verify_certs=True
33+
if os.environ.get("OCI_OPENSEARCH_VERIFY_CERTS", None).lower() == "true"
34+
else False,
35+
ca_certs=os.environ.get("OCI_OPENSEARCH_CA_CERTS", None),
36+
)
37+
38+
@staticmethod
39+
def save(obj):
40+
from ads.llm.serialize import dump
41+
from opensearchpy.client import OpenSearch
42+
43+
serialized = dumpd(obj)
44+
serialized["type"] = "constructor"
45+
serialized["_type"] = OpenSearchVectorDBSerializer.type()
46+
kwargs = {}
47+
for key, val in obj.__dict__.items():
48+
if key == "client":
49+
if isinstance(val, OpenSearch):
50+
client_info = val.transport.hosts[0]
51+
opensearch_url = (
52+
f"https://{client_info['host']}:{client_info['port']}"
53+
)
54+
kwargs.update({"opensearch_url": opensearch_url})
55+
else:
56+
raise NotImplementedError("Only support OpenSearch client.")
57+
continue
58+
kwargs[key] = dump(val)
59+
serialized["kwargs"] = kwargs
60+
return serialized
61+
62+
63+
class FaissSerializer:
64+
"""
65+
Serializer for OpenSearchVectorSearch class
66+
"""
67+
68+
@staticmethod
69+
def type():
70+
return FAISS.__name__
71+
72+
@staticmethod
73+
def load(config: dict, **kwargs):
74+
from ads.llm.serialize import load
75+
76+
embedding_function = load(config["embedding_function"], **kwargs)
77+
decoded_pkl = base64.b64decode(json.loads(config["vectordb"]))
78+
return FAISS.deserialize_from_bytes(
79+
embeddings=embedding_function, serialized=decoded_pkl
80+
) # Load the index
81+
82+
@staticmethod
83+
def save(obj):
84+
from ads.llm.serialize import dump
85+
86+
serialized = {}
87+
serialized["_type"] = FaissSerializer.type()
88+
pkl = obj.serialize_to_bytes()
89+
# Encoding bytes to a base64 string
90+
encoded_pkl = base64.b64encode(pkl).decode("utf-8")
91+
# Serializing the base64 string
92+
serialized["vectordb"] = json.dumps(encoded_pkl)
93+
serialized["embedding_function"] = dump(obj.__dict__["embedding_function"])
94+
return serialized
95+
96+
97+
class RetrievalQASerializer:
98+
"""
99+
Serializer for RetrieverQA class
100+
"""
101+
102+
# Mapping class to vector store serialization functions
103+
vectordb_serialization = {
104+
"OpenSearchVectorSearch": OpenSearchVectorDBSerializer,
105+
"FAISS": FaissSerializer,
106+
}
107+
108+
@staticmethod
109+
def type():
110+
return "retrieval_qa"
111+
112+
@staticmethod
113+
def load(config: dict, **kwargs):
114+
config_param = deepcopy(config)
115+
retriever_kwargs = config_param.pop("retriever_kwargs")
116+
vectordb_serializer = RetrievalQASerializer.vectordb_serialization[
117+
config_param["vectordb"]["class"]
118+
]
119+
vectordb = vectordb_serializer.load(config_param.pop("vectordb"), **kwargs)
120+
retriever = vectordb.as_retriever(**retriever_kwargs)
121+
return load_chain_from_config(config=config_param, retriever=retriever)
122+
123+
@staticmethod
124+
def save(obj):
125+
serialized = obj.dict()
126+
retriever_kwargs = {}
127+
for key, val in obj.retriever.__dict__.items():
128+
if key not in ["tags", "metadata", "vectorstore"]:
129+
retriever_kwargs[key] = val
130+
serialized["retriever_kwargs"] = retriever_kwargs
131+
serialized["vectordb"] = {"class": obj.retriever.vectorstore.__class__.__name__}
132+
133+
vectordb_serializer = RetrievalQASerializer.vectordb_serialization[
134+
serialized["vectordb"]["class"]
135+
]
136+
serialized["vectordb"].update(
137+
vectordb_serializer.save(obj.retriever.vectorstore)
138+
)
139+
140+
if (
141+
serialized["vectordb"]["class"]
142+
not in RetrievalQASerializer.vectordb_serialization
143+
):
144+
raise NotImplementedError(
145+
f"VectorDBSerializer for {serialized['vectordb']['class']} is not implemented."
146+
)
147+
return serialized
File renamed without changes.

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
18
import os
29
from unittest import TestCase, mock, SkipTest
310

0 commit comments

Comments
 (0)