|
| 1 | +#!/usr/bin/env python |
| 2 | +# -*- coding: utf-8 -*-- |
| 3 | + |
| 4 | +# Copyright (c) 2023 Oracle and/or its affiliates. |
| 5 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 6 | + |
| 7 | + |
| 8 | +import json |
| 9 | +import logging |
| 10 | +import os |
| 11 | +import pathlib |
| 12 | +from typing import Any, List, Optional |
| 13 | + |
| 14 | +import yaml |
| 15 | +from langchain.llms.base import LLM |
| 16 | +from langchain.schema.runnable import ( |
| 17 | + Runnable, |
| 18 | + RunnableConfig, |
| 19 | + RunnableSequence, |
| 20 | +) |
| 21 | +from ads.llm.guardrails.base import GuardrailIO, Guardrail, RunInfo, BlockedByGuardrail |
| 22 | + |
| 23 | + |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | +SPEC_CHAIN_TYPE = "_type" |
| 26 | +SPEC_CHAIN = "chain" |
| 27 | +LOG_ADS_GUARDRAIL_INFO = "LOG_ADS_GUARDRAIL_INFO" |
| 28 | + |
| 29 | + |
| 30 | +class GuardrailSequence(RunnableSequence): |
| 31 | + """Represents a sequence of guardrails and other LangChain (non-guardrail) components.""" |
| 32 | + |
| 33 | + first: Optional[Runnable] = None |
| 34 | + last: Optional[Runnable] = None |
| 35 | + |
| 36 | + raise_exception: bool = False |
| 37 | + """The ``raise_exception`` property indicate whether an exception should be raised |
| 38 | + if the content is blocked by one of the guardrails. |
| 39 | + This property is set to ``False`` by default. |
| 40 | + Note that each guardrail also has its own ``raise_exception`` property. |
| 41 | + This property on GuardrailSequence has no effect |
| 42 | + when the ``raise_exception`` is set to False on the individual guardrail. |
| 43 | +
|
| 44 | + When this is ``False``, instead of raising an exception, |
| 45 | + the custom message from the guardrail will be returned as the output. |
| 46 | +
|
| 47 | + When this is ``True``, the ``BlockedByGuardrail`` exception from the guardrail will be raised. |
| 48 | + """ |
| 49 | + |
| 50 | + log_info: bool = False |
| 51 | + """Indicate whether to print the run info at the end of each invocation. |
| 52 | + This option can also be turned on if the environment variable LOG_ADS_GUARDRAIL_INFO is set to "1". |
| 53 | + """ |
| 54 | + |
| 55 | + max_retry: int = 1 |
| 56 | + """Maximum number of retry for running the Guardrail sequence again if the output is blocked by a guardrail.""" |
| 57 | + |
| 58 | + @property |
| 59 | + def steps(self) -> List[Runnable[Any, Any]]: |
| 60 | + """Steps in the sequence.""" |
| 61 | + if self.first: |
| 62 | + chain = [self.first] + self.middle |
| 63 | + else: |
| 64 | + return [] |
| 65 | + if self.last: |
| 66 | + chain += [self.last] |
| 67 | + return chain |
| 68 | + |
| 69 | + @staticmethod |
| 70 | + def type() -> str: |
| 71 | + """A unique identifier as type for serialization.""" |
| 72 | + return "ads_guardrail_sequence" |
| 73 | + |
| 74 | + @classmethod |
| 75 | + def from_sequence(cls, sequence: RunnableSequence, **kwargs): |
| 76 | + """Creates a GuardrailSequence from a LangChain runnable sequence.""" |
| 77 | + return cls( |
| 78 | + first=sequence.first, middle=sequence.middle, last=sequence.last, **kwargs |
| 79 | + ) |
| 80 | + |
| 81 | + def __or__(self, other) -> "GuardrailSequence": |
| 82 | + """Adds another component to the end of this sequence. |
| 83 | + If the sequence is empty, the component will be added as the first step of the sequence. |
| 84 | + """ |
| 85 | + if not self.first: |
| 86 | + return GuardrailSequence(first=other) |
| 87 | + if not self.last: |
| 88 | + return GuardrailSequence(first=self.first, last=other) |
| 89 | + return self.from_sequence(super().__or__(other)) |
| 90 | + |
| 91 | + def __ror__(self, other) -> "GuardrailSequence": |
| 92 | + """Chain this sequence to the end of another component.""" |
| 93 | + return self.from_sequence(super().__ror__(other)) |
| 94 | + |
| 95 | + def invoke(self, input: Any, config: RunnableConfig = None) -> GuardrailIO: |
| 96 | + """Invokes the guardrail. |
| 97 | +
|
| 98 | + In LangChain interface, invoke() is designed for calling the chain with a single input, |
| 99 | + while batch() is designed for calling the chain with a list of inputs. |
| 100 | + https://python.langchain.com/docs/expression_language/interface |
| 101 | +
|
| 102 | + """ |
| 103 | + return self.run(input) |
| 104 | + |
| 105 | + def _invoke_llm(self, llm: LLM, texts: list, num_generations: int, **kwargs): |
| 106 | + if num_generations > 1: |
| 107 | + if len(texts) > 1: |
| 108 | + raise NotImplementedError( |
| 109 | + "Batch completion with more than 1 prompt is not supported." |
| 110 | + ) |
| 111 | + # TODO: invoke in parallel |
| 112 | + # TODO: let llm generate n completions. |
| 113 | + output = [llm.invoke(texts[0], **kwargs) for _ in range(num_generations)] |
| 114 | + else: |
| 115 | + output = llm.batch(texts, **kwargs) |
| 116 | + return output |
| 117 | + |
| 118 | + def _run_step( |
| 119 | + self, step: Runnable, obj: GuardrailIO, num_generations: int, **kwargs |
| 120 | + ): |
| 121 | + if not isinstance(step, Guardrail): |
| 122 | + # Invoke the step as a LangChain component |
| 123 | + spec = {} |
| 124 | + with RunInfo(name=step.__class__.__name__, input=obj.data) as info: |
| 125 | + if isinstance(step, LLM): |
| 126 | + output = self._invoke_llm(step, obj.data, num_generations, **kwargs) |
| 127 | + spec.update(kwargs) |
| 128 | + spec["num_generations"] = num_generations |
| 129 | + else: |
| 130 | + output = step.batch(obj.data) |
| 131 | + info.output = output |
| 132 | + info.parameters = { |
| 133 | + "class": step.__class__.__name__, |
| 134 | + "path": step.__module__, |
| 135 | + "spec": spec, |
| 136 | + } |
| 137 | + obj.info.append(info) |
| 138 | + obj.data = output |
| 139 | + else: |
| 140 | + obj = step.invoke(obj) |
| 141 | + return obj |
| 142 | + |
| 143 | + def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO: |
| 144 | + """Runs the guardrail sequence. |
| 145 | +
|
| 146 | + Parameters |
| 147 | + ---------- |
| 148 | + input : Any |
| 149 | + Input for the guardrail sequence. |
| 150 | + This will be the input for the first step in the sequence. |
| 151 | + num_generations : int, optional |
| 152 | + The number of completions to be generated by the LLM, by default 1. |
| 153 | +
|
| 154 | + The kwargs will be passed to LLM step(s) in the guardrail sequence. |
| 155 | +
|
| 156 | + Returns |
| 157 | + ------- |
| 158 | + GuardrailIO |
| 159 | + Contains the outputs and metrics from each step. |
| 160 | + The final output is stored in GuardrailIO.data property. |
| 161 | + """ |
| 162 | + retry_count = 0 |
| 163 | + while True: |
| 164 | + retry_count += 1 |
| 165 | + obj = GuardrailIO(data=[input]) |
| 166 | + try: |
| 167 | + for i, step in enumerate(self.steps): |
| 168 | + obj = self._run_step(step, obj, num_generations, **kwargs) |
| 169 | + break |
| 170 | + except BlockedByGuardrail as ex: |
| 171 | + if retry_count < self.max_retry: |
| 172 | + continue |
| 173 | + if self.raise_exception: |
| 174 | + raise ex |
| 175 | + obj.data = [ex.message] |
| 176 | + obj.info.append(ex.info) |
| 177 | + break |
| 178 | + if self.log_info or os.environ.get(LOG_ADS_GUARDRAIL_INFO) == "1": |
| 179 | + # LOG_ADS_GUARDRAIL_INFO is set to "1" in score.py by default. |
| 180 | + print(obj.dict()) |
| 181 | + # If the output is a singleton list, take it out of the list. |
| 182 | + if isinstance(obj.data, list) and len(obj.data) == 1: |
| 183 | + obj.data = obj.data[0] |
| 184 | + return obj |
| 185 | + |
| 186 | + def _save_to_file(self, chain_dict, filename, overwrite=False): |
| 187 | + expanded_path = os.path.expanduser(filename) |
| 188 | + if os.path.isfile(expanded_path) and not overwrite: |
| 189 | + raise FileExistsError( |
| 190 | + f"File {expanded_path} already exists." |
| 191 | + "Set overwrite to True if you would like to overwrite the file." |
| 192 | + ) |
| 193 | + |
| 194 | + file_ext = pathlib.Path(expanded_path).suffix.lower() |
| 195 | + if file_ext not in [".yaml", ".json"]: |
| 196 | + raise ValueError( |
| 197 | + f"{self.__class__.__name__} can only be saved as yaml or json format." |
| 198 | + ) |
| 199 | + with open(expanded_path, "w", encoding="utf-8") as f: |
| 200 | + if file_ext == ".yaml": |
| 201 | + yaml.safe_dump(chain_dict, f, default_flow_style=False) |
| 202 | + elif file_ext == ".json": |
| 203 | + json.dump(chain_dict, f) |
| 204 | + |
| 205 | + def save(self, filename: str = None, overwrite: bool = False) -> dict: |
| 206 | + """Serialize the sequence to a dictionary. |
| 207 | + Optionally, save the sequence into a JSON or YAML file. |
| 208 | +
|
| 209 | + The dictionary will look like the following:: |
| 210 | +
|
| 211 | + { |
| 212 | + "_type": "ads_guardrail_sequence", |
| 213 | + "chain": [ |
| 214 | + ... |
| 215 | + ] |
| 216 | + } |
| 217 | +
|
| 218 | + where ``chain`` contains a list of steps. |
| 219 | +
|
| 220 | + Parameters |
| 221 | + ---------- |
| 222 | + filename : str |
| 223 | + YAML or JSON filename to store the serialized sequence. |
| 224 | +
|
| 225 | + Returns |
| 226 | + ------- |
| 227 | + dict |
| 228 | + The sequence saved as a dictionary. |
| 229 | + """ |
| 230 | + from ads.llm.serialize import dump |
| 231 | + |
| 232 | + chain_spec = [] |
| 233 | + for step in self.steps: |
| 234 | + chain_spec.append(dump(step)) |
| 235 | + chain_dict = { |
| 236 | + SPEC_CHAIN_TYPE: self.type(), |
| 237 | + SPEC_CHAIN: chain_spec, |
| 238 | + } |
| 239 | + |
| 240 | + if filename: |
| 241 | + self._save_to_file(chain_dict, filename, overwrite) |
| 242 | + |
| 243 | + return chain_dict |
| 244 | + |
| 245 | + @classmethod |
| 246 | + def load(cls, chain_dict: dict, **kwargs) -> "GuardrailSequence": |
| 247 | + """Loads the sequence from a dictionary config. |
| 248 | +
|
| 249 | + Parameters |
| 250 | + ---------- |
| 251 | + chain_dict : dict |
| 252 | + A dictionary containing the key "chain". |
| 253 | + The value of "chain" should be a list of dictionary. |
| 254 | + Each dictionary corresponds to a step in the chain. |
| 255 | +
|
| 256 | + Returns |
| 257 | + ------- |
| 258 | + GuardrailSequence |
| 259 | + A GuardrailSequence loaded from the config. |
| 260 | + """ |
| 261 | + from ads.llm.serialize import load |
| 262 | + |
| 263 | + chain_spec = chain_dict[SPEC_CHAIN] |
| 264 | + chain = cls() |
| 265 | + for config in chain_spec: |
| 266 | + step = load(config, **kwargs) |
| 267 | + # Chain the step |
| 268 | + chain |= step |
| 269 | + return chain |
| 270 | + |
| 271 | + def __str__(self) -> str: |
| 272 | + return "\n".join([str(step.__class__) for step in self.steps]) |
0 commit comments