Skip to content

Commit 15e412c

Browse files
committed
Merge commit '37e116e05aec5e32306b6ca5956b5515614ec9ca'
2 parents fa59003 + 37e116e commit 15e412c

File tree

4 files changed

+97
-106
lines changed

4 files changed

+97
-106
lines changed

.pre-commit-config.yaml

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,52 @@
11
repos:
2-
# Standard hooks
3-
- repo: https://github.com/pre-commit/pre-commit-hooks
4-
rev: v4.4.0
5-
hooks:
6-
- id: check-ast
7-
exclude: ^docs/
8-
- id: check-docstring-first
9-
exclude: ^(docs/|tests/)
10-
- id: check-json
11-
- id: check-merge-conflict
12-
- id: check-yaml
13-
exclude: feature_store_*.yaml
14-
args: ['--allow-multiple-documents']
15-
- id: detect-private-key
16-
- id: end-of-file-fixer
17-
exclude: '\.ipynb?$'
18-
- id: pretty-format-json
19-
args: ['--autofix']
20-
- id: trailing-whitespace
21-
args: [--markdown-linebreak-ext=md]
22-
exclude: ^docs/
23-
# Black, the code formatter, natively supports pre-commit
24-
- repo: https://github.com/psf/black
25-
rev: 23.3.0
26-
hooks:
27-
- id: black
28-
exclude: ^docs/
29-
# Regex based rst files common mistakes detector
30-
- repo: https://github.com/pre-commit/pygrep-hooks
31-
rev: v1.10.0
32-
hooks:
33-
- id: rst-backticks
34-
files: ^docs/
35-
- id: rst-inline-touching-normal
36-
files: ^docs/
37-
# Hardcoded secrets and ocids detector
38-
- repo: https://github.com/gitleaks/gitleaks
39-
rev: v8.17.0
40-
hooks:
41-
- id: gitleaks
42-
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml
43-
# Oracle copyright checker
44-
- repo: https://github.com/oracle-samples/oci-data-science-ai-samples/
45-
rev: cbe0136f7aaffe463b31ddf3f34b0e16b4b124ff
46-
hooks:
47-
- id: check-copyright
48-
name: check-copyright
49-
entry: .pre-commit-scripts/check-copyright.py
50-
language: script
51-
types_or: ['python', 'shell', 'bash']
52-
exclude: ^docs/
2+
# Standard hooks
3+
- repo: https://github.com/pre-commit/pre-commit-hooks
4+
rev: v4.4.0
5+
hooks:
6+
- id: check-ast
7+
exclude: ^docs/
8+
- id: check-docstring-first
9+
exclude: ^(docs/|tests/)
10+
- id: check-json
11+
- id: check-merge-conflict
12+
- id: check-yaml
13+
exclude: feature_store_*.yaml
14+
args: ["--allow-multiple-documents"]
15+
- id: detect-private-key
16+
- id: end-of-file-fixer
17+
exclude: '\.ipynb?$'
18+
- id: pretty-format-json
19+
args: ["--autofix"]
20+
- id: trailing-whitespace
21+
args: [--markdown-linebreak-ext=md]
22+
exclude: ^docs/
23+
# Black, the code formatter, natively supports pre-commit
24+
- repo: https://github.com/psf/black
25+
rev: 23.3.0
26+
hooks:
27+
- id: black
28+
exclude: ^docs/
29+
# Regex based rst files common mistakes detector
30+
- repo: https://github.com/pre-commit/pygrep-hooks
31+
rev: v1.10.0
32+
hooks:
33+
- id: rst-backticks
34+
files: ^docs/
35+
- id: rst-inline-touching-normal
36+
files: ^docs/
37+
# Hardcoded secrets and ocids detector
38+
- repo: https://github.com/gitleaks/gitleaks
39+
rev: v8.17.0
40+
hooks:
41+
- id: gitleaks
42+
exclude: .github/workflows/reusable-actions/set-dummy-conf.yml
43+
# Oracle copyright checker
44+
- repo: https://github.com/oracle-samples/oci-data-science-ai-samples/
45+
rev: 1bc5270a443b791c62f634233c0f4966dfcc0dd6
46+
hooks:
47+
- id: check-copyright
48+
name: check-copyright
49+
entry: .pre-commit-scripts/check-copyright.py
50+
language: script
51+
types_or: ["python", "shell", "bash"]
52+
exclude: ^docs/

ads/llm/chain.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -261,12 +261,8 @@ def load(cls, chain_dict: dict, **kwargs) -> "GuardrailSequence":
261261
from ads.llm.serialize import load
262262

263263
chain_spec = chain_dict[SPEC_CHAIN]
264-
chain = cls()
265-
for config in chain_spec:
266-
step = load(config, **kwargs)
267-
# Chain the step
268-
chain |= step
269-
return chain
264+
steps = [load(config, **kwargs) for config in chain_spec]
265+
return cls(*steps)
270266

271267
def __str__(self) -> str:
272268
return "\n".join([str(step.__class__) for step in self.steps])

tests/unitary/with_extras/langchain/test_guardrails.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,6 @@ def test_fn(chain: GuardrailSequence):
175175

176176
self.assert_before_and_after_serialization(test_fn, chain)
177177

178-
def test_empty_sequence(self):
179-
"""Tests empty sequence."""
180-
seq = GuardrailSequence()
181-
self.assertEqual(seq.steps, [])
182178

183179
def test_save_to_file(self):
184180
"""Tests saving to file."""

tests/unitary/with_extras/langchain/test_serialization.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -40,39 +40,6 @@ def setUp(self) -> None:
4040
GEN_AI_KWARGS = {"service_endpoint": "https://endpoint.oraclecloud.com"}
4141
ENDPOINT = "https://modeldeployment.customer-oci.com/ocid/predict"
4242

43-
EXPECTED_LLM_CHAIN_WITH_COHERE = {
44-
"memory": None,
45-
"verbose": True,
46-
"tags": None,
47-
"metadata": None,
48-
"prompt": {
49-
"input_variables": ["subject"],
50-
"input_types": {},
51-
"output_parser": None,
52-
"partial_variables": {},
53-
"template": "Tell me a joke about {subject}",
54-
"template_format": "f-string",
55-
"validate_template": False,
56-
"_type": "prompt",
57-
},
58-
"llm": {
59-
"model": None,
60-
"max_tokens": 256,
61-
"temperature": 0.75,
62-
"k": 0,
63-
"p": 1,
64-
"frequency_penalty": 0.0,
65-
"presence_penalty": 0.0,
66-
"truncate": None,
67-
"_type": "cohere",
68-
},
69-
"output_key": "text",
70-
"output_parser": {"_type": "default"},
71-
"return_final_only": True,
72-
"llm_kwargs": {},
73-
"_type": "llm_chain",
74-
}
75-
7643
EXPECTED_LLM_CHAIN_WITH_OCI_MD = {
7744
"lc": 1,
7845
"type": "constructor",
@@ -173,7 +140,23 @@ def test_llm_chain_serialization_with_cohere(self):
173140
template = PromptTemplate.from_template(self.PROMPT_TEMPLATE)
174141
llm_chain = LLMChain(prompt=template, llm=llm, verbose=True)
175142
serialized = dump(llm_chain)
176-
self.assertEqual(serialized, self.EXPECTED_LLM_CHAIN_WITH_COHERE)
143+
144+
# Check the serialized chain
145+
self.assertTrue(serialized.get("verbose"))
146+
self.assertEqual(serialized.get("_type"), "llm_chain")
147+
148+
# Check the serialized prompt template
149+
serialized_prompt = serialized.get("prompt")
150+
self.assertIsInstance(serialized_prompt, dict)
151+
self.assertEqual(serialized_prompt.get("_type"), "prompt")
152+
self.assertEqual(set(serialized_prompt.get("input_variables")), {"subject"})
153+
self.assertEqual(serialized_prompt.get("template"), self.PROMPT_TEMPLATE)
154+
155+
# Check the serialized LLM
156+
serialized_llm = serialized.get("llm")
157+
self.assertIsInstance(serialized_llm, dict)
158+
self.assertEqual(serialized_llm.get("_type"), "cohere")
159+
177160
llm_chain = load(serialized)
178161
self.assertIsInstance(llm_chain, LLMChain)
179162
self.assertIsInstance(llm_chain.prompt, PromptTemplate)
@@ -237,21 +220,37 @@ def test_runnable_sequence_serialization(self):
237220

238221
chain = map_input | template | llm
239222
serialized = dump(chain)
240-
# Do not check the ID fields.
241-
expected = deepcopy(self.EXPECTED_RUNNABLE_SEQUENCE)
242-
expected["id"] = serialized["id"]
243-
expected["kwargs"]["first"]["id"] = serialized["kwargs"]["first"]["id"]
244-
expected["kwargs"]["first"]["kwargs"]["steps"]["text"]["id"] = serialized[
245-
"kwargs"
246-
]["first"]["kwargs"]["steps"]["text"]["id"]
247-
expected["kwargs"]["middle"][0]["id"] = serialized["kwargs"]["middle"][0]["id"]
248-
self.assertEqual(serialized, expected)
223+
224+
self.assertEqual(serialized.get("type"), "constructor")
225+
self.assertNotIn("_type", serialized)
226+
227+
kwargs = serialized.get("kwargs")
228+
self.assertIsInstance(kwargs, dict)
229+
230+
element_1 = kwargs.get("first")
231+
self.assertEqual(element_1.get("_type"), "RunnableParallel")
232+
step = element_1.get("kwargs").get("steps").get("text")
233+
self.assertEqual(step.get("id")[-1], "RunnablePassthrough")
234+
235+
element_2 = kwargs.get("middle")[0]
236+
self.assertNotIn("_type", element_2)
237+
self.assertEqual(element_2.get("kwargs").get("template"), self.PROMPT_TEMPLATE)
238+
self.assertEqual(element_2.get("kwargs").get("input_variables"), ["subject"])
239+
240+
element_3 = kwargs.get("last")
241+
self.assertNotIn("_type", element_3)
242+
self.assertEqual(element_3.get("id"), ["ads", "llm", "ModelDeploymentTGI"])
243+
self.assertEqual(
244+
element_3.get("kwargs"),
245+
{"endpoint": "https://modeldeployment.customer-oci.com/ocid/predict"},
246+
)
247+
249248
chain = load(serialized)
250249
self.assertEqual(len(chain.steps), 3)
251250
self.assertIsInstance(chain.steps[0], RunnableParallel)
252251
self.assertEqual(
253-
chain.steps[0].dict(),
254-
{"steps": {"text": {"input_type": None, "func": None, "afunc": None}}},
252+
list(chain.steps[0].dict().get("steps").keys()),
253+
["text"],
255254
)
256255
self.assertIsInstance(chain.steps[1], PromptTemplate)
257256
self.assertIsInstance(chain.steps[2], ModelDeploymentTGI)

0 commit comments

Comments
 (0)