14
14
from langchain import llms
15
15
from langchain .llms import loading
16
16
from langchain .chains .loading import load_chain_from_config
17
- from langchain .load .load import load as __lc_load
17
+ from langchain .load .load import Reviver , load as __lc_load
18
18
from langchain .load .serializable import Serializable
19
19
20
20
from ads .common .auth import default_signer
@@ -76,14 +76,32 @@ def load(
76
76
Returns:
77
77
Revived LangChain objects.
78
78
"""
79
+ # Add ADS as valid namespace
79
80
if not valid_namespaces :
80
81
valid_namespaces = []
81
82
if "ads" not in valid_namespaces :
82
83
valid_namespaces .append ("ads" )
83
84
85
+ reviver = Reviver (secrets_map , valid_namespaces )
86
+
87
+ def _load (obj : Any ) -> Any :
88
+ if isinstance (obj , dict ):
89
+ if "_type" in obj and obj ["_type" ] in custom_deserialization :
90
+ if valid_namespaces :
91
+ kwargs ["valid_namespaces" ] = valid_namespaces
92
+ if secrets_map :
93
+ kwargs ["secret_map" ] = secrets_map
94
+ return custom_deserialization [obj ["_type" ]](obj , ** kwargs )
95
+ # Need to revive leaf nodes before reviving this node
96
+ loaded_obj = {k : _load (v ) for k , v in obj .items ()}
97
+ return reviver (loaded_obj )
98
+ if isinstance (obj , list ):
99
+ return [_load (o ) for o in obj ]
100
+ return obj
101
+
84
102
if isinstance (obj , dict ) and "_type" in obj :
85
103
obj_type = obj ["_type" ]
86
- # Check if the object requires a custom function to load.
104
+ # Check if the object has custom load function .
87
105
if obj_type in custom_deserialization :
88
106
if valid_namespaces :
89
107
kwargs ["valid_namespaces" ] = valid_namespaces
@@ -93,7 +111,7 @@ def load(
93
111
# Legacy chain
94
112
return load_chain_from_config (obj , ** kwargs )
95
113
96
- return __lc_load (obj , secrets_map = secrets_map , valid_namespaces = valid_namespaces )
114
+ return _load (obj )
97
115
98
116
99
117
def load_from_yaml (
@@ -144,11 +162,30 @@ def default(obj: Any) -> Any:
144
162
TypeError
145
163
If the object is not LangChain serializable.
146
164
"""
165
+ for super_class , save_fn in custom_serialization .items ():
166
+ if isinstance (obj , super_class ):
167
+ return save_fn (obj )
147
168
if isinstance (obj , Serializable ) and obj .is_lc_serializable ():
148
169
return obj .to_json ()
149
170
raise TypeError (f"Serialization of { type (obj )} is not supported." )
150
171
151
172
173
+ def __save (obj ):
174
+ """Calls the legacy save method to save the object to temp json
175
+ then load it into a dictionary.
176
+ """
177
+ try :
178
+ temp_file = tempfile .NamedTemporaryFile (
179
+ mode = "w" , encoding = "utf-8" , suffix = ".json" , delete = False
180
+ )
181
+ temp_file .close ()
182
+ obj .save (temp_file .name )
183
+ with open (temp_file .name , "r" , encoding = "utf-8" ) as f :
184
+ return json .load (f )
185
+ finally :
186
+ os .unlink (temp_file .name )
187
+
188
+
152
189
def dump (obj : Any ) -> Dict [str , Any ]:
153
190
"""Return a json dict representation of an object.
154
191
@@ -167,14 +204,14 @@ def dump(obj: Any) -> Dict[str, Any]:
167
204
):
168
205
# The object is not is_lc_serializable.
169
206
# However, it supports the legacy save() method.
170
- try :
171
- temp_file = tempfile . NamedTemporaryFile (
172
- mode = "w" , encoding = "utf-8" , suffix = ".json" , delete = False
173
- )
174
- temp_file . close ()
175
- obj . save ( temp_file . name )
176
- with open ( temp_file . name , "r" , encoding = "utf-8" ) as f :
177
- return json . load ( f )
178
- finally :
179
- os . unlink ( temp_file . name )
180
- return json . loads ( json . dumps ( obj , default = default ))
207
+ return __save ( obj )
208
+ # The object is is_lc_serializable.
209
+ # However, some properties may not be serializable
210
+ # Here we try to dump the object and fallback to the save() method
211
+ # if there is an error.
212
+ try :
213
+ return json . loads ( json . dumps ( obj , default = default ))
214
+ except TypeError as ex :
215
+ if isinstance ( obj , Serializable ) and hasattr ( obj , "save" ) :
216
+ return __save ( obj )
217
+ raise ex
0 commit comments