Skip to content

Commit cf9ca93

Browse files
committed
Updated pr.
1 parent ed0e207 commit cf9ca93

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

tests/unitary/with_extras/langchain/embeddings/test_oci_model_deployment_endpoint.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,49 @@
66

77
"""Test OCI Data Science Model Deployment Endpoint."""
88

9-
import responses
10-
from pytest_mock import MockerFixture
9+
from unittest.mock import MagicMock, patch
1110
from ads.llm import OCIModelDeploymentEndpointEmbeddings
1211

1312

14-
@responses.activate
15-
def test_embedding_call(mocker: MockerFixture) -> None:
13+
@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry")
14+
def test_embed_documents(mock_embed_with_retry) -> None:
1615
"""Test valid call to oci model deployment endpoint."""
17-
endpoint = "https://MD_OCID/predict"
18-
documents = ["Hello", "World"]
1916
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
20-
responses.add(
21-
responses.POST,
22-
endpoint,
23-
json={
17+
result = MagicMock()
18+
result.json = MagicMock(
19+
return_value={
2420
"data": [{"embedding": expected_output}],
25-
},
26-
status=200,
21+
}
2722
)
28-
mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None))
23+
mock_embed_with_retry.return_value = result
24+
endpoint = "https://MD_OCID/predict"
25+
documents = ["Hello", "World"]
2926

3027
embeddings = OCIModelDeploymentEndpointEmbeddings(
3128
endpoint=endpoint,
3229
)
3330

3431
output = embeddings.embed_documents(documents)
3532
assert output == expected_output
33+
34+
35+
@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry")
36+
def test_embed_query(mock_embed_with_retry) -> None:
37+
"""Test valid call to oci model deployment endpoint."""
38+
expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
39+
result = MagicMock()
40+
result.json = MagicMock(
41+
return_value={
42+
"data": [{"embedding": expected_output}],
43+
}
44+
)
45+
mock_embed_with_retry.return_value = result
46+
endpoint = "https://MD_OCID/predict"
47+
query = "Hello world"
48+
49+
embeddings = OCIModelDeploymentEndpointEmbeddings(
50+
endpoint=endpoint,
51+
)
52+
53+
output = embeddings.embed_query(query)
54+
assert output == expected_output[0]

0 commit comments

Comments
 (0)