Skip to content

Commit cc66c42

Browse files
committed
Updated pr.
1 parent 238b6d9 commit cc66c42

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from typing import Any, AsyncGenerator, Dict, Generator
1111
from unittest import mock
1212

13+
from ads.llm.langchain.plugins.chat_models.oci_data_science import (
14+
DEFAULT_INFERENCE_ENDPOINT_CHAT,
15+
)
1316
import pytest
1417

1518

@@ -126,6 +129,7 @@ def test_invoke_vllm(*args: Any) -> None:
126129
output = llm.invoke(CONST_PROMPT)
127130
assert isinstance(output, AIMessage)
128131
assert output.content == CONST_COMPLETION
132+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
129133

130134

131135
@pytest.mark.requires("ads")
@@ -138,6 +142,7 @@ def test_invoke_tgi(*args: Any) -> None:
138142
output = llm.invoke(CONST_PROMPT)
139143
assert isinstance(output, AIMessage)
140144
assert output.content == CONST_COMPLETION
145+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
141146

142147

143148
@pytest.mark.requires("ads")
@@ -162,6 +167,7 @@ def test_stream_vllm(*args: Any) -> None:
162167
assert output is not None
163168
if output is not None:
164169
assert str(output.content).strip() == CONST_COMPLETION
170+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
165171

166172

167173
async def mocked_async_streaming_response(
@@ -194,3 +200,4 @@ async def test_stream_async(*args: Any) -> None:
194200
):
195201
chunks = [str(chunk.content) async for chunk in llm.astream(CONST_PROMPT)]
196202
assert "".join(chunks).strip() == CONST_COMPLETION
203+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}

tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
from typing import Any, AsyncGenerator, Dict, Generator
1111
from unittest import mock
1212

13+
from ads.llm.langchain.plugins.llms.oci_data_science_model_deployment_endpoint import (
14+
DEFAULT_INFERENCE_ENDPOINT,
15+
)
1316
import pytest
1417

1518
if sys.version_info < (3, 9):
@@ -118,6 +121,7 @@ def test_invoke_vllm(*args: Any) -> None:
118121
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
119122
output = llm.invoke(CONST_PROMPT)
120123
assert output == CONST_COMPLETION
124+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
121125

122126

123127
@pytest.mark.requires("ads")
@@ -135,6 +139,7 @@ def test_stream_tgi(*args: Any) -> None:
135139
count += 1
136140
assert count == 4
137141
assert output.strip() == CONST_COMPLETION
142+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
138143

139144

140145
@pytest.mark.requires("ads")
@@ -147,6 +152,7 @@ def test_generate_tgi(*args: Any) -> None:
147152
)
148153
output = llm.invoke(CONST_PROMPT)
149154
assert output == CONST_COMPLETION
155+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
150156

151157

152158
@pytest.mark.asyncio
@@ -170,3 +176,4 @@ async def test_stream_async(*args: Any) -> None:
170176
):
171177
chunks = [chunk async for chunk in llm.astream(CONST_PROMPT)]
172178
assert "".join(chunks).strip() == CONST_COMPLETION
179+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}

0 commit comments

Comments
 (0)