10
10
from typing import Any , AsyncGenerator , Dict , Generator
11
11
from unittest import mock
12
12
13
- from ads .llm .langchain .plugins .chat_models .oci_data_science import (
14
- DEFAULT_INFERENCE_ENDPOINT_CHAT ,
15
- )
16
13
import pytest
17
14
18
15
29
26
CONST_ENDPOINT = "https://oci.endpoint/ocid/predict"
30
27
CONST_PROMPT = "This is a prompt."
31
28
CONST_COMPLETION = "This is a completion."
29
+ CONST_ENDPOINT = "/v1/chat/completions"
32
30
CONST_COMPLETION_RESPONSE = {
33
31
"id" : "chat-123456789" ,
34
32
"object" : "chat.completion" ,
@@ -126,7 +124,7 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
126
124
def test_invoke_vllm (* args : Any ) -> None :
127
125
"""Tests invoking vLLM endpoint."""
128
126
llm = ChatOCIModelDeploymentVLLM (endpoint = CONST_ENDPOINT , model = CONST_MODEL_NAME )
129
- assert llm .headers == {"route" : DEFAULT_INFERENCE_ENDPOINT_CHAT }
127
+ assert llm .headers == {"route" : CONST_ENDPOINT }
130
128
output = llm .invoke (CONST_PROMPT )
131
129
assert isinstance (output , AIMessage )
132
130
assert output .content == CONST_COMPLETION
@@ -139,7 +137,7 @@ def test_invoke_vllm(*args: Any) -> None:
139
137
def test_invoke_tgi (* args : Any ) -> None :
140
138
"""Tests invoking TGI endpoint using OpenAI Spec."""
141
139
llm = ChatOCIModelDeploymentTGI (endpoint = CONST_ENDPOINT , model = CONST_MODEL_NAME )
142
- assert llm .headers == {"route" : DEFAULT_INFERENCE_ENDPOINT_CHAT }
140
+ assert llm .headers == {"route" : CONST_ENDPOINT }
143
141
output = llm .invoke (CONST_PROMPT )
144
142
assert isinstance (output , AIMessage )
145
143
assert output .content == CONST_COMPLETION
@@ -154,7 +152,7 @@ def test_stream_vllm(*args: Any) -> None:
154
152
llm = ChatOCIModelDeploymentVLLM (
155
153
endpoint = CONST_ENDPOINT , model = CONST_MODEL_NAME , streaming = True
156
154
)
157
- assert llm .headers == {"route" : DEFAULT_INFERENCE_ENDPOINT_CHAT }
155
+ assert llm .headers == {"route" : CONST_ENDPOINT }
158
156
output = None
159
157
count = 0
160
158
for chunk in llm .stream (CONST_PROMPT ):
@@ -193,7 +191,7 @@ async def test_stream_async(*args: Any) -> None:
193
191
llm = ChatOCIModelDeploymentVLLM (
194
192
endpoint = CONST_ENDPOINT , model = CONST_MODEL_NAME , streaming = True
195
193
)
196
- assert llm .headers == {"route" : DEFAULT_INFERENCE_ENDPOINT_CHAT }
194
+ assert llm .headers == {"route" : CONST_ENDPOINT }
197
195
with mock .patch .object (
198
196
llm ,
199
197
"_aiter_sse" ,
0 commit comments