Skip to content

Commit cb91bfc

Browse files
authored
ADS LLM Integration (#473)
2 parents d0edb9f + fa938cd commit cb91bfc

32 files changed

+3143
-1
lines changed

ads/llm/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
try:
8+
import langchain
9+
from ads.llm.langchain.plugins.llm_gen_ai import GenerativeAI
10+
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
11+
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
12+
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
13+
except ImportError as ex:
14+
if ex.name == "langchain":
15+
raise ImportError(
16+
f"{ex.msg}\nPlease install/update langchain with `pip install langchain -U`"
17+
) from ex
18+
raise ex

ads/llm/chain.py

Lines changed: 272 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import json
9+
import logging
10+
import os
11+
import pathlib
12+
from typing import Any, List, Optional
13+
14+
import yaml
15+
from langchain.llms.base import LLM
16+
from langchain.schema.runnable import (
17+
Runnable,
18+
RunnableConfig,
19+
RunnableSequence,
20+
)
21+
from ads.llm.guardrails.base import GuardrailIO, Guardrail, RunInfo, BlockedByGuardrail
22+
23+
24+
logger = logging.getLogger(__name__)
25+
SPEC_CHAIN_TYPE = "_type"
26+
SPEC_CHAIN = "chain"
27+
LOG_ADS_GUARDRAIL_INFO = "LOG_ADS_GUARDRAIL_INFO"
28+
29+
30+
class GuardrailSequence(RunnableSequence):
31+
"""Represents a sequence of guardrails and other LangChain (non-guardrail) components."""
32+
33+
first: Optional[Runnable] = None
34+
last: Optional[Runnable] = None
35+
36+
raise_exception: bool = False
37+
"""The ``raise_exception`` property indicate whether an exception should be raised
38+
if the content is blocked by one of the guardrails.
39+
This property is set to ``False`` by default.
40+
Note that each guardrail also has its own ``raise_exception`` property.
41+
This property on GuardrailSequence has no effect
42+
when the ``raise_exception`` is set to False on the individual guardrail.
43+
44+
When this is ``False``, instead of raising an exception,
45+
the custom message from the guardrail will be returned as the output.
46+
47+
When this is ``True``, the ``BlockedByGuardrail`` exception from the guardrail will be raised.
48+
"""
49+
50+
log_info: bool = False
51+
"""Indicate whether to print the run info at the end of each invocation.
52+
This option can also be turned on if the environment variable LOG_ADS_GUARDRAIL_INFO is set to "1".
53+
"""
54+
55+
max_retry: int = 1
56+
"""Maximum number of retry for running the Guardrail sequence again if the output is blocked by a guardrail."""
57+
58+
@property
59+
def steps(self) -> List[Runnable[Any, Any]]:
60+
"""Steps in the sequence."""
61+
if self.first:
62+
chain = [self.first] + self.middle
63+
else:
64+
return []
65+
if self.last:
66+
chain += [self.last]
67+
return chain
68+
69+
@staticmethod
70+
def type() -> str:
71+
"""A unique identifier as type for serialization."""
72+
return "ads_guardrail_sequence"
73+
74+
@classmethod
75+
def from_sequence(cls, sequence: RunnableSequence, **kwargs):
76+
"""Creates a GuardrailSequence from a LangChain runnable sequence."""
77+
return cls(
78+
first=sequence.first, middle=sequence.middle, last=sequence.last, **kwargs
79+
)
80+
81+
def __or__(self, other) -> "GuardrailSequence":
82+
"""Adds another component to the end of this sequence.
83+
If the sequence is empty, the component will be added as the first step of the sequence.
84+
"""
85+
if not self.first:
86+
return GuardrailSequence(first=other)
87+
if not self.last:
88+
return GuardrailSequence(first=self.first, last=other)
89+
return self.from_sequence(super().__or__(other))
90+
91+
def __ror__(self, other) -> "GuardrailSequence":
92+
"""Chain this sequence to the end of another component."""
93+
return self.from_sequence(super().__ror__(other))
94+
95+
def invoke(self, input: Any, config: RunnableConfig = None) -> GuardrailIO:
96+
"""Invokes the guardrail.
97+
98+
In LangChain interface, invoke() is designed for calling the chain with a single input,
99+
while batch() is designed for calling the chain with a list of inputs.
100+
https://python.langchain.com/docs/expression_language/interface
101+
102+
"""
103+
return self.run(input)
104+
105+
def _invoke_llm(self, llm: LLM, texts: list, num_generations: int, **kwargs):
106+
if num_generations > 1:
107+
if len(texts) > 1:
108+
raise NotImplementedError(
109+
"Batch completion with more than 1 prompt is not supported."
110+
)
111+
# TODO: invoke in parallel
112+
# TODO: let llm generate n completions.
113+
output = [llm.invoke(texts[0], **kwargs) for _ in range(num_generations)]
114+
else:
115+
output = llm.batch(texts, **kwargs)
116+
return output
117+
118+
def _run_step(
119+
self, step: Runnable, obj: GuardrailIO, num_generations: int, **kwargs
120+
):
121+
if not isinstance(step, Guardrail):
122+
# Invoke the step as a LangChain component
123+
spec = {}
124+
with RunInfo(name=step.__class__.__name__, input=obj.data) as info:
125+
if isinstance(step, LLM):
126+
output = self._invoke_llm(step, obj.data, num_generations, **kwargs)
127+
spec.update(kwargs)
128+
spec["num_generations"] = num_generations
129+
else:
130+
output = step.batch(obj.data)
131+
info.output = output
132+
info.parameters = {
133+
"class": step.__class__.__name__,
134+
"path": step.__module__,
135+
"spec": spec,
136+
}
137+
obj.info.append(info)
138+
obj.data = output
139+
else:
140+
obj = step.invoke(obj)
141+
return obj
142+
143+
def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
144+
"""Runs the guardrail sequence.
145+
146+
Parameters
147+
----------
148+
input : Any
149+
Input for the guardrail sequence.
150+
This will be the input for the first step in the sequence.
151+
num_generations : int, optional
152+
The number of completions to be generated by the LLM, by default 1.
153+
154+
The kwargs will be passed to LLM step(s) in the guardrail sequence.
155+
156+
Returns
157+
-------
158+
GuardrailIO
159+
Contains the outputs and metrics from each step.
160+
The final output is stored in GuardrailIO.data property.
161+
"""
162+
retry_count = 0
163+
while True:
164+
retry_count += 1
165+
obj = GuardrailIO(data=[input])
166+
try:
167+
for i, step in enumerate(self.steps):
168+
obj = self._run_step(step, obj, num_generations, **kwargs)
169+
break
170+
except BlockedByGuardrail as ex:
171+
if retry_count < self.max_retry:
172+
continue
173+
if self.raise_exception:
174+
raise ex
175+
obj.data = [ex.message]
176+
obj.info.append(ex.info)
177+
break
178+
if self.log_info or os.environ.get(LOG_ADS_GUARDRAIL_INFO) == "1":
179+
# LOG_ADS_GUARDRAIL_INFO is set to "1" in score.py by default.
180+
print(obj.dict())
181+
# If the output is a singleton list, take it out of the list.
182+
if isinstance(obj.data, list) and len(obj.data) == 1:
183+
obj.data = obj.data[0]
184+
return obj
185+
186+
def _save_to_file(self, chain_dict, filename, overwrite=False):
187+
expanded_path = os.path.expanduser(filename)
188+
if os.path.isfile(expanded_path) and not overwrite:
189+
raise FileExistsError(
190+
f"File {expanded_path} already exists."
191+
"Set overwrite to True if you would like to overwrite the file."
192+
)
193+
194+
file_ext = pathlib.Path(expanded_path).suffix.lower()
195+
if file_ext not in [".yaml", ".json"]:
196+
raise ValueError(
197+
f"{self.__class__.__name__} can only be saved as yaml or json format."
198+
)
199+
with open(expanded_path, "w", encoding="utf-8") as f:
200+
if file_ext == ".yaml":
201+
yaml.safe_dump(chain_dict, f, default_flow_style=False)
202+
elif file_ext == ".json":
203+
json.dump(chain_dict, f)
204+
205+
def save(self, filename: str = None, overwrite: bool = False) -> dict:
206+
"""Serialize the sequence to a dictionary.
207+
Optionally, save the sequence into a JSON or YAML file.
208+
209+
The dictionary will look like the following::
210+
211+
{
212+
"_type": "ads_guardrail_sequence",
213+
"chain": [
214+
...
215+
]
216+
}
217+
218+
where ``chain`` contains a list of steps.
219+
220+
Parameters
221+
----------
222+
filename : str
223+
YAML or JSON filename to store the serialized sequence.
224+
225+
Returns
226+
-------
227+
dict
228+
The sequence saved as a dictionary.
229+
"""
230+
from ads.llm.serialize import dump
231+
232+
chain_spec = []
233+
for step in self.steps:
234+
chain_spec.append(dump(step))
235+
chain_dict = {
236+
SPEC_CHAIN_TYPE: self.type(),
237+
SPEC_CHAIN: chain_spec,
238+
}
239+
240+
if filename:
241+
self._save_to_file(chain_dict, filename, overwrite)
242+
243+
return chain_dict
244+
245+
@classmethod
246+
def load(cls, chain_dict: dict, **kwargs) -> "GuardrailSequence":
247+
"""Loads the sequence from a dictionary config.
248+
249+
Parameters
250+
----------
251+
chain_dict : dict
252+
A dictionary containing the key "chain".
253+
The value of "chain" should be a list of dictionary.
254+
Each dictionary corresponds to a step in the chain.
255+
256+
Returns
257+
-------
258+
GuardrailSequence
259+
A GuardrailSequence loaded from the config.
260+
"""
261+
from ads.llm.serialize import load
262+
263+
chain_spec = chain_dict[SPEC_CHAIN]
264+
chain = cls()
265+
for config in chain_spec:
266+
step = load(config, **kwargs)
267+
# Chain the step
268+
chain |= step
269+
return chain
270+
271+
def __str__(self) -> str:
272+
return "\n".join([str(step.__class__) for step in self.steps])

ads/llm/deploy.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
8+
import os
9+
import tempfile
10+
from datetime import datetime
11+
from typing import Any
12+
13+
import yaml
14+
from jinja2 import Environment, PackageLoader
15+
16+
from ads.model.artifact import ADS_VERSION, SCORE_VERSION
17+
from ads.model.generic_model import GenericModel
18+
from ads.llm.serialize import dump, load_from_yaml
19+
20+
21+
class ChainDeployment(GenericModel):
22+
def __init__(self, chain, **kwargs):
23+
self.chain = chain
24+
super().__init__(**kwargs)
25+
26+
def prepare(self, **kwargs) -> GenericModel:
27+
"""Prepares the model artifact."""
28+
chain_yaml_uri = os.path.join(self.artifact_dir, "chain.yaml")
29+
with open(chain_yaml_uri, "w", encoding="utf-8") as f:
30+
f.write(yaml.safe_dump(dump(self.chain)))
31+
32+
try:
33+
score_py = None
34+
if "score_py_uri" not in kwargs:
35+
with tempfile.NamedTemporaryFile(
36+
mode="w", suffix="score.py", delete=False
37+
) as score_py:
38+
env = Environment(loader=PackageLoader("ads", "llm/templates"))
39+
score_template = env.get_template("score_chain.jinja2")
40+
time_suffix = datetime.today().strftime("%Y%m%d_%H%M%S")
41+
42+
context = {
43+
"SCORE_VERSION": SCORE_VERSION,
44+
"ADS_VERSION": ADS_VERSION,
45+
"time_created": time_suffix,
46+
}
47+
score_py.write(score_template.render(context))
48+
49+
kwargs["score_py_uri"] = score_py.name
50+
return super().prepare(**kwargs)
51+
finally:
52+
if score_py:
53+
os.unlink(score_py.name)
54+
55+
@classmethod
56+
def load_chain(cls, yaml_uri: str, **kwargs) -> Any:
57+
return load_from_yaml(yaml_uri, **kwargs)

ads/llm/guardrails/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

0 commit comments

Comments
 (0)