14
14
from langchain import llms
15
15
from langchain .llms import loading
16
16
from langchain .chains .loading import load_chain_from_config
17
- from langchain .load .load import load as __lc_load
17
+ from langchain .load .load import Reviver , load as __lc_load
18
18
from langchain .load .serializable import Serializable
19
19
20
20
from ads .common .auth import default_signer
@@ -76,14 +76,32 @@ def load(
76
76
Returns:
77
77
Revived LangChain objects.
78
78
"""
79
+ # Add ADS as valid namespace
79
80
if not valid_namespaces :
80
81
valid_namespaces = []
81
82
if "ads" not in valid_namespaces :
82
83
valid_namespaces .append ("ads" )
83
84
85
+ reviver = Reviver (secrets_map , valid_namespaces )
86
+
87
+ def _load (obj : Any ) -> Any :
88
+ if isinstance (obj , dict ):
89
+ if "_type" in obj and obj ["_type" ] in custom_deserialization :
90
+ if valid_namespaces :
91
+ kwargs ["valid_namespaces" ] = valid_namespaces
92
+ if secrets_map :
93
+ kwargs ["secret_map" ] = secrets_map
94
+ return custom_deserialization [obj ["_type" ]](obj , ** kwargs )
95
+ # Need to revive leaf nodes before reviving this node
96
+ loaded_obj = {k : _load (v ) for k , v in obj .items ()}
97
+ return reviver (loaded_obj )
98
+ if isinstance (obj , list ):
99
+ return [_load (o ) for o in obj ]
100
+ return obj
101
+
84
102
if isinstance (obj , dict ) and "_type" in obj :
85
103
obj_type = obj ["_type" ]
86
- # Check if the object requires a custom function to load.
104
+ # Check if the object has custom load function .
87
105
if obj_type in custom_deserialization :
88
106
if valid_namespaces :
89
107
kwargs ["valid_namespaces" ] = valid_namespaces
@@ -93,7 +111,7 @@ def load(
93
111
# Legacy chain
94
112
return load_chain_from_config (obj , ** kwargs )
95
113
96
- return __lc_load (obj , secrets_map = secrets_map , valid_namespaces = valid_namespaces )
114
+ return _load (obj )
97
115
98
116
99
117
def load_from_yaml (
@@ -144,6 +162,9 @@ def default(obj: Any) -> Any:
144
162
TypeError
145
163
If the object is not LangChain serializable.
146
164
"""
165
+ for super_class , save_fn in custom_serialization .items ():
166
+ if isinstance (obj , super_class ):
167
+ return save_fn (obj )
147
168
if isinstance (obj , Serializable ) and obj .is_lc_serializable ():
148
169
return obj .to_json ()
149
170
raise TypeError (f"Serialization of { type (obj )} is not supported." )
0 commit comments