Skip to content

Commit 579bd43

Browse files
committed
Update chain serialization.
1 parent 3418eea commit 579bd43

File tree

1 file changed

+24
-3
lines changed

1 file changed

+24
-3
lines changed

ads/llm/serialize.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from langchain import llms
1515
from langchain.llms import loading
1616
from langchain.chains.loading import load_chain_from_config
17-
from langchain.load.load import load as __lc_load
17+
from langchain.load.load import Reviver, load as __lc_load
1818
from langchain.load.serializable import Serializable
1919

2020
from ads.common.auth import default_signer
@@ -76,14 +76,32 @@ def load(
7676
Returns:
7777
Revived LangChain objects.
7878
"""
79+
# Add ADS as valid namespace
7980
if not valid_namespaces:
8081
valid_namespaces = []
8182
if "ads" not in valid_namespaces:
8283
valid_namespaces.append("ads")
8384

85+
reviver = Reviver(secrets_map, valid_namespaces)
86+
87+
def _load(obj: Any) -> Any:
88+
if isinstance(obj, dict):
89+
if "_type" in obj and obj["_type"] in custom_deserialization:
90+
if valid_namespaces:
91+
kwargs["valid_namespaces"] = valid_namespaces
92+
if secrets_map:
93+
kwargs["secret_map"] = secrets_map
94+
return custom_deserialization[obj["_type"]](obj, **kwargs)
95+
# Need to revive leaf nodes before reviving this node
96+
loaded_obj = {k: _load(v) for k, v in obj.items()}
97+
return reviver(loaded_obj)
98+
if isinstance(obj, list):
99+
return [_load(o) for o in obj]
100+
return obj
101+
84102
if isinstance(obj, dict) and "_type" in obj:
85103
obj_type = obj["_type"]
86-
# Check if the object requires a custom function to load.
104+
# Check if the object has custom load function.
87105
if obj_type in custom_deserialization:
88106
if valid_namespaces:
89107
kwargs["valid_namespaces"] = valid_namespaces
@@ -93,7 +111,7 @@ def load(
93111
# Legacy chain
94112
return load_chain_from_config(obj, **kwargs)
95113

96-
return __lc_load(obj, secrets_map=secrets_map, valid_namespaces=valid_namespaces)
114+
return _load(obj)
97115

98116

99117
def load_from_yaml(
@@ -144,6 +162,9 @@ def default(obj: Any) -> Any:
144162
TypeError
145163
If the object is not LangChain serializable.
146164
"""
165+
for super_class, save_fn in custom_serialization.items():
166+
if isinstance(obj, super_class):
167+
return save_fn(obj)
147168
if isinstance(obj, Serializable) and obj.is_lc_serializable():
148169
return obj.to_json()
149170
raise TypeError(f"Serialization of {type(obj)} is not supported.")

0 commit comments

Comments
 (0)