Skip to content

Commit 1290a4a

Browse files
committed
Update tests.
1 parent 4f21324 commit 1290a4a

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

ads/llm/guardrails/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def default_name(cls, values):
210210

211211
@classmethod
212212
def is_lc_serializable(cls) -> bool:
213+
"""This class is LangChain serializable."""
213214
return True
214215

215216
def _preprocess(self, input: Any) -> str:

tests/unitary/with_extras/langchain/test_guardrails.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,18 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7-
7+
import os
88
from typing import Any, List, Mapping, Optional
99
from unittest import TestCase
1010
from langchain.callbacks.manager import CallbackManagerForLLMRun
1111
from langchain.llms.base import LLM
12+
from langchain.prompts import PromptTemplate
13+
from langchain.schema.runnable import RunnableMap, RunnablePassthrough
14+
from langchain.load import dumpd
1215
from ads.llm.guardrails import HuggingFaceEvaluation
13-
from ads.llm.guardrails.base import BlockedByGuardrail
16+
from ads.llm.guardrails.base import BlockedByGuardrail, GuardrailIO
17+
from ads.llm.chain import GuardrailSequence
18+
from ads.llm.load import load
1419

1520

1621
class FakeLLM(LLM):
@@ -36,26 +41,46 @@ def _identifying_params(self) -> Mapping[str, Any]:
3641
"""Get the identifying parameters."""
3742
return {}
3843

44+
@classmethod
45+
def is_lc_serializable(cls) -> bool:
46+
"""This class is LangChain serializable."""
47+
return True
3948

40-
class ToxicityGuardrailTests(TestCase):
41-
"""Contains tests for the toxicity guardrail."""
49+
50+
class GuardrailTestsBase(TestCase):
51+
"""Base class for guardrail tests."""
4252

4353
TOXIC_CONTENT = "Women is not capable of this job."
54+
LOAD_ARGS = {"cache_dir": os.path.expanduser("~/.cache/huggingface/evaluate")}
4455
FAKE_LLM = FakeLLM()
4556

57+
58+
class ToxicityGuardrailTests(GuardrailTestsBase):
59+
"""Contains tests for the toxicity guardrail."""
60+
4661
def test_toxicity_without_threshold(self):
4762
"""When using guardrail alone with is no threshold, it does not do anything."""
48-
toxicity = HuggingFaceEvaluation(path="toxicity")
63+
toxicity = HuggingFaceEvaluation(path="toxicity", load_args=self.LOAD_ARGS)
4964
chain = self.FAKE_LLM | toxicity
5065
output = chain.invoke(self.TOXIC_CONTENT)
5166
self.assertEqual(output, self.TOXIC_CONTENT)
67+
serialized = dumpd(chain)
68+
chain = load(serialized, valid_namespaces=["tests"])
69+
output = chain.invoke(self.TOXIC_CONTENT)
70+
self.assertEqual(output, self.TOXIC_CONTENT)
5271

5372
def test_toxicity_with_threshold(self):
5473
"""Once we set a threshold, an exception will be raise for toxic output."""
55-
toxicity = HuggingFaceEvaluation(path="toxicity", threshold=0.2)
74+
toxicity = HuggingFaceEvaluation(
75+
path="toxicity", threshold=0.2, load_args=self.LOAD_ARGS
76+
)
5677
chain = self.FAKE_LLM | toxicity
5778
with self.assertRaises(BlockedByGuardrail):
5879
chain.invoke(self.TOXIC_CONTENT)
80+
serialized = dumpd(chain)
81+
chain = load(serialized, valid_namespaces=["tests"])
82+
with self.assertRaises(BlockedByGuardrail):
83+
chain.invoke(self.TOXIC_CONTENT)
5984

6085
def test_toxicity_without_exception(self):
6186
"""Guardrail can return the custom message instead of raising an exception."""
@@ -64,16 +89,43 @@ def test_toxicity_without_exception(self):
6489
threshold=0.2,
6590
raise_exception=False,
6691
custom_msg="Sorry, but let's discuss something else.",
92+
load_args=self.LOAD_ARGS,
6793
)
6894
chain = self.FAKE_LLM | toxicity
6995
output = chain.invoke(self.TOXIC_CONTENT)
7096
self.assertEqual(output, toxicity.custom_msg)
97+
serialized = dumpd(chain)
98+
chain = load(serialized, valid_namespaces=["tests"])
99+
output = chain.invoke(self.TOXIC_CONTENT)
100+
self.assertEqual(output, toxicity.custom_msg)
71101

72102
def test_toxicity_return_metrics(self):
73103
"""Return the toxicity metrics"""
74-
toxicity = HuggingFaceEvaluation(path="toxicity", return_metrics=True)
104+
toxicity = HuggingFaceEvaluation(
105+
path="toxicity", return_metrics=True, load_args=self.LOAD_ARGS
106+
)
75107
chain = self.FAKE_LLM | toxicity
76108
output = chain.invoke(self.TOXIC_CONTENT)
77109
self.assertIsInstance(output, dict)
78110
self.assertEqual(output["output"], self.TOXIC_CONTENT)
79111
self.assertGreater(output["metrics"]["toxicity"][0], 0.2)
112+
serialized = dumpd(chain)
113+
chain = load(serialized, valid_namespaces=["tests"])
114+
output = chain.invoke(self.TOXIC_CONTENT)
115+
self.assertIsInstance(output, dict)
116+
self.assertEqual(output["output"], self.TOXIC_CONTENT)
117+
self.assertGreater(output["metrics"]["toxicity"][0], 0.2)
118+
119+
120+
class GuardrailSequenceTests(GuardrailTestsBase):
121+
"""Contains tests for GuardrailSequence."""
122+
123+
def test_guardrail_sequence_with_template_and_toxicity(self):
124+
template = PromptTemplate.from_template("Tell me a joke about {subject}")
125+
map_input = RunnableMap(subject=RunnablePassthrough())
126+
toxicity = HuggingFaceEvaluation(path="toxicity", load_args=self.LOAD_ARGS)
127+
chain = GuardrailSequence.from_sequence(
128+
map_input | template | self.FAKE_LLM | toxicity
129+
)
130+
output = chain.run("cats")
131+
self.assertIsInstance(output, GuardrailIO)

0 commit comments

Comments
 (0)