4
4
# Copyright (c) 2023 Oracle and/or its affiliates.
5
5
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6
6
7
-
7
+ import os
8
8
from typing import Any , List , Mapping , Optional
9
9
from unittest import TestCase
10
10
from langchain .callbacks .manager import CallbackManagerForLLMRun
11
11
from langchain .llms .base import LLM
12
+ from langchain .prompts import PromptTemplate
13
+ from langchain .schema .runnable import RunnableMap , RunnablePassthrough
14
+ from langchain .load import dumpd
12
15
from ads .llm .guardrails import HuggingFaceEvaluation
13
- from ads .llm .guardrails .base import BlockedByGuardrail
16
+ from ads .llm .guardrails .base import BlockedByGuardrail , GuardrailIO
17
+ from ads .llm .chain import GuardrailSequence
18
+ from ads .llm .load import load
14
19
15
20
16
21
class FakeLLM (LLM ):
@@ -36,26 +41,46 @@ def _identifying_params(self) -> Mapping[str, Any]:
36
41
"""Get the identifying parameters."""
37
42
return {}
38
43
44
+ @classmethod
45
+ def is_lc_serializable (cls ) -> bool :
46
+ """This class is LangChain serializable."""
47
+ return True
39
48
40
- class ToxicityGuardrailTests (TestCase ):
41
- """Contains tests for the toxicity guardrail."""
49
+
50
+ class GuardrailTestsBase (TestCase ):
51
+ """Base class for guardrail tests."""
42
52
43
53
TOXIC_CONTENT = "Women is not capable of this job."
54
+ LOAD_ARGS = {"cache_dir" : os .path .expanduser ("~/.cache/huggingface/evaluate" )}
44
55
FAKE_LLM = FakeLLM ()
45
56
57
+
58
+ class ToxicityGuardrailTests (GuardrailTestsBase ):
59
+ """Contains tests for the toxicity guardrail."""
60
+
46
61
def test_toxicity_without_threshold (self ):
47
62
"""When using guardrail alone with is no threshold, it does not do anything."""
48
- toxicity = HuggingFaceEvaluation (path = "toxicity" )
63
+ toxicity = HuggingFaceEvaluation (path = "toxicity" , load_args = self . LOAD_ARGS )
49
64
chain = self .FAKE_LLM | toxicity
50
65
output = chain .invoke (self .TOXIC_CONTENT )
51
66
self .assertEqual (output , self .TOXIC_CONTENT )
67
+ serialized = dumpd (chain )
68
+ chain = load (serialized , valid_namespaces = ["tests" ])
69
+ output = chain .invoke (self .TOXIC_CONTENT )
70
+ self .assertEqual (output , self .TOXIC_CONTENT )
52
71
53
72
def test_toxicity_with_threshold (self ):
54
73
"""Once we set a threshold, an exception will be raise for toxic output."""
55
- toxicity = HuggingFaceEvaluation (path = "toxicity" , threshold = 0.2 )
74
+ toxicity = HuggingFaceEvaluation (
75
+ path = "toxicity" , threshold = 0.2 , load_args = self .LOAD_ARGS
76
+ )
56
77
chain = self .FAKE_LLM | toxicity
57
78
with self .assertRaises (BlockedByGuardrail ):
58
79
chain .invoke (self .TOXIC_CONTENT )
80
+ serialized = dumpd (chain )
81
+ chain = load (serialized , valid_namespaces = ["tests" ])
82
+ with self .assertRaises (BlockedByGuardrail ):
83
+ chain .invoke (self .TOXIC_CONTENT )
59
84
60
85
def test_toxicity_without_exception (self ):
61
86
"""Guardrail can return the custom message instead of raising an exception."""
@@ -64,16 +89,43 @@ def test_toxicity_without_exception(self):
64
89
threshold = 0.2 ,
65
90
raise_exception = False ,
66
91
custom_msg = "Sorry, but let's discuss something else." ,
92
+ load_args = self .LOAD_ARGS ,
67
93
)
68
94
chain = self .FAKE_LLM | toxicity
69
95
output = chain .invoke (self .TOXIC_CONTENT )
70
96
self .assertEqual (output , toxicity .custom_msg )
97
+ serialized = dumpd (chain )
98
+ chain = load (serialized , valid_namespaces = ["tests" ])
99
+ output = chain .invoke (self .TOXIC_CONTENT )
100
+ self .assertEqual (output , toxicity .custom_msg )
71
101
72
102
def test_toxicity_return_metrics (self ):
73
103
"""Return the toxicity metrics"""
74
- toxicity = HuggingFaceEvaluation (path = "toxicity" , return_metrics = True )
104
+ toxicity = HuggingFaceEvaluation (
105
+ path = "toxicity" , return_metrics = True , load_args = self .LOAD_ARGS
106
+ )
75
107
chain = self .FAKE_LLM | toxicity
76
108
output = chain .invoke (self .TOXIC_CONTENT )
77
109
self .assertIsInstance (output , dict )
78
110
self .assertEqual (output ["output" ], self .TOXIC_CONTENT )
79
111
self .assertGreater (output ["metrics" ]["toxicity" ][0 ], 0.2 )
112
+ serialized = dumpd (chain )
113
+ chain = load (serialized , valid_namespaces = ["tests" ])
114
+ output = chain .invoke (self .TOXIC_CONTENT )
115
+ self .assertIsInstance (output , dict )
116
+ self .assertEqual (output ["output" ], self .TOXIC_CONTENT )
117
+ self .assertGreater (output ["metrics" ]["toxicity" ][0 ], 0.2 )
118
+
119
+
120
+ class GuardrailSequenceTests (GuardrailTestsBase ):
121
+ """Contains tests for GuardrailSequence."""
122
+
123
+ def test_guardrail_sequence_with_template_and_toxicity (self ):
124
+ template = PromptTemplate .from_template ("Tell me a joke about {subject}" )
125
+ map_input = RunnableMap (subject = RunnablePassthrough ())
126
+ toxicity = HuggingFaceEvaluation (path = "toxicity" , load_args = self .LOAD_ARGS )
127
+ chain = GuardrailSequence .from_sequence (
128
+ map_input | template | self .FAKE_LLM | toxicity
129
+ )
130
+ output = chain .run ("cats" )
131
+ self .assertIsInstance (output , GuardrailIO )
0 commit comments