|
6 | 6 |
|
7 | 7 | """Test OCI Data Science Model Deployment Endpoint."""
|
8 | 8 |
|
9 |
| -import responses |
10 |
| -from pytest_mock import MockerFixture |
| 9 | +from unittest.mock import MagicMock, patch |
11 | 10 | from ads.llm import OCIModelDeploymentEndpointEmbeddings
|
12 | 11 |
|
13 | 12 |
|
14 |
| -@responses.activate |
15 |
| -def test_embedding_call(mocker: MockerFixture) -> None: |
| 13 | +@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") |
| 14 | +def test_embed_documents(mock_embed_with_retry) -> None: |
16 | 15 | """Test valid call to oci model deployment endpoint."""
|
17 |
| - endpoint = "https://MD_OCID/predict" |
18 |
| - documents = ["Hello", "World"] |
19 | 16 | expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]
|
20 |
| - responses.add( |
21 |
| - responses.POST, |
22 |
| - endpoint, |
23 |
| - json={ |
| 17 | + result = MagicMock() |
| 18 | + result.json = MagicMock( |
| 19 | + return_value={ |
24 | 20 | "data": [{"embedding": expected_output}],
|
25 |
| - }, |
26 |
| - status=200, |
| 21 | + } |
27 | 22 | )
|
28 |
| - mocker.patch("ads.common.auth.default_signer", return_value=dict(signer=None)) |
| 23 | + mock_embed_with_retry.return_value = result |
| 24 | + endpoint = "https://MD_OCID/predict" |
| 25 | + documents = ["Hello", "World"] |
29 | 26 |
|
30 | 27 | embeddings = OCIModelDeploymentEndpointEmbeddings(
|
31 | 28 | endpoint=endpoint,
|
32 | 29 | )
|
33 | 30 |
|
34 | 31 | output = embeddings.embed_documents(documents)
|
35 | 32 | assert output == expected_output
|
| 33 | + |
| 34 | + |
| 35 | +@patch("ads.llm.OCIModelDeploymentEndpointEmbeddings._embed_with_retry") |
| 36 | +def test_embed_query(mock_embed_with_retry) -> None: |
| 37 | + """Test valid call to oci model deployment endpoint.""" |
| 38 | + expected_output = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]] |
| 39 | + result = MagicMock() |
| 40 | + result.json = MagicMock( |
| 41 | + return_value={ |
| 42 | + "data": [{"embedding": expected_output}], |
| 43 | + } |
| 44 | + ) |
| 45 | + mock_embed_with_retry.return_value = result |
| 46 | + endpoint = "https://MD_OCID/predict" |
| 47 | + query = "Hello world" |
| 48 | + |
| 49 | + embeddings = OCIModelDeploymentEndpointEmbeddings( |
| 50 | + endpoint=endpoint, |
| 51 | + ) |
| 52 | + |
| 53 | + output = embeddings.embed_query(query) |
| 54 | + assert output == expected_output[0] |
0 commit comments