Skip to content

Commit d39a87f

Browse files
committed
added unit tests
1 parent 8bc81a8 commit d39a87f

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

ads/opctl/operator/lowcode/common/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def load_data(data_spec, storage_options=None, **kwargs):
9696
connect_args['service_name'] = adwsecret['service_name']
9797

9898
except Exception as e:
99-
logger.debug(f"Could not retrieve database credentials from vault : {e}")
99+
raise Exception(f"Could not retrieve database credentials from vault {vault_secret_id}: {e}")
100100

101101
con = oracledb.connect(**connect_args)
102102
if table_name is not None:
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2024 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
import pytest
6+
from ads.opctl.operator.lowcode.common.utils import (
7+
load_data,
8+
)
9+
from ads.opctl.operator.common.operator_config import InputData
10+
from unittest.mock import patch, Mock
11+
import unittest
12+
13+
14+
class TestDataLoad(unittest.TestCase):
15+
def setUp(self):
16+
self.data_spec = Mock(spec=InputData)
17+
self.data_spec.connect_args = {
18+
'dsn': '(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=q9tjyjeyzhxqwla_h8posa0j7hooatry_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=yes)))',
19+
'wallet_password': '@Varsha1'
20+
}
21+
self.data_spec.vault_secret_id = 'ocid1.vaultsecret.oc1.iad.amaaaaaav66vvnialgpfay4ys5shd6y5nu4f2tn2e3qius2s23adzipuyhqq'
22+
self.data_spec.table_name = 'DF_SALARY'
23+
self.data_spec.url = None
24+
self.data_spec.format = None
25+
self.data_spec.columns = None
26+
self.data_spec.limit = None
27+
28+
def testLoadSecretAndDBConnection(self):
29+
data = load_data(self.data_spec)
30+
assert len(data) == 135, f"Expected length 135, but got {len(data)}"
31+
expected_columns = ['CODE', 'PAY_MONTH', 'FIXED_SAL']
32+
assert list(
33+
data.columns) == expected_columns, f"Expected columns {expected_columns}, but got {list(data.columns)}"
34+
35+
def testLoadVaultFailure(self):
36+
msg = "Vault exception message"
37+
38+
def mock_load_secret(*args, **kwargs):
39+
raise Exception(msg)
40+
41+
with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=mock_load_secret):
42+
with pytest.raises(Exception) as e:
43+
load_data(self.data_spec)
44+
45+
expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {msg}"
46+
assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'"
47+
48+
def testDBConnectionFailure(self):
49+
msg = "Mocked DB connection error"
50+
51+
def mock_oracledb_connect(*args, **kwargs):
52+
raise Exception(msg)
53+
54+
with patch('oracledb.connect', side_effect=mock_oracledb_connect):
55+
with pytest.raises(Exception) as e:
56+
load_data(self.data_spec)
57+
58+
assert str(e.value) == msg, f"Expected exception message '{msg}', but got '{str(e)}'"

0 commit comments

Comments
 (0)