Skip to content

Commit 1cbb681

Browse files
committed
added unit tests
1 parent d39a87f commit 1cbb681

File tree

1 file changed

+71
-26
lines changed

1 file changed

+71
-26
lines changed
Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#!/usr/bin/env python
2+
from typing import Union
23

34
# Copyright (c) 2024 Oracle and/or its affiliates.
45
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
@@ -7,52 +8,96 @@
78
load_data,
89
)
910
from ads.opctl.operator.common.operator_config import InputData
10-
from unittest.mock import patch, Mock
11+
from unittest.mock import patch, Mock, MagicMock
1112
import unittest
13+
import pandas as pd
14+
15+
mock_secret = {
16+
'user_name': 'mock_user',
17+
'password': 'mock_password',
18+
'service_name': 'mock_service_name'
19+
}
20+
21+
mock_connect_args = {
22+
'user': 'mock_user',
23+
'password': 'mock_password',
24+
'service_name': 'mock_service_name',
25+
'dsn': 'mock_dsn'
26+
}
27+
28+
# Mock data for testing
29+
mock_data = pd.DataFrame({
30+
'id': [1, 2, 3],
31+
'name': ['Alice', 'Bob', 'Charlie']
32+
})
33+
34+
mock_db_connection = MagicMock()
35+
36+
load_secret_err_msg = "Vault exception message"
37+
db_connect_err_msg = "Mocked DB connection error"
38+
39+
40+
def mock_oracledb_connect_failure(*args, **kwargs):
41+
raise Exception(db_connect_err_msg)
42+
43+
44+
def mock_oracledb_connect(**kwargs):
45+
assert kwargs == mock_connect_args, f"Expected connect_args {mock_connect_args}, but got {kwargs}"
46+
return mock_db_connection
47+
48+
49+
class MockADBSecretKeeper:
50+
@staticmethod
51+
def __enter__(*args, **kwargs):
52+
return mock_secret
53+
54+
@staticmethod
55+
def __exit__(*args, **kwargs):
56+
pass
57+
58+
@staticmethod
59+
def load_secret(vault_secret_id, wallet_dir):
60+
return MockADBSecretKeeper()
61+
62+
@staticmethod
63+
def load_secret_fail(*args, **kwargs):
64+
raise Exception(load_secret_err_msg)
1265

1366

1467
class TestDataLoad(unittest.TestCase):
1568
def setUp(self):
1669
self.data_spec = Mock(spec=InputData)
1770
self.data_spec.connect_args = {
18-
'dsn': '(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=q9tjyjeyzhxqwla_h8posa0j7hooatry_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=yes)))',
19-
'wallet_password': '@Varsha1'
71+
'dsn': 'mock_dsn'
2072
}
21-
self.data_spec.vault_secret_id = 'ocid1.vaultsecret.oc1.iad.amaaaaaav66vvnialgpfay4ys5shd6y5nu4f2tn2e3qius2s23adzipuyhqq'
22-
self.data_spec.table_name = 'DF_SALARY'
73+
self.data_spec.vault_secret_id = 'mock_secret_id'
74+
self.data_spec.table_name = 'mock_table_name'
2375
self.data_spec.url = None
2476
self.data_spec.format = None
2577
self.data_spec.columns = None
2678
self.data_spec.limit = None
2779

2880
def testLoadSecretAndDBConnection(self):
29-
data = load_data(self.data_spec)
30-
assert len(data) == 135, f"Expected length 135, but got {len(data)}"
31-
expected_columns = ['CODE', 'PAY_MONTH', 'FIXED_SAL']
32-
assert list(
33-
data.columns) == expected_columns, f"Expected columns {expected_columns}, but got {list(data.columns)}"
81+
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret):
82+
with patch('oracledb.connect', side_effect=mock_oracledb_connect):
83+
with patch('pandas.read_sql', return_value=mock_data) as mock_read_sql:
84+
data = load_data(self.data_spec)
85+
mock_read_sql.assert_called_once_with(f"SELECT * FROM {self.data_spec.table_name}",
86+
mock_db_connection)
87+
pd.testing.assert_frame_equal(data, mock_data)
3488

3589
def testLoadVaultFailure(self):
36-
msg = "Vault exception message"
37-
38-
def mock_load_secret(*args, **kwargs):
39-
raise Exception(msg)
40-
41-
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=mock_load_secret):
90+
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret_fail):
4291
with pytest.raises(Exception) as e:
4392
load_data(self.data_spec)
4493

45-
expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {msg}"
94+
expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {load_secret_err_msg}"
4695
assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'"
4796

4897
def testDBConnectionFailure(self):
49-
msg = "Mocked DB connection error"
50-
51-
def mock_oracledb_connect(*args, **kwargs):
52-
raise Exception(msg)
53-
54-
with patch('oracledb.connect', side_effect=mock_oracledb_connect):
55-
with pytest.raises(Exception) as e:
56-
load_data(self.data_spec)
98+
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=MockADBSecretKeeper.load_secret):
99+
with patch('oracledb.connect', side_effect=mock_oracledb_connect_failure):
100+
with pytest.raises(Exception) as e:
101+
load_data(self.data_spec)
57102

58-
assert str(e.value) == msg, f"Expected exception message '{msg}', but got '{str(e)}'"
103+
assert str(e.value) == db_connect_err_msg , f"Expected exception message '{db_connect_err_msg }', but got '{str(e)}'"

0 commit comments

Comments
 (0)