11
11
from langchain .llms .base import LLM
12
12
from langchain .prompts import PromptTemplate
13
13
from langchain .schema .runnable import RunnableMap , RunnablePassthrough
14
- from langchain .load import dumpd
15
14
from ads .llm .guardrails import HuggingFaceEvaluation
16
15
from ads .llm .guardrails .base import BlockedByGuardrail , GuardrailIO
17
16
from ads .llm .chain import GuardrailSequence
18
- from ads .llm .load import load
17
+ from ads .llm .serialize import load , dump
19
18
20
19
21
20
class FakeLLM (LLM ):
@@ -64,7 +63,7 @@ def test_toxicity_without_threshold(self):
64
63
chain = self .FAKE_LLM | toxicity
65
64
output = chain .invoke (self .TOXIC_CONTENT )
66
65
self .assertEqual (output , self .TOXIC_CONTENT )
67
- serialized = dumpd (chain )
66
+ serialized = dump (chain )
68
67
chain = load (serialized , valid_namespaces = ["tests" ])
69
68
output = chain .invoke (self .TOXIC_CONTENT )
70
69
self .assertEqual (output , self .TOXIC_CONTENT )
@@ -77,7 +76,7 @@ def test_toxicity_with_threshold(self):
77
76
chain = self .FAKE_LLM | toxicity
78
77
with self .assertRaises (BlockedByGuardrail ):
79
78
chain .invoke (self .TOXIC_CONTENT )
80
- serialized = dumpd (chain )
79
+ serialized = dump (chain )
81
80
chain = load (serialized , valid_namespaces = ["tests" ])
82
81
with self .assertRaises (BlockedByGuardrail ):
83
82
chain .invoke (self .TOXIC_CONTENT )
@@ -94,7 +93,7 @@ def test_toxicity_without_exception(self):
94
93
chain = self .FAKE_LLM | toxicity
95
94
output = chain .invoke (self .TOXIC_CONTENT )
96
95
self .assertEqual (output , toxicity .custom_msg )
97
- serialized = dumpd (chain )
96
+ serialized = dump (chain )
98
97
chain = load (serialized , valid_namespaces = ["tests" ])
99
98
output = chain .invoke (self .TOXIC_CONTENT )
100
99
self .assertEqual (output , toxicity .custom_msg )
@@ -109,7 +108,7 @@ def test_toxicity_return_metrics(self):
109
108
self .assertIsInstance (output , dict )
110
109
self .assertEqual (output ["output" ], self .TOXIC_CONTENT )
111
110
self .assertGreater (output ["metrics" ]["toxicity" ][0 ], 0.2 )
112
- serialized = dumpd (chain )
111
+ serialized = dump (chain )
113
112
chain = load (serialized , valid_namespaces = ["tests" ])
114
113
output = chain .invoke (self .TOXIC_CONTENT )
115
114
self .assertIsInstance (output , dict )
@@ -123,9 +122,11 @@ class GuardrailSequenceTests(GuardrailTestsBase):
123
122
def test_guardrail_sequence_with_template_and_toxicity (self ):
124
123
template = PromptTemplate .from_template ("Tell me a joke about {subject}" )
125
124
map_input = RunnableMap (subject = RunnablePassthrough ())
126
- toxicity = HuggingFaceEvaluation (path = "toxicity" , load_args = self .LOAD_ARGS )
125
+ toxicity = HuggingFaceEvaluation (
126
+ path = "toxicity" , load_args = self .LOAD_ARGS , select = min
127
+ )
127
128
chain = GuardrailSequence .from_sequence (
128
129
map_input | template | self .FAKE_LLM | toxicity
129
130
)
130
- output = chain .run ("cats" )
131
+ output = chain .run ("cats" , num_generations = 5 )
131
132
self .assertIsInstance (output , GuardrailIO )
0 commit comments