|
| 1 | +import unittest |
| 2 | +from unittest.mock import patch, MagicMock, ANY |
| 3 | +import dlt # dlt.resource and dlt.source are used |
| 4 | +from pendulum import datetime as pendulum_datetime |
| 5 | + |
| 6 | +# Assuming __init__.py functions are importable |
| 7 | +from ingestr.src.stripe_analytics import stripe_source, incremental_stripe_source |
| 8 | + |
| 9 | +# Mock stripe module at the top level of the test file |
| 10 | +# This will be used by the @patch decorator on the class |
| 11 | +stripe_mock = MagicMock() |
| 12 | + |
| 13 | +@patch('ingestr.src.stripe_analytics.stripe', new=stripe_mock) |
| 14 | +class TestStripeAnalyticsSources(unittest.TestCase): |
| 15 | + |
| 16 | + def setUp(self): |
| 17 | + # Reset mocks for each test to prevent interference |
| 18 | + stripe_mock.reset_mock() |
| 19 | + |
| 20 | + def _test_endpoint_full_refresh(self, endpoint_name, stripe_object_name=None): |
| 21 | + if stripe_object_name is None: |
| 22 | + stripe_object_name = endpoint_name |
| 23 | + |
| 24 | + mock_list_method = getattr(stripe_mock, stripe_object_name).list |
| 25 | + mock_list_method.return_value = {"data": [{"id": f"id_{endpoint_name.lower()}_123"}], "has_more": False} |
| 26 | + |
| 27 | + resources = list(stripe_source(endpoints=(endpoint_name,), stripe_secret_key="sk_test_123")) |
| 28 | + |
| 29 | + self.assertEqual(len(resources), 1) |
| 30 | + resource = resources[0] |
| 31 | + self.assertEqual(resource.name, endpoint_name) |
| 32 | + # Note: dlt.resource objects don't directly expose write_disposition in a public way after creation. |
| 33 | + # The write_disposition is passed to dlt.resource decorator. We are testing that the correct |
| 34 | + # dlt_source function is called, which internally uses the correct write_disposition. |
| 35 | + # For the purpose of these tests, we confirm the correct source function was used. |
| 36 | + # If direct assertion is needed, it might require inspecting dlt internals or how it's applied. |
| 37 | + # Here, we are implicitly testing it by calling stripe_source which should use 'replace'. |
| 38 | + mock_list_method.assert_called_once_with(limit=100, created=None, starting_after=None) |
| 39 | + |
| 40 | + def _test_endpoint_incremental(self, endpoint_name, stripe_object_name=None, initial_start_date_fixture=pendulum_datetime(2020,1,1)): |
| 41 | + if stripe_object_name is None: |
| 42 | + stripe_object_name = endpoint_name |
| 43 | + |
| 44 | + mock_list_method = getattr(stripe_mock, stripe_object_name).list |
| 45 | + # Incremental loads often filter by 'created' |
| 46 | + mock_list_method.return_value = {"data": [{"id": f"id_{endpoint_name.lower()}_123", "created": 1620000000}], "has_more": False} |
| 47 | + |
| 48 | + # For incremental, we also need to mock dlt.sources.incremental |
| 49 | + # The incremental decorator is applied to the 'created' argument of the 'incremental_resource' inner function |
| 50 | + # So we need to patch it where it's actually used by dlt when the resource is iterated. |
| 51 | + # A simpler way is to trust that dlt.sources.incremental works as specified and that |
| 52 | + # our pagination function receives the correct 'created' arguments. |
| 53 | + # The pagination function itself is mocked via stripe.ObjectName.list. |
| 54 | + |
| 55 | + # We test if the initial_value for 'created' is passed correctly. |
| 56 | + # The actual filtering by 'created' happens inside the mocked 'pagination' (via stripe.ObjectName.list) |
| 57 | + |
| 58 | + resources = list(incremental_stripe_source(endpoints=(endpoint_name,), stripe_secret_key="sk_test_123", initial_start_date=initial_start_date_fixture)) |
| 59 | + |
| 60 | + self.assertEqual(len(resources), 1) |
| 61 | + resource = resources[0] |
| 62 | + self.assertEqual(resource.name, endpoint_name) |
| 63 | + # Implicitly testing write_disposition="append" by calling incremental_stripe_source. |
| 64 | + |
| 65 | + # Check if the first call to pagination (mocked by stripe.ObjectName.list) |
| 66 | + # received a 'created' argument reflecting the initial_start_date. |
| 67 | + expected_start_timestamp = int(initial_start_date_fixture.timestamp()) if initial_start_date_fixture else -1 |
| 68 | + |
| 69 | + # The actual call to stripe.XXX.list happens inside the 'pagination' helper, |
| 70 | + # which receives the 'created' value from the dlt.incremental decorator. |
| 71 | + # We are checking the arguments to the mocked stripe.XXX.list method. |
| 72 | + # The 'created' dict for date range is passed to stripe.XXX.list |
| 73 | + if initial_start_date_fixture == -1: # Special case for default initial_value |
| 74 | + mock_list_method.assert_called_once_with(limit=100, created={'gte': -1}, starting_after=None) |
| 75 | + else: |
| 76 | + mock_list_method.assert_called_once_with(limit=100, created={'gte': expected_start_timestamp}, starting_after=None) |
| 77 | + |
| 78 | + |
| 79 | + # --- Tests for Newly Added Endpoints --- |
| 80 | + def test_application_fee_endpoint_incremental(self): |
| 81 | + self._test_endpoint_incremental("ApplicationFee") |
| 82 | + |
| 83 | + def test_dispute_endpoint_full_refresh(self): |
| 84 | + self._test_endpoint_full_refresh("Dispute") |
| 85 | + |
| 86 | + def test_subscription_item_endpoint_full_refresh(self): |
| 87 | + self._test_endpoint_full_refresh("SubscriptionItem") |
| 88 | + |
| 89 | + def test_checkout_session_endpoint_full_refresh(self): |
| 90 | + # Checkout.Session is accessed via stripe.checkout.Session |
| 91 | + stripe_mock.checkout.Session.list.return_value = {"data": [{"id": "cs_123"}], "has_more": False} |
| 92 | + |
| 93 | + resources = list(stripe_source(endpoints=("CheckoutSession",), stripe_secret_key="sk_test_123")) |
| 94 | + |
| 95 | + self.assertEqual(len(resources), 1) |
| 96 | + resource = resources[0] |
| 97 | + self.assertEqual(resource.name, "CheckoutSession") |
| 98 | + stripe_mock.checkout.Session.list.assert_called_once_with(limit=100, created=None, starting_after=None) |
| 99 | + |
| 100 | + def test_credit_note_endpoint_incremental(self): |
| 101 | + self._test_endpoint_incremental("CreditNote") |
| 102 | + |
| 103 | + def test_customer_balance_transaction_endpoint_incremental(self): |
| 104 | + self._test_endpoint_incremental("CustomerBalanceTransaction") |
| 105 | + |
| 106 | + def test_setup_attempt_endpoint_incremental(self): |
| 107 | + # SetupAttempt list method requires setup_intent, but our generic pagination doesn't support that directly. |
| 108 | + # The source code currently calls stripe.SetupAttempt.list(...) without setup_intent if called via incremental_stripe_source. |
| 109 | + # This test will reflect that behavior. If specific params are needed, the source or test needs adjustment. |
| 110 | + self._test_endpoint_incremental("SetupAttempt", initial_start_date_fixture=-1) # Default initial value is -1 |
| 111 | + |
| 112 | + def test_shipping_rate_endpoint_full_refresh(self): |
| 113 | + self._test_endpoint_full_refresh("ShippingRate") |
| 114 | + |
| 115 | + # --- Tests for Pre-existing Endpoints --- |
| 116 | + def test_charge_endpoint_full_refresh(self): # Charge is in INCREMENTAL_ENDPOINTS now |
| 117 | + self._test_endpoint_incremental("Charge") |
| 118 | + |
| 119 | + def test_event_endpoint_incremental(self): |
| 120 | + self._test_endpoint_incremental("Event") |
| 121 | + |
| 122 | + def test_customer_endpoint_full_refresh(self): |
| 123 | + self._test_endpoint_full_refresh("Customer") |
| 124 | + |
| 125 | + def test_subscription_endpoint_full_refresh(self): |
| 126 | + self._test_endpoint_full_refresh("Subscription") |
| 127 | + |
| 128 | + def test_invoice_endpoint_incremental(self): |
| 129 | + self._test_endpoint_incremental("Invoice") |
| 130 | + |
| 131 | + # Test for ApplicationFeeRefund (Type C - Nested) |
| 132 | + @patch('ingestr.src.stripe_analytics.pagination') |
| 133 | + def test_application_fee_refund_endpoint_full_refresh(self, mock_pagination): |
| 134 | + # Mock the parent resource (ApplicationFee) pagination |
| 135 | + mock_pagination.side_effect = [ |
| 136 | + iter([{"id": "fee_1", "created": 1620000000}, {"id": "fee_2", "created": 1620000001}]), # For ApplicationFee |
| 137 | + # The pagination for refunds themselves is handled by stripe.ApplicationFee.list_refunds |
| 138 | + ] |
| 139 | + |
| 140 | + # Mock the nested list_refunds call |
| 141 | + stripe_mock.ApplicationFee.list_refunds.side_effect = [ |
| 142 | + {"data": [{"id": "fr_fee1_1", "fee": "fee_1", "created": 1620000000}], "has_more": False}, # Refunds for fee_1 |
| 143 | + {"data": [{"id": "fr_fee2_1", "fee": "fee_2", "created": 1620000001}], "has_more": False}, # Refunds for fee_2 |
| 144 | + ] |
| 145 | + |
| 146 | + resources = list(stripe_source(endpoints=("ApplicationFeeRefund",), stripe_secret_key="sk_test_123")) |
| 147 | + |
| 148 | + self.assertEqual(len(resources), 1) |
| 149 | + resource = resources[0] |
| 150 | + self.assertEqual(resource.name, "ApplicationFeeRefund") |
| 151 | + |
| 152 | + # Verify pagination was called for ApplicationFee |
| 153 | + mock_pagination.assert_any_call("ApplicationFee", None, None) |
| 154 | + |
| 155 | + # Verify list_refunds was called for each application fee |
| 156 | + self.assertEqual(stripe_mock.ApplicationFee.list_refunds.call_count, 2) |
| 157 | + stripe_mock.ApplicationFee.list_refunds.assert_any_call("fee_1", limit=100, starting_after=None) |
| 158 | + stripe_mock.ApplicationFee.list_refunds.assert_any_call("fee_2", limit=100, starting_after=None) |
| 159 | + |
| 160 | + # Collect data from the resource to ensure refunds are yielded |
| 161 | + refund_data = list(resource) |
| 162 | + self.assertEqual(len(refund_data), 2) |
| 163 | + self.assertEqual(refund_data[0]["id"], "fr_fee1_1") |
| 164 | + self.assertEqual(refund_data[1]["id"], "fr_fee2_1") |
| 165 | + |
| 166 | + |
| 167 | +if __name__ == '__main__': |
| 168 | + unittest.main() |
0 commit comments