Skip to content

Commit 25c26f2

Browse files
committed
Updated pr.
1 parent 960ed5b commit 25c26f2

File tree

2 files changed

+10
-14
lines changed

2 files changed

+10
-14
lines changed

tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
from typing import Any, AsyncGenerator, Dict, Generator
1111
from unittest import mock
1212

13-
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14-
DEFAULT_INFERENCE_ENDPOINT_CHAT,
15-
)
1613
import pytest
1714

1815

@@ -29,6 +26,7 @@
2926
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
3027
CONST_PROMPT = "This is a prompt."
3128
CONST_COMPLETION = "This is a completion."
29+
CONST_ENDPOINT = "/v1/chat/completions"
3230
CONST_COMPLETION_RESPONSE = {
3331
"id": "chat-123456789",
3432
"object": "chat.completion",
@@ -126,7 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
126124
def test_invoke_vllm(*args: Any) -> None:
127125
"""Tests invoking vLLM endpoint."""
128126
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
129-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
127+
assert llm.headers == {"route": CONST_ENDPOINT}
130128
output = llm.invoke(CONST_PROMPT)
131129
assert isinstance(output, AIMessage)
132130
assert output.content == CONST_COMPLETION
@@ -139,7 +137,7 @@ def test_invoke_vllm(*args: Any) -> None:
139137
def test_invoke_tgi(*args: Any) -> None:
140138
"""Tests invoking TGI endpoint using OpenAI Spec."""
141139
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
142-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
140+
assert llm.headers == {"route": CONST_ENDPOINT}
143141
output = llm.invoke(CONST_PROMPT)
144142
assert isinstance(output, AIMessage)
145143
assert output.content == CONST_COMPLETION
@@ -154,7 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
154152
llm = ChatOCIModelDeploymentVLLM(
155153
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
156154
)
157-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
155+
assert llm.headers == {"route": CONST_ENDPOINT}
158156
output = None
159157
count = 0
160158
for chunk in llm.stream(CONST_PROMPT):
@@ -193,7 +191,7 @@ async def test_stream_async(*args: Any) -> None:
193191
llm = ChatOCIModelDeploymentVLLM(
194192
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
195193
)
196-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
194+
assert llm.headers == {"route": CONST_ENDPOINT}
197195
with mock.patch.object(
198196
llm,
199197
"_aiter_sse",

tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
from typing import Any, AsyncGenerator, Dict, Generator
1111
from unittest import mock
1212

13-
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
14-
DEFAULT_INFERENCE_ENDPOINT,
15-
)
1613
import pytest
1714

1815
if sys.version_info < (3, 9):
@@ -27,6 +24,7 @@
2724
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
2825
CONST_PROMPT = "This is a prompt."
2926
CONST_COMPLETION = "This is a completion."
27+
CONST_ENDPOINT = "/v1/completions"
3028
CONST_COMPLETION_RESPONSE = {
3129
"choices": [
3230
{
@@ -119,7 +117,7 @@ async def mocked_async_streaming_response(
119117
def test_invoke_vllm(*args: Any) -> None:
120118
"""Tests invoking vLLM endpoint."""
121119
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
122-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
120+
assert llm.headers == {"route": CONST_ENDPOINT}
123121
output = llm.invoke(CONST_PROMPT)
124122
assert output == CONST_COMPLETION
125123

@@ -132,7 +130,7 @@ def test_stream_tgi(*args: Any) -> None:
132130
llm = OCIModelDeploymentTGI(
133131
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
134132
)
135-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
133+
assert llm.headers == {"route": CONST_ENDPOINT}
136134
output = ""
137135
count = 0
138136
for chunk in llm.stream(CONST_PROMPT):
@@ -150,7 +148,7 @@ def test_generate_tgi(*args: Any) -> None:
150148
llm = OCIModelDeploymentTGI(
151149
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
152150
)
153-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
151+
assert llm.headers == {"route": CONST_ENDPOINT}
154152
output = llm.invoke(CONST_PROMPT)
155153
assert output == CONST_COMPLETION
156154

@@ -169,7 +167,7 @@ async def test_stream_async(*args: Any) -> None:
169167
llm = OCIModelDeploymentTGI(
170168
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
171169
)
172-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
170+
assert llm.headers == {"route": CONST_ENDPOINT}
173171
with mock.patch.object(
174172
llm,
175173
"_aiter_sse",

0 commit comments

Comments
 (0)