|
1 | 1 | #!/usr/bin/env python
|
| 2 | +from typing import Union |
2 | 3 |
|
3 | 4 | # Copyright (c) 2024 Oracle and/or its affiliates.
|
4 | 5 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
|
7 | 8 | load_data,
|
8 | 9 | )
|
9 | 10 | from ads.opctl.operator.common.operator_config import InputData
|
10 |
| -from unittest.mock import patch, Mock |
| 11 | +from unittest.mock import patch, Mock, MagicMock |
11 | 12 | import unittest
|
| 13 | +import pandas as pd |
| 14 | + |
| 15 | +mock_secret = { |
| 16 | + 'user_name': 'mock_user', |
| 17 | + 'password': 'mock_password', |
| 18 | + 'service_name': 'mock_service_name' |
| 19 | +} |
| 20 | + |
| 21 | +mock_connect_args = { |
| 22 | + 'user': 'mock_user', |
| 23 | + 'password': 'mock_password', |
| 24 | + 'service_name': 'mock_service_name', |
| 25 | + 'dsn': 'mock_dsn' |
| 26 | +} |
| 27 | + |
| 28 | +# Mock data for testing |
| 29 | +mock_data = pd.DataFrame({ |
| 30 | + 'id': [1, 2, 3], |
| 31 | + 'name': ['Alice', 'Bob', 'Charlie'] |
| 32 | +}) |
| 33 | + |
| 34 | +mock_db_connection = MagicMock() |
| 35 | + |
| 36 | +load_secret_err_msg = "Vault exception message" |
| 37 | +db_connect_err_msg = "Mocked DB connection error" |
| 38 | + |
| 39 | + |
| 40 | +def mock_oracledb_connect_failure(*args, **kwargs): |
| 41 | + raise Exception(db_connect_err_msg) |
| 42 | + |
| 43 | + |
| 44 | +def mock_oracledb_connect(**kwargs): |
| 45 | + assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}" |
| 46 | + return mock_db_connection |
| 47 | + |
| 48 | + |
| 49 | +class MockADBSecretKeeper: |
| 50 | + @staticmethod |
| 51 | + def __enter__(*args, **kwargs): |
| 52 | + return mock_secret |
| 53 | + |
| 54 | + @staticmethod |
| 55 | + def __exit__(*args, **kwargs): |
| 56 | + pass |
| 57 | + |
| 58 | + @staticmethod |
| 59 | + def load_secret(vault_secret_id, wallet_dir): |
| 60 | + return MockADBSecretKeeper() |
| 61 | + |
| 62 | + @staticmethod |
| 63 | + def load_secret_fail(*args, **kwargs): |
| 64 | + raise Exception(load_secret_err_msg) |
12 | 65 |
|
13 | 66 |
|
14 | 67 | class TestDataLoad(unittest.TestCase):
|
15 | 68 | def setUp(self):
|
16 | 69 | self.data_spec = Mock(spec=InputData)
|
17 | 70 | self.data_spec.connect_args = {
|
18 |
| - 'dsn': '(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=q9tjyjeyzhxqwla_h8posa0j7hooatry_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=yes)))', |
19 |
| - 'wallet_password': '@Varsha1' |
| 71 | + 'dsn': 'mock_dsn' |
20 | 72 | }
|
21 |
| - self.data_spec.vault_secret_id = 'ocid1.vaultsecret.oc1.iad.amaaaaaav66vvnialgpfay4ys5shd6y5nu4f2tn2e3qius2s23adzipuyhqq' |
22 |
| - self.data_spec.table_name = 'DF_SALARY' |
| 73 | + self.data_spec.vault_secret_id = 'mock_secret_id' |
| 74 | + self.data_spec.table_name = 'mock_table_name' |
23 | 75 | self.data_spec.url = None
|
24 | 76 | self.data_spec.format = None
|
25 | 77 | self.data_spec.columns = None
|
26 | 78 | self.data_spec.limit = None
|
27 | 79 |
|
28 | 80 | def testLoadSecretAndDBConnection(self):
|
29 |
| - data = load_data(self.data_spec) |
30 |
| - assert len(data) == 135, f"Expected length 135, but got {len(data)}" |
31 |
| - expected_columns = ['CODE', 'PAY_MONTH', 'FIXED_SAL'] |
32 |
| - assert list( |
33 |
| - data.columns) == expected_columns, f"Expected columns {expected_columns}, but got {list(data.columns)}" |
| 81 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): |
| 82 | + with patch('oracledb.connect', side_effect=mock_oracledb_connect): |
| 83 | + with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql: |
| 84 | + data = load_data(self.data_spec) |
| 85 | + mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}", |
| 86 | + mock_db_connection) |
| 87 | + pd.testing.assert_frame_equal(data, mock_data) |
34 | 88 |
|
35 | 89 | def testLoadVaultFailure(self):
|
36 |
| - msg = "Vault exception message" |
37 |
| - |
38 |
| - def mock_load_secret(*args, **kwargs): |
39 |
| - raise Exception(msg) |
40 |
| - |
41 |
| - with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=mock_load_secret): |
| 90 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail): |
42 | 91 | with pytest.raises(Exception) as e:
|
43 | 92 | load_data(self.data_spec)
|
44 | 93 |
|
45 |
| - expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {msg}" |
| 94 | + expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}" |
46 | 95 | assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'"
|
47 | 96 |
|
48 | 97 | def testDBConnectionFailure(self):
|
49 |
| - msg = "Mocked DB connection error" |
50 |
| - |
51 |
| - def mock_oracledb_connect(*args, **kwargs): |
52 |
| - raise Exception(msg) |
53 |
| - |
54 |
| - with patch('oracledb.connect', side_effect=mock_oracledb_connect): |
55 |
| - with pytest.raises(Exception) as e: |
56 |
| - load_data(self.data_spec) |
| 98 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret): |
| 99 | + with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure): |
| 100 | + with pytest.raises(Exception) as e: |
| 101 | + load_data(self.data_spec) |
57 | 102 |
|
58 |
| - assert str(e.value) == msg, f"Expected exception message '{msg}', but got '{str(e)}'" |
| 103 | + assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'" |
0 commit comments