|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Copyright (c) 2024 Oracle and/or its affiliates. |
| 4 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 5 | +import pytest |
| 6 | +from ads.opctl.operator.lowcode.common.utils import ( |
| 7 | + load_data, |
| 8 | +) |
| 9 | +from ads.opctl.operator.common.operator_config import InputData |
| 10 | +from unittest.mock import patch, Mock |
| 11 | +import unittest |
| 12 | + |
| 13 | + |
| 14 | +class TestDataLoad(unittest.TestCase): |
| 15 | + def setUp(self): |
| 16 | + self.data_spec = Mock(spec=InputData) |
| 17 | + self.data_spec.connect_args = { |
| 18 | + 'dsn': '(description= (retry_count=20)(retry_delay=3)(address=(protocol=tcps)(port=1522)(host=adb.us-ashburn-1.oraclecloud.com))(connect_data=(service_name=q9tjyjeyzhxqwla_h8posa0j7hooatry_high.adb.oraclecloud.com))(security=(ssl_server_dn_match=yes)))', |
| 19 | + 'wallet_password': '@Varsha1' |
| 20 | + } |
| 21 | + self.data_spec.vault_secret_id = 'ocid1.vaultsecret.oc1.iad.amaaaaaav66vvnialgpfay4ys5shd6y5nu4f2tn2e3qius2s23adzipuyhqq' |
| 22 | + self.data_spec.table_name = 'DF_SALARY' |
| 23 | + self.data_spec.url = None |
| 24 | + self.data_spec.format = None |
| 25 | + self.data_spec.columns = None |
| 26 | + self.data_spec.limit = None |
| 27 | + |
| 28 | + def testLoadSecretAndDBConnection(self): |
| 29 | + data = load_data(self.data_spec) |
| 30 | + assert len(data) == 135, f"Expected length 135, but got {len(data)}" |
| 31 | + expected_columns = ['CODE', 'PAY_MONTH', 'FIXED_SAL'] |
| 32 | + assert list( |
| 33 | + data.columns) == expected_columns, f"Expected columns {expected_columns}, but got {list(data.columns)}" |
| 34 | + |
| 35 | + def testLoadVaultFailure(self): |
| 36 | + msg = "Vault exception message" |
| 37 | + |
| 38 | + def mock_load_secret(*args, **kwargs): |
| 39 | + raise Exception(msg) |
| 40 | + |
| 41 | + with patch('ads.secrets.ADBSecretKeeper.load_secret', side_effect=mock_load_secret): |
| 42 | + with pytest.raises(Exception) as e: |
| 43 | + load_data(self.data_spec) |
| 44 | + |
| 45 | + expected_msg = f"Could not retrieve database credentials from vault {self.data_spec.vault_secret_id}: {msg}" |
| 46 | + assert str(e.value) == expected_msg, f"Expected exception message '{expected_msg}', but got '{str(e)}'" |
| 47 | + |
| 48 | + def testDBConnectionFailure(self): |
| 49 | + msg = "Mocked DB connection error" |
| 50 | + |
| 51 | + def mock_oracledb_connect(*args, **kwargs): |
| 52 | + raise Exception(msg) |
| 53 | + |
| 54 | + with patch('oracledb.connect', side_effect=mock_oracledb_connect): |
| 55 | + with pytest.raises(Exception) as e: |
| 56 | + load_data(self.data_spec) |
| 57 | + |
| 58 | + assert str(e.value) == msg, f"Expected exception message '{msg}', but got '{str(e)}'" |
0 commit comments