Skip to content

Commit 960ed5b

Browse files
committed
Updated pr.
1 parent cc66c42 commit 960ed5b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

tests/unitary/with_extras/langchain/chat_models/test_oci_data_science.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,10 @@ def mocked_requests_post(url: str, **kwargs: Any) -> MockResponse:
126126
def test_invoke_vllm(*args: Any) -> None:
127127
"""Tests invoking vLLM endpoint."""
128128
llm = ChatOCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
129+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
129130
output = llm.invoke(CONST_PROMPT)
130131
assert isinstance(output, AIMessage)
131132
assert output.content == CONST_COMPLETION
132-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
133133

134134

135135
@pytest.mark.requires("ads")
@@ -139,10 +139,10 @@ def test_invoke_vllm(*args: Any) -> None:
139139
def test_invoke_tgi(*args: Any) -> None:
140140
"""Tests invoking TGI endpoint using OpenAI Spec."""
141141
llm = ChatOCIModelDeploymentTGI(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
142+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
142143
output = llm.invoke(CONST_PROMPT)
143144
assert isinstance(output, AIMessage)
144145
assert output.content == CONST_COMPLETION
145-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
146146

147147

148148
@pytest.mark.requires("ads")
@@ -154,6 +154,7 @@ def test_stream_vllm(*args: Any) -> None:
154154
llm = ChatOCIModelDeploymentVLLM(
155155
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
156156
)
157+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
157158
output = None
158159
count = 0
159160
for chunk in llm.stream(CONST_PROMPT):
@@ -167,7 +168,6 @@ def test_stream_vllm(*args: Any) -> None:
167168
assert output is not None
168169
if output is not None:
169170
assert str(output.content).strip() == CONST_COMPLETION
170-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
171171

172172

173173
async def mocked_async_streaming_response(
@@ -193,11 +193,11 @@ async def test_stream_async(*args: Any) -> None:
193193
llm = ChatOCIModelDeploymentVLLM(
194194
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
195195
)
196+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}
196197
with mock.patch.object(
197198
llm,
198199
"_aiter_sse",
199200
mock.MagicMock(return_value=mocked_async_streaming_response()),
200201
):
201202
chunks = [str(chunk.content) async for chunk in llm.astream(CONST_PROMPT)]
202203
assert "".join(chunks).strip() == CONST_COMPLETION
203-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT_CHAT}

tests/unitary/with_extras/langchain/llms/test_oci_model_deployment_endpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,9 @@ async def mocked_async_streaming_response(
119119
def test_invoke_vllm(*args: Any) -> None:
120120
"""Tests invoking vLLM endpoint."""
121121
llm = OCIModelDeploymentVLLM(endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME)
122+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
122123
output = llm.invoke(CONST_PROMPT)
123124
assert output == CONST_COMPLETION
124-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
125125

126126

127127
@pytest.mark.requires("ads")
@@ -132,14 +132,14 @@ def test_stream_tgi(*args: Any) -> None:
132132
llm = OCIModelDeploymentTGI(
133133
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
134134
)
135+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
135136
output = ""
136137
count = 0
137138
for chunk in llm.stream(CONST_PROMPT):
138139
output += chunk
139140
count += 1
140141
assert count == 4
141142
assert output.strip() == CONST_COMPLETION
142-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
143143

144144

145145
@pytest.mark.requires("ads")
@@ -150,9 +150,9 @@ def test_generate_tgi(*args: Any) -> None:
150150
llm = OCIModelDeploymentTGI(
151151
endpoint=CONST_ENDPOINT, api="/generate", model=CONST_MODEL_NAME
152152
)
153+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
153154
output = llm.invoke(CONST_PROMPT)
154155
assert output == CONST_COMPLETION
155-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
156156

157157

158158
@pytest.mark.asyncio
@@ -169,11 +169,11 @@ async def test_stream_async(*args: Any) -> None:
169169
llm = OCIModelDeploymentTGI(
170170
endpoint=CONST_ENDPOINT, model=CONST_MODEL_NAME, streaming=True
171171
)
172+
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}
172173
with mock.patch.object(
173174
llm,
174175
"_aiter_sse",
175176
mock.MagicMock(return_value=mocked_async_streaming_response()),
176177
):
177178
chunks = [chunk async for chunk in llm.astream(CONST_PROMPT)]
178179
assert "".join(chunks).strip() == CONST_COMPLETION
179-
assert llm.headers == {"route": DEFAULT_INFERENCE_ENDPOINT}

0 commit comments

Comments
 (0)