5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
7
8
- import importlib
9
- import importlib .util
10
8
import json
11
9
import logging
12
10
import os
13
11
import pathlib
14
- import sys
15
-
16
- from copy import deepcopy
17
- from typing import Any , List , Optional , ClassVar
12
+ from typing import Any , List , Optional
18
13
19
14
import yaml
20
-
21
- from langchain .chains .loading import load_chain_from_config , type_to_loader_dict
22
15
from langchain .llms .base import LLM
23
16
from langchain .schema .runnable import (
24
17
Runnable ,
25
18
RunnableConfig ,
26
19
RunnableSequence ,
27
20
)
28
- from ads .llm import guardrails
29
- from ads .llm .serialize import load , dump
30
21
from ads .llm .guardrails .base import GuardrailIO , Guardrail , RunInfo , BlockedByGuardrail
31
22
32
23
33
24
logger = logging .getLogger (__name__ )
34
- SPEC_CLASS = "class"
35
- SPEC_PATH = "path"
36
- SPEC_SPEC = "spec"
37
25
SPEC_CHAIN_TYPE = "_type"
38
26
SPEC_CHAIN = "chain"
39
- BUILT_IN = "ads."
40
27
41
28
42
29
class GuardrailSequence (RunnableSequence ):
43
30
"""Represents a sequence of guardrails and other LangChain (non-guardrail) components."""
44
31
45
- CHAIN_TYPE : ClassVar [str ] = "ads_guardrail_sequence"
46
-
47
32
first : Optional [Runnable ] = None
48
33
last : Optional [Runnable ] = None
49
34
@@ -72,8 +57,14 @@ def steps(self) -> List[Runnable[Any, Any]]:
72
57
chain += [self .last ]
73
58
return chain
74
59
60
+ @staticmethod
61
+ def type () -> str :
62
+ """A unique identifier as type for serialization."""
63
+ return "ads_guardrail_sequence"
64
+
75
65
@classmethod
76
66
def from_sequence (cls , sequence : RunnableSequence ):
67
+ """Creates a GuardrailSequence from a LangChain runnable sequence."""
77
68
return cls (first = sequence .first , middle = sequence .middle , last = sequence .last )
78
69
79
70
def __or__ (self , other ) -> "GuardrailSequence" :
@@ -100,7 +91,7 @@ def invoke(self, input: Any, config: RunnableConfig = None) -> GuardrailIO:
100
91
"""
101
92
return self .run (input )
102
93
103
- def _invoke_llm (self , llm , texts , num_generations , ** kwargs ):
94
+ def _invoke_llm (self , llm : LLM , texts : list , num_generations : int , ** kwargs ):
104
95
if num_generations > 1 :
105
96
if len (texts ) > 1 :
106
97
raise NotImplementedError (
@@ -166,6 +157,7 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
166
157
if self .raise_exception :
167
158
raise ex
168
159
obj .data = [ex .message ]
160
+ obj .info .append (ex .info )
169
161
return obj
170
162
171
163
def _save_to_file (self , chain_dict , filename , overwrite = False ):
@@ -187,7 +179,7 @@ def _save_to_file(self, chain_dict, filename, overwrite=False):
187
179
f"{ self .__class__ .__name__ } can only be saved as yaml or json format."
188
180
)
189
181
190
- def save (self , filename : str = None , overwrite : bool = False ):
182
+ def save (self , filename : str = None , overwrite : bool = False ) -> dict :
191
183
"""Serialize the sequence to a dictionary.
192
184
Optionally, save the sequence into a JSON or YAML file.
193
185
@@ -196,16 +188,12 @@ def save(self, filename: str = None, overwrite: bool = False):
196
188
{
197
189
"_type": "ads_guardrail_sequence",
198
190
"chain": [
199
- {
200
- "class": "...",
201
- "path": "...",
202
- "spec": {
203
- ...
204
- }
205
- }
191
+ ...
206
192
]
207
193
}
208
194
195
+ where ``chain`` contains a list of steps.
196
+
209
197
Parameters
210
198
----------
211
199
filename : str
@@ -216,22 +204,13 @@ def save(self, filename: str = None, overwrite: bool = False):
216
204
dict
217
205
The sequence saved as a dictionary.
218
206
"""
207
+ from ads .llm .serialize import dump
208
+
219
209
chain_spec = []
220
210
for step in self .steps :
221
- class_name = step .__class__ .__name__
222
- if step .__module__ .startswith (BUILT_IN ):
223
- path = getattr (step , "path" , None )
224
- else :
225
- path = step .__module__
226
-
227
- logger .debug ("class: %s | module: %s" , class_name , path )
228
- if not hasattr (step , "dict" ):
229
- raise NotImplementedError (f"{ class_name } is not serializable." )
230
- chain_spec .append (
231
- {SPEC_CLASS : class_name , SPEC_PATH : path , SPEC_SPEC : step .dict ()}
232
- )
211
+ chain_spec .append (dump (step ))
233
212
chain_dict = {
234
- SPEC_CHAIN_TYPE : self .CHAIN_TYPE ,
213
+ SPEC_CHAIN_TYPE : self .type () ,
235
214
SPEC_CHAIN : chain_spec ,
236
215
}
237
216
@@ -240,83 +219,8 @@ def save(self, filename: str = None, overwrite: bool = False):
240
219
241
220
return chain_dict
242
221
243
- def __str__ (self ) -> str :
244
- return "\n " .join ([str (step .__class__ ) for step in self .steps ])
245
-
246
- @staticmethod
247
- def _load_class_from_file (module_name , file_path , class_name ):
248
- module_spec = importlib .util .spec_from_file_location (module_name , file_path )
249
- module = importlib .util .module_from_spec (module_spec )
250
- sys .modules [module_name ] = module
251
- module_spec .loader .exec_module (module )
252
- return getattr (module , class_name )
253
-
254
- @staticmethod
255
- def _load_class_from_module (module_name , class_name ):
256
- component_module = importlib .import_module (module_name )
257
- return getattr (component_module , class_name )
258
-
259
- @staticmethod
260
- def load_step (config : dict ):
261
- spec = deepcopy (config .get (SPEC_SPEC , {}))
262
- spec : dict
263
- class_name = config [SPEC_CLASS ]
264
- module_name = config .get (SPEC_PATH )
265
-
266
- if not module_name and "." in class_name :
267
- # The class name is given as a.b.c.MyClass
268
- module_name , class_name = class_name .rsplit ("." , 1 )
269
-
270
- # Load the step with LangChain loader if it matches the "_type".
271
- # Note that some LangChain objects are saved with the "_type" but there is no matching loader.
272
- if (
273
- str (module_name ).startswith ("langchain." )
274
- and SPEC_CHAIN_TYPE in spec
275
- and spec [SPEC_CHAIN_TYPE ] in type_to_loader_dict
276
- ):
277
- return load_chain_from_config (spec )
278
-
279
- # Load the guardrail using spec as kwargs
280
- if hasattr (guardrails , class_name ):
281
- # Built-in guardrail, including custom huggingface guardrail
282
- component_class = getattr (guardrails , class_name )
283
- # Copy the path into spec if it is not already there
284
- if SPEC_PATH in config and SPEC_PATH not in spec :
285
- spec [SPEC_PATH ] = config [SPEC_PATH ]
286
- elif SPEC_PATH in config :
287
- # Custom component
288
- # For custom guardrail, the module name could be a file.
289
- if "://" in module_name :
290
- # TODO: Load module from OCI object storage
291
- #
292
- # component_class = GuardrailSequence._load_class_from_file(
293
- # module_name, temp_file, class_name
294
- # )
295
- raise NotImplementedError (
296
- f"Loading module from { module_name } is not supported."
297
- )
298
- elif os .path .exists (module_name ):
299
- component_class = GuardrailSequence ._load_class_from_file (
300
- module_name , module_name , class_name
301
- )
302
- else :
303
- component_class = GuardrailSequence ._load_class_from_module (
304
- module_name , class_name
305
- )
306
- elif "." in class_name :
307
- # The class name is given as a.b.c.MyClass
308
- module_name , class_name = class_name .rsplit ("." , 1 )
309
- component_class = GuardrailSequence ._load_class_from_module (
310
- module_name , class_name
311
- )
312
- else :
313
- raise ValueError (f"Invalid Guardrail: { class_name } " )
314
-
315
- spec .pop (SPEC_CHAIN_TYPE , None )
316
- return component_class (** spec )
317
-
318
222
@classmethod
319
- def load (cls , chain_dict : dict ) -> "GuardrailSequence" :
223
+ def load (cls , chain_dict : dict , ** kwargs ) -> "GuardrailSequence" :
320
224
"""Loads the sequence from a dictionary config.
321
225
322
226
Parameters
@@ -331,10 +235,15 @@ def load(cls, chain_dict: dict) -> "GuardrailSequence":
331
235
GuardrailSequence
332
236
A GuardrailSequence loaded from the config.
333
237
"""
238
+ from ads .llm .serialize import load
239
+
334
240
chain_spec = chain_dict [SPEC_CHAIN ]
335
241
chain = cls ()
336
242
for config in chain_spec :
337
- guardrail = cls . load_step (config )
338
- # Chain the guardrail
339
- chain |= guardrail
243
+ step = load (config , ** kwargs )
244
+ # Chain the step
245
+ chain |= step
340
246
return chain
247
+
248
+ def __str__ (self ) -> str :
249
+ return "\n " .join ([str (step .__class__ ) for step in self .steps ])
0 commit comments