Skip to content

Commit 9e62306

Browse files
committed
Update serialization.
1 parent 1290a4a commit 9e62306

File tree

3 files changed

+46
-11
lines changed

3 files changed

+46
-11
lines changed

ads/llm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
from ads.llm.langchain.plugins.llm_md import ModelDeploymentTGI
1010
from ads.llm.langchain.plugins.llm_md import ModelDeploymentVLLM
1111
from ads.llm.langchain.plugins.embeddings import GenerativeAIEmbeddings
12-
from ads.llm.load import load
12+
from ads.llm.serialize import load

ads/llm/load.py renamed to ads/llm/serialize.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import json
78
from typing import Any, Dict, List, Optional
89
from langchain.load.load import load as lc_load
10+
from langchain.load.serializable import Serializable
911

1012

1113
def load(
@@ -14,10 +16,10 @@ def load(
1416
secrets_map: Optional[Dict[str, str]] = None,
1517
valid_namespaces: Optional[List[str]] = None,
1618
) -> Any:
17-
"""Revive a LangChain class from a JSON object. Use this if you already
19+
"""Revive an ADS/LangChain class from a JSON object. Use this if you already
1820
have a parsed JSON object, eg. from `json.load` or `json.loads`.
1921
20-
This is a drop in replacement for load() in langchain.load.load to support loading ADS LangChain compatible component.
22+
This is a drop in replacement for load() in langchain.load.load to support loading compatible class from ADS.
2123
2224
Args:
2325
obj: The object to load.
@@ -33,3 +35,35 @@ def load(
3335
if "ads" not in valid_namespaces:
3436
valid_namespaces.append("ads")
3537
return lc_load(obj, secrets_map=secrets_map, valid_namespaces=valid_namespaces)
38+
39+
40+
def default(obj: Any) -> Any:
41+
"""Calls the to_json() method to serialize the object.
42+
43+
Parameters
44+
----------
45+
obj : Any
46+
The object to be serialized.
47+
48+
Returns
49+
-------
50+
Any
51+
The serialized representation of the object.
52+
53+
Raises
54+
------
55+
TypeError
56+
If the object is not LangChain serializable.
57+
"""
58+
if isinstance(obj, Serializable) and obj.is_lc_serializable():
59+
return obj.to_json()
60+
raise TypeError(f"Serialization of {type(obj)} is not supported.")
61+
62+
63+
def dump(obj: Any) -> Dict[str, Any]:
64+
"""Return a json dict representation of an object.
65+
66+
This is a drop in replacement of the dumpd() in langchain.load.dump,
67+
except that this method will raise TypeError when the object is not serializable.
68+
"""
69+
return json.loads(json.dumps(obj, default=default))

tests/unitary/with_extras/langchain/test_guardrails.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
from langchain.llms.base import LLM
1212
from langchain.prompts import PromptTemplate
1313
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
14-
from langchain.load import dumpd
1514
from ads.llm.guardrails import HuggingFaceEvaluation
1615
from ads.llm.guardrails.base import BlockedByGuardrail, GuardrailIO
1716
from ads.llm.chain import GuardrailSequence
18-
from ads.llm.load import load
17+
from ads.llm.serialize import load, dump
1918

2019

2120
class FakeLLM(LLM):
@@ -64,7 +63,7 @@ def test_toxicity_without_threshold(self):
6463
chain = self.FAKE_LLM | toxicity
6564
output = chain.invoke(self.TOXIC_CONTENT)
6665
self.assertEqual(output, self.TOXIC_CONTENT)
67-
serialized = dumpd(chain)
66+
serialized = dump(chain)
6867
chain = load(serialized, valid_namespaces=["tests"])
6968
output = chain.invoke(self.TOXIC_CONTENT)
7069
self.assertEqual(output, self.TOXIC_CONTENT)
@@ -77,7 +76,7 @@ def test_toxicity_with_threshold(self):
7776
chain = self.FAKE_LLM | toxicity
7877
with self.assertRaises(BlockedByGuardrail):
7978
chain.invoke(self.TOXIC_CONTENT)
80-
serialized = dumpd(chain)
79+
serialized = dump(chain)
8180
chain = load(serialized, valid_namespaces=["tests"])
8281
with self.assertRaises(BlockedByGuardrail):
8382
chain.invoke(self.TOXIC_CONTENT)
@@ -94,7 +93,7 @@ def test_toxicity_without_exception(self):
9493
chain = self.FAKE_LLM | toxicity
9594
output = chain.invoke(self.TOXIC_CONTENT)
9695
self.assertEqual(output, toxicity.custom_msg)
97-
serialized = dumpd(chain)
96+
serialized = dump(chain)
9897
chain = load(serialized, valid_namespaces=["tests"])
9998
output = chain.invoke(self.TOXIC_CONTENT)
10099
self.assertEqual(output, toxicity.custom_msg)
@@ -109,7 +108,7 @@ def test_toxicity_return_metrics(self):
109108
self.assertIsInstance(output, dict)
110109
self.assertEqual(output["output"], self.TOXIC_CONTENT)
111110
self.assertGreater(output["metrics"]["toxicity"][0], 0.2)
112-
serialized = dumpd(chain)
111+
serialized = dump(chain)
113112
chain = load(serialized, valid_namespaces=["tests"])
114113
output = chain.invoke(self.TOXIC_CONTENT)
115114
self.assertIsInstance(output, dict)
@@ -123,9 +122,11 @@ class GuardrailSequenceTests(GuardrailTestsBase):
123122
def test_guardrail_sequence_with_template_and_toxicity(self):
124123
template = PromptTemplate.from_template("Tell me a joke about {subject}")
125124
map_input = RunnableMap(subject=RunnablePassthrough())
126-
toxicity = HuggingFaceEvaluation(path="toxicity", load_args=self.LOAD_ARGS)
125+
toxicity = HuggingFaceEvaluation(
126+
path="toxicity", load_args=self.LOAD_ARGS, select=min
127+
)
127128
chain = GuardrailSequence.from_sequence(
128129
map_input | template | self.FAKE_LLM | toxicity
129130
)
130-
output = chain.run("cats")
131+
output = chain.run("cats", num_generations=5)
131132
self.assertIsInstance(output, GuardrailIO)

0 commit comments

Comments
 (0)