Skip to content

Commit 31956da

Browse files
committed
Update guardrail sequence serialization.
1 parent 97d105a commit 31956da

File tree

5 files changed

+212
-166
lines changed

5 files changed

+212
-166
lines changed

ads/llm/chain.py

Lines changed: 26 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -5,45 +5,30 @@
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

77

8-
import importlib
9-
import importlib.util
108
import json
119
import logging
1210
import os
1311
import pathlib
14-
import sys
15-
16-
from copy import deepcopy
17-
from typing import Any, List, Optional, ClassVar
12+
from typing import Any, List, Optional
1813

1914
import yaml
20-
21-
from langchain.chains.loading import load_chain_from_config, type_to_loader_dict
2215
from langchain.llms.base import LLM
2316
from langchain.schema.runnable import (
2417
Runnable,
2518
RunnableConfig,
2619
RunnableSequence,
2720
)
28-
from ads.llm import guardrails
29-
from ads.llm.serialize import load, dump
3021
from ads.llm.guardrails.base import GuardrailIO, Guardrail, RunInfo, BlockedByGuardrail
3122

3223

3324
logger = logging.getLogger(__name__)
34-
SPEC_CLASS = "class"
35-
SPEC_PATH = "path"
36-
SPEC_SPEC = "spec"
3725
SPEC_CHAIN_TYPE = "_type"
3826
SPEC_CHAIN = "chain"
39-
BUILT_IN = "ads."
4027

4128

4229
class GuardrailSequence(RunnableSequence):
4330
"""Represents a sequence of guardrails and other LangChain (non-guardrail) components."""
4431

45-
CHAIN_TYPE: ClassVar[str] = "ads_guardrail_sequence"
46-
4732
first: Optional[Runnable] = None
4833
last: Optional[Runnable] = None
4934

@@ -72,8 +57,14 @@ def steps(self) -> List[Runnable[Any, Any]]:
7257
chain += [self.last]
7358
return chain
7459

60+
@staticmethod
61+
def type() -> str:
62+
"""A unique identifier as type for serialization."""
63+
return "ads_guardrail_sequence"
64+
7565
@classmethod
7666
def from_sequence(cls, sequence: RunnableSequence):
67+
"""Creates a GuardrailSequence from a LangChain runnable sequence."""
7768
return cls(first=sequence.first, middle=sequence.middle, last=sequence.last)
7869

7970
def __or__(self, other) -> "GuardrailSequence":
@@ -100,7 +91,7 @@ def invoke(self, input: Any, config: RunnableConfig = None) -> GuardrailIO:
10091
"""
10192
return self.run(input)
10293

103-
def _invoke_llm(self, llm, texts, num_generations, **kwargs):
94+
def _invoke_llm(self, llm: LLM, texts: list, num_generations: int, **kwargs):
10495
if num_generations > 1:
10596
if len(texts) > 1:
10697
raise NotImplementedError(
@@ -166,6 +157,7 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
166157
if self.raise_exception:
167158
raise ex
168159
obj.data = [ex.message]
160+
obj.info.append(ex.info)
169161
return obj
170162

171163
def _save_to_file(self, chain_dict, filename, overwrite=False):
@@ -187,7 +179,7 @@ def _save_to_file(self, chain_dict, filename, overwrite=False):
187179
f"{self.__class__.__name__} can only be saved as yaml or json format."
188180
)
189181

190-
def save(self, filename: str = None, overwrite: bool = False):
182+
def save(self, filename: str = None, overwrite: bool = False) -> dict:
191183
"""Serialize the sequence to a dictionary.
192184
Optionally, save the sequence into a JSON or YAML file.
193185
@@ -196,16 +188,12 @@ def save(self, filename: str = None, overwrite: bool = False):
196188
{
197189
"_type": "ads_guardrail_sequence",
198190
"chain": [
199-
{
200-
"class": "...",
201-
"path": "...",
202-
"spec": {
203-
...
204-
}
205-
}
191+
...
206192
]
207193
}
208194
195+
where ``chain`` contains a list of steps.
196+
209197
Parameters
210198
----------
211199
filename : str
@@ -216,22 +204,13 @@ def save(self, filename: str = None, overwrite: bool = False):
216204
dict
217205
The sequence saved as a dictionary.
218206
"""
207+
from ads.llm.serialize import dump
208+
219209
chain_spec = []
220210
for step in self.steps:
221-
class_name = step.__class__.__name__
222-
if step.__module__.startswith(BUILT_IN):
223-
path = getattr(step, "path", None)
224-
else:
225-
path = step.__module__
226-
227-
logger.debug("class: %s | module: %s", class_name, path)
228-
if not hasattr(step, "dict"):
229-
raise NotImplementedError(f"{class_name} is not serializable.")
230-
chain_spec.append(
231-
{SPEC_CLASS: class_name, SPEC_PATH: path, SPEC_SPEC: step.dict()}
232-
)
211+
chain_spec.append(dump(step))
233212
chain_dict = {
234-
SPEC_CHAIN_TYPE: self.CHAIN_TYPE,
213+
SPEC_CHAIN_TYPE: self.type(),
235214
SPEC_CHAIN: chain_spec,
236215
}
237216

@@ -240,83 +219,8 @@ def save(self, filename: str = None, overwrite: bool = False):
240219

241220
return chain_dict
242221

243-
def __str__(self) -> str:
244-
return "\n".join([str(step.__class__) for step in self.steps])
245-
246-
@staticmethod
247-
def _load_class_from_file(module_name, file_path, class_name):
248-
module_spec = importlib.util.spec_from_file_location(module_name, file_path)
249-
module = importlib.util.module_from_spec(module_spec)
250-
sys.modules[module_name] = module
251-
module_spec.loader.exec_module(module)
252-
return getattr(module, class_name)
253-
254-
@staticmethod
255-
def _load_class_from_module(module_name, class_name):
256-
component_module = importlib.import_module(module_name)
257-
return getattr(component_module, class_name)
258-
259-
@staticmethod
260-
def load_step(config: dict):
261-
spec = deepcopy(config.get(SPEC_SPEC, {}))
262-
spec: dict
263-
class_name = config[SPEC_CLASS]
264-
module_name = config.get(SPEC_PATH)
265-
266-
if not module_name and "." in class_name:
267-
# The class name is given as a.b.c.MyClass
268-
module_name, class_name = class_name.rsplit(".", 1)
269-
270-
# Load the step with LangChain loader if it matches the "_type".
271-
# Note that some LangChain objects are saved with the "_type" but there is no matching loader.
272-
if (
273-
str(module_name).startswith("langchain.")
274-
and SPEC_CHAIN_TYPE in spec
275-
and spec[SPEC_CHAIN_TYPE] in type_to_loader_dict
276-
):
277-
return load_chain_from_config(spec)
278-
279-
# Load the guardrail using spec as kwargs
280-
if hasattr(guardrails, class_name):
281-
# Built-in guardrail, including custom huggingface guardrail
282-
component_class = getattr(guardrails, class_name)
283-
# Copy the path into spec if it is not already there
284-
if SPEC_PATH in config and SPEC_PATH not in spec:
285-
spec[SPEC_PATH] = config[SPEC_PATH]
286-
elif SPEC_PATH in config:
287-
# Custom component
288-
# For custom guardrail, the module name could be a file.
289-
if "://" in module_name:
290-
# TODO: Load module from OCI object storage
291-
#
292-
# component_class = GuardrailSequence._load_class_from_file(
293-
# module_name, temp_file, class_name
294-
# )
295-
raise NotImplementedError(
296-
f"Loading module from {module_name} is not supported."
297-
)
298-
elif os.path.exists(module_name):
299-
component_class = GuardrailSequence._load_class_from_file(
300-
module_name, module_name, class_name
301-
)
302-
else:
303-
component_class = GuardrailSequence._load_class_from_module(
304-
module_name, class_name
305-
)
306-
elif "." in class_name:
307-
# The class name is given as a.b.c.MyClass
308-
module_name, class_name = class_name.rsplit(".", 1)
309-
component_class = GuardrailSequence._load_class_from_module(
310-
module_name, class_name
311-
)
312-
else:
313-
raise ValueError(f"Invalid Guardrail: {class_name}")
314-
315-
spec.pop(SPEC_CHAIN_TYPE, None)
316-
return component_class(**spec)
317-
318222
@classmethod
319-
def load(cls, chain_dict: dict) -> "GuardrailSequence":
223+
def load(cls, chain_dict: dict, **kwargs) -> "GuardrailSequence":
320224
"""Loads the sequence from a dictionary config.
321225
322226
Parameters
@@ -331,10 +235,15 @@ def load(cls, chain_dict: dict) -> "GuardrailSequence":
331235
GuardrailSequence
332236
A GuardrailSequence loaded from the config.
333237
"""
238+
from ads.llm.serialize import load
239+
334240
chain_spec = chain_dict[SPEC_CHAIN]
335241
chain = cls()
336242
for config in chain_spec:
337-
guardrail = cls.load_step(config)
338-
# Chain the guardrail
339-
chain |= guardrail
243+
step = load(config, **kwargs)
244+
# Chain the step
245+
chain |= step
340246
return chain
247+
248+
def __str__(self) -> str:
249+
return "\n".join([str(step.__class__) for step in self.steps])

ads/llm/guardrails/base.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import datetime
99
import functools
1010
import operator
11+
import importlib.util
12+
import sys
13+
1114
from typing import Any, List, Dict, Tuple
1215
from langchain.schema.prompt import PromptValue
1316
from langchain.tools.base import BaseTool, ToolException
@@ -234,6 +237,7 @@ def _run(self, query: Any, run_manager=None) -> Any:
234237
"""
235238
if isinstance(query, GuardrailIO):
236239
guardrail_io = query
240+
query = guardrail_io.data
237241
else:
238242
guardrail_io = None
239243
# In this default implementation, we convert all input to list.
@@ -290,9 +294,6 @@ def _run(self, query: Any, run_manager=None) -> Any:
290294
return {"output": output, "metrics": info.metrics}
291295
return output
292296

293-
def load(self) -> None:
294-
"""Loads the models and configs needed for the guardrail."""
295-
296297
def compute(self, data=None, **kwargs) -> dict:
297298
"""Computes the metrics and returns a dictionary."""
298299
return {}
@@ -408,3 +409,45 @@ def single_metric_moderate(self, metrics: dict, data=None, **kwargs) -> List[str
408409
elif self.threshold is not None:
409410
return self.apply_filter(metrics, data)
410411
return data
412+
413+
414+
class CustomGuardrailBase(Guardrail):
415+
"""Base class for custom guardrail."""
416+
417+
@classmethod
418+
def is_lc_serializable(cls) -> bool:
419+
"""This class is not LangChain serializable."""
420+
return False
421+
422+
@staticmethod
423+
def load_class_from_file(uri: str, class_name: str):
424+
"""Loads a Python class from a file."""
425+
# TODO: Support loading from OCI object storage
426+
module_name = uri
427+
module_spec = importlib.util.spec_from_file_location(module_name, uri)
428+
module = importlib.util.module_from_spec(module_spec)
429+
sys.modules[module_name] = module
430+
module_spec.loader.exec_module(module)
431+
return getattr(module, class_name)
432+
433+
@staticmethod
434+
def type() -> str:
435+
"""A unique string as identifier to the type of the object for serialization."""
436+
return "ads_custom_guardrail"
437+
438+
@staticmethod
439+
def load(config, **kwargs):
440+
"""Loads the object from serialized config."""
441+
guardrail_class = CustomGuardrailBase.load_class_from_file(
442+
config["module"], config["class"]
443+
)
444+
return guardrail_class(**config["spec"])
445+
446+
def save(self) -> dict:
447+
"""Serialize the object into a dictionary."""
448+
return {
449+
"_type": self.type(),
450+
"module": self.__module__,
451+
"class": self.__class__.__name__,
452+
"spec": self.dict(),
453+
}

ads/llm/patch.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from langchain.schema.runnable import RunnableParallel
2+
from langchain.load import dumpd, load
3+
4+
5+
class RunnableParallelSerializer:
6+
@staticmethod
7+
def type():
8+
return RunnableParallel.__name__
9+
10+
@staticmethod
11+
def load(config: dict, **kwargs):
12+
steps = config["kwargs"]["steps"]
13+
steps = {k: load(v, **kwargs) for k, v in steps.items()}
14+
return RunnableParallel(**steps)
15+
16+
@staticmethod
17+
def save(obj):
18+
serialized = dumpd(obj)
19+
serialized["_type"] = RunnableParallelSerializer.type()
20+
return serialized

0 commit comments

Comments
 (0)