Skip to content

Commit 5859927

Browse files
committed
Update test_guardrails.py
1 parent 220a6d6 commit 5859927

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/unitary/with_extras/langchain/test_guardrails.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# Copyright (c) 2023 Oracle and/or its affiliates.
55
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
66

7+
import json
78
import os
9+
import tempfile
810
from typing import Any, List, Dict, Mapping, Optional
911
from unittest import TestCase
1012
from langchain.callbacks.manager import CallbackManagerForLLMRun
@@ -172,3 +174,31 @@ def test_fn(chain: GuardrailSequence):
172174
self.assertEqual(len(output.info), len(chain.steps))
173175

174176
self.assert_before_and_after_serialization(test_fn, chain)
177+
178+
def test_empty_sequence(self):
179+
"""Tests empty sequence."""
180+
seq = GuardrailSequence()
181+
self.assertEqual(seq.steps, [])
182+
183+
def test_save_to_file(self):
184+
"""Tests saving to file."""
185+
message = "Let's talk something else."
186+
toxicity = HuggingFaceEvaluation(
187+
path="toxicity",
188+
load_args=self.LOAD_ARGS,
189+
threshold=0.5,
190+
custom_msg=message,
191+
)
192+
chain = GuardrailSequence.from_sequence(self.FAKE_LLM | toxicity)
193+
try:
194+
temp = tempfile.NamedTemporaryFile(suffix=".json", delete=False)
195+
temp.close()
196+
with self.assertRaises(FileExistsError):
197+
serialized = chain.save(temp.name)
198+
with self.assertRaises(ValueError):
199+
chain.save("abc.html")
200+
serialized = chain.save(temp.name, overwrite=True)
201+
with open(temp.name, "r", encoding="utf-8") as f:
202+
self.assertEqual(json.load(f), serialized)
203+
finally:
204+
os.unlink(temp.name)

0 commit comments

Comments
 (0)