Skip to content

Commit 4ff91a4

Browse files
committed
Rename test_plugins.py to test_llm_plugins.py
1 parent f2d42d4 commit 4ff91a4

File tree

2 files changed

+65
-71
lines changed

2 files changed

+65
-71
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*--
3+
4+
# Copyright (c) 2023 Oracle and/or its affiliates.
5+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
6+
7+
import pytest
8+
import unittest
9+
from unittest.mock import patch
10+
11+
from ads.llm import ModelDeploymentTGI
12+
from oci.signer import Signer
13+
14+
15+
class LangChainPluginsTest(unittest.TestCase):
16+
mock_endpoint = "https://mock_endpoint/predict"
17+
18+
def mocked_requests_post(endpoint, headers, json, auth, **kwargs):
19+
class MockResponse:
20+
def __init__(self, json_data, status_code):
21+
self.json_data = json_data
22+
self.status_code = status_code
23+
24+
def json(self):
25+
return self.json_data
26+
27+
def raise_for_status(self):
28+
pass
29+
30+
assert endpoint.startswith("https://")
31+
assert json
32+
assert headers
33+
prompt = json.get("inputs")
34+
assert prompt and isinstance(prompt, str)
35+
completion = "ads" if "who" in prompt else "Unknown"
36+
assert auth
37+
assert isinstance(auth, Signer)
38+
39+
return MockResponse(
40+
json_data={"generated_text": completion},
41+
status_code=200,
42+
)
43+
44+
def test_oci_model_deployment_model_param(self):
45+
llm = ModelDeploymentTGI(endpoint=self.mock_endpoint, temperature=0.9)
46+
model_params_keys = [
47+
"best_of",
48+
"max_new_tokens",
49+
"temperature",
50+
"top_k",
51+
"top_p",
52+
"do_sample",
53+
"return_full_text",
54+
"watermark",
55+
]
56+
assert llm.endpoint == self.mock_endpoint
57+
assert all(key in llm._default_params for key in model_params_keys)
58+
assert llm.temperature == 0.9
59+
60+
@patch("requests.post", mocked_requests_post)
61+
def test_oci_model_deployment_call(self):
62+
llm = ModelDeploymentTGI(endpoint=self.mock_endpoint)
63+
response = llm("who am i")
64+
completion = "ads"
65+
assert response == completion

tests/unitary/with_extras/langchain/test_plugins.py

Lines changed: 0 additions & 71 deletions
This file was deleted.

0 commit comments

Comments
 (0)