Skip to content

Commit 0bb8974

Browse files
committed
Merge branch 'feature/guardrails' of https://github.com/oracle/accelerated-data-science into add_langchain_deployment
2 parents 2a591f7 + 8b2c166 commit 0bb8974

File tree

5 files changed

+65
-21
lines changed

5 files changed

+65
-21
lines changed

ads/llm/chain.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def run(self, input: Any, num_generations: int = 1, **kwargs) -> GuardrailIO:
178178
if self.log_info or os.environ.get(LOG_ADS_GUARDRAIL_INFO) == "1":
179179
# LOG_ADS_GUARDRAIL_INFO is set to "1" in score.py by default.
180180
print(obj.dict())
181+
# If the output is a singleton list, take it out of the list.
182+
if isinstance(obj.data, list) and len(obj.data) == 1:
183+
obj.data = obj.data[0]
181184
return obj
182185

183186
def _save_to_file(self, chain_dict, filename, overwrite=False):

ads/llm/guardrails/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def __repr__(self) -> str:
7878
steps.append(f"{run_info.name} - {run_info.metrics}")
7979
if run_info:
8080
steps.append(str(run_info.output))
81-
return "\n".join(steps)
81+
return "\n".join(steps) + "\n\n" + str(self)
8282

8383

8484
class BlockedByGuardrail(ToolException):

ads/llm/langchain/plugins/base.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _print_response(self, completion, response):
4646

4747
@classmethod
4848
def get_lc_namespace(cls) -> List[str]:
49-
"""Get the namespace of the langchain object."""
49+
"""Get the namespace of the LangChain object."""
5050
return ["ads", "llm"]
5151

5252
@classmethod
@@ -56,10 +56,12 @@ def is_lc_serializable(cls) -> bool:
5656

5757

5858
class GenerativeAiClientModel(BaseModel):
59+
"""Base model for generative AI embedding model and LLM."""
60+
5961
client: Any #: :meta private:
6062
"""OCI GenerativeAiClient."""
6163

62-
compartment_id: str
64+
compartment_id: str = None
6365
"""Compartment ID of the caller."""
6466

6567
endpoint_kwargs: Dict[str, Any] = {}
@@ -90,7 +92,9 @@ def validate_environment( # pylint: disable=no-self-argument
9092
client_kwargs.update(values["client_kwargs"])
9193
values["client"] = GenerativeAiClient(**auth, **client_kwargs)
9294
# Set default compartment ID
93-
if "compartment_id" not in values and COMPARTMENT_OCID:
94-
values["compartment_id"] = COMPARTMENT_OCID
95-
95+
if not values.get("compartment_id"):
96+
if COMPARTMENT_OCID:
97+
values["compartment_id"] = COMPARTMENT_OCID
98+
else:
99+
raise ValueError("Please specify compartment_id.")
96100
return values

ads/llm/serialize.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from langchain import llms
1515
from langchain.llms import loading
1616
from langchain.chains.loading import load_chain_from_config
17-
from langchain.load.load import load as __lc_load
17+
from langchain.load.load import Reviver, load as __lc_load
1818
from langchain.load.serializable import Serializable
1919

2020
from ads.common.auth import default_signer
@@ -76,14 +76,32 @@ def load(
7676
Returns:
7777
Revived LangChain objects.
7878
"""
79+
# Add ADS as valid namespace
7980
if not valid_namespaces:
8081
valid_namespaces = []
8182
if "ads" not in valid_namespaces:
8283
valid_namespaces.append("ads")
8384

85+
reviver = Reviver(secrets_map, valid_namespaces)
86+
87+
def _load(obj: Any) -> Any:
88+
if isinstance(obj, dict):
89+
if "_type" in obj and obj["_type"] in custom_deserialization:
90+
if valid_namespaces:
91+
kwargs["valid_namespaces"] = valid_namespaces
92+
if secrets_map:
93+
kwargs["secret_map"] = secrets_map
94+
return custom_deserialization[obj["_type"]](obj, **kwargs)
95+
# Need to revive leaf nodes before reviving this node
96+
loaded_obj = {k: _load(v) for k, v in obj.items()}
97+
return reviver(loaded_obj)
98+
if isinstance(obj, list):
99+
return [_load(o) for o in obj]
100+
return obj
101+
84102
if isinstance(obj, dict) and "_type" in obj:
85103
obj_type = obj["_type"]
86-
# Check if the object requires a custom function to load.
104+
# Check if the object has custom load function.
87105
if obj_type in custom_deserialization:
88106
if valid_namespaces:
89107
kwargs["valid_namespaces"] = valid_namespaces
@@ -93,7 +111,7 @@ def load(
93111
# Legacy chain
94112
return load_chain_from_config(obj, **kwargs)
95113

96-
return __lc_load(obj, secrets_map=secrets_map, valid_namespaces=valid_namespaces)
114+
return _load(obj)
97115

98116

99117
def load_from_yaml(
@@ -144,11 +162,30 @@ def default(obj: Any) -> Any:
144162
TypeError
145163
If the object is not LangChain serializable.
146164
"""
165+
for super_class, save_fn in custom_serialization.items():
166+
if isinstance(obj, super_class):
167+
return save_fn(obj)
147168
if isinstance(obj, Serializable) and obj.is_lc_serializable():
148169
return obj.to_json()
149170
raise TypeError(f"Serialization of {type(obj)} is not supported.")
150171

151172

173+
def __save(obj):
174+
"""Calls the legacy save method to save the object to temp json
175+
then load it into a dictionary.
176+
"""
177+
try:
178+
temp_file = tempfile.NamedTemporaryFile(
179+
mode="w", encoding="utf-8", suffix=".json", delete=False
180+
)
181+
temp_file.close()
182+
obj.save(temp_file.name)
183+
with open(temp_file.name, "r", encoding="utf-8") as f:
184+
return json.load(f)
185+
finally:
186+
os.unlink(temp_file.name)
187+
188+
152189
def dump(obj: Any) -> Dict[str, Any]:
153190
"""Return a json dict representation of an object.
154191
@@ -167,14 +204,14 @@ def dump(obj: Any) -> Dict[str, Any]:
167204
):
168205
# The object is not is_lc_serializable.
169206
# However, it supports the legacy save() method.
170-
try:
171-
temp_file = tempfile.NamedTemporaryFile(
172-
mode="w", encoding="utf-8", suffix=".json", delete=False
173-
)
174-
temp_file.close()
175-
obj.save(temp_file.name)
176-
with open(temp_file.name, "r", encoding="utf-8") as f:
177-
return json.load(f)
178-
finally:
179-
os.unlink(temp_file.name)
180-
return json.loads(json.dumps(obj, default=default))
207+
return __save(obj)
208+
# The object is is_lc_serializable.
209+
# However, some properties may not be serializable
210+
# Here we try to dump the object and fallback to the save() method
211+
# if there is an error.
212+
try:
213+
return json.loads(json.dumps(obj, default=default))
214+
except TypeError as ex:
215+
if isinstance(obj, Serializable) and hasattr(obj, "save"):
216+
return __save(obj)
217+
raise ex

ads/llm/templates/score_chain.jinja2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,5 +165,5 @@ def predict(data, model=load_model(), input_schema_path=os.path.join(os.path.dir
165165

166166
"""
167167
features = pre_inference(data, input_schema_path)
168-
yhat = post_inference(model.run(features))
168+
yhat = post_inference(model.invoke(features))
169169
return {'prediction': yhat}

0 commit comments

Comments
 (0)