diff --git a/tests/test_disable_discovery.py b/tests/test_disable_discovery.py new file mode 100644 index 00000000..17e49c72 --- /dev/null +++ b/tests/test_disable_discovery.py @@ -0,0 +1,313 @@ +import pytest +import unittest.mock +import ydb +import asyncio +from ydb import _apis + + +TEST_ERROR = "Test error" +TEST_QUERY = "SELECT 1 + 2 AS sum" + + +@pytest.fixture +def mock_connection(): + """Mock a YDB connection to avoid actual connections.""" + with unittest.mock.patch("ydb.connection.Connection.ready_factory") as mock_factory: + # Setup the mock to return a connection-like object + mock_connection = unittest.mock.MagicMock() + # Use the endpoint fixture value via the function parameter + mock_connection.endpoint = "localhost:2136" # Will be overridden in tests + mock_connection.node_id = "mock_node_id" + mock_factory.return_value = mock_connection + yield mock_factory + + +@pytest.fixture +def mock_aio_connection(): + """Mock a YDB async connection to avoid actual connections.""" + with unittest.mock.patch("ydb.aio.connection.Connection.__init__") as mock_init: + # Setup the mock to return None (as __init__ does) + mock_init.return_value = None + + # Mock connection_ready method + with unittest.mock.patch("ydb.aio.connection.Connection.connection_ready") as mock_ready: + # Create event loop if there isn't one currently + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + future = asyncio.Future() + future.set_result(None) + mock_ready.return_value = future + yield mock_init + + +def create_mock_discovery_resolver(path): + """Create a mock discovery resolver that raises exception if called.""" + + def _mock_fixture(): + with unittest.mock.patch(path) as mock_resolve: + # Configure mock to throw an exception if called + mock_resolve.side_effect = Exception("Discovery should not be executed when discovery is disabled") + yield mock_resolve + + return _mock_fixture + + +# Mock discovery resolvers to verify no discovery requests are made +mock_discovery_resolver = pytest.fixture( + create_mock_discovery_resolver("ydb.resolver.DiscoveryEndpointsResolver.context_resolve") +) +mock_aio_discovery_resolver = pytest.fixture( + create_mock_discovery_resolver("ydb.aio.resolver.DiscoveryEndpointsResolver.resolve") +) + + +# Basic unit tests for DriverConfig +def test_driver_config_has_disable_discovery_option(endpoint, database): + """Test that DriverConfig has the disable_discovery option.""" + config = ydb.DriverConfig(endpoint=endpoint, database=database, disable_discovery=True) + assert hasattr(config, "disable_discovery") + assert config.disable_discovery is True + + +# Driver config fixtures +def create_driver_config(endpoint, database, disable_discovery): + """Create a driver config with the given discovery setting.""" + return ydb.DriverConfig( + endpoint=endpoint, + database=database, + disable_discovery=disable_discovery, + ) + + +@pytest.fixture +def driver_config_disabled_discovery(endpoint, database): + """A driver config with discovery disabled""" + return create_driver_config(endpoint, database, True) + + +@pytest.fixture +def driver_config_enabled_discovery(endpoint, database): + """A driver config with discovery enabled (default)""" + return create_driver_config(endpoint, database, False) + + +# Common test assertions +def assert_discovery_disabled(driver): + """Assert that discovery is disabled in the driver.""" + assert "Discovery is disabled" in driver.discovery_debug_details() + + +def create_future_with_error(): + """Create a future with a test error.""" + future = asyncio.Future() + future.set_exception(ydb.issues.Error(TEST_ERROR)) + return future + + +def create_completed_future(): + """Create a completed future.""" + future = asyncio.Future() + future.set_result(None) + return future + + +# Mock tests for synchronous driver +def test_sync_driver_discovery_disabled_mock( + driver_config_disabled_discovery, mock_connection, mock_discovery_resolver +): + """Test that when disable_discovery=True, the discovery thread is not started and resolver is not called (mock).""" + with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + + try: + # Check that the discovery thread was not created + mock_discovery_class.assert_not_called() + + # Check that discovery is disabled in debug details + assert_discovery_disabled(driver) + + # Execute a dummy call to verify no discovery requests are made + try: + driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation") + except ydb.issues.Error: + pass # Expected exception, we just want to ensure no discovery occurs + + # Verify the mock wasn't called + assert ( + not mock_discovery_resolver.called + ), "Discovery resolver should not be called when discovery is disabled" + finally: + # Clean up + driver.stop() + + +def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection): + """Test that when disable_discovery=False, the discovery thread is started (mock).""" + with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class: + mock_discovery_instance = unittest.mock.MagicMock() + mock_discovery_class.return_value = mock_discovery_instance + + driver = ydb.Driver(driver_config=driver_config_enabled_discovery) + + try: + # Check that the discovery thread was created and started + mock_discovery_class.assert_called_once() + assert mock_discovery_instance.start.called + finally: + # Clean up + driver.stop() + + +# Helper for setting up async driver test mocks +def setup_async_driver_mocks(): + """Set up common mocks for async driver tests.""" + mocks = {} + + # Create mock for Discovery class + discovery_patcher = unittest.mock.patch("ydb.aio.pool.Discovery") + mocks["mock_discovery_class"] = discovery_patcher.start() + + # Mock the event loop + loop_patcher = unittest.mock.patch("asyncio.get_event_loop") + mock_loop = loop_patcher.start() + mock_loop_instance = unittest.mock.MagicMock() + mock_loop.return_value = mock_loop_instance + mock_loop_instance.create_task.return_value = unittest.mock.MagicMock() + mocks["mock_loop"] = mock_loop + + # Mock the connection pool stop method + stop_patcher = unittest.mock.patch("ydb.aio.pool.ConnectionPool.stop") + mock_stop = stop_patcher.start() + mock_stop.return_value = create_completed_future() + mocks["mock_stop"] = mock_stop + + # Add cleanup for all patchers + mocks["patchers"] = [discovery_patcher, loop_patcher, stop_patcher] + + return mocks + + +def teardown_async_mocks(mocks): + """Clean up all mock patchers.""" + for patcher in mocks["patchers"]: + patcher.stop() + + +# Mock tests for asynchronous driver +@pytest.mark.asyncio +async def test_aio_driver_discovery_disabled_mock( + driver_config_disabled_discovery, mock_aio_connection, mock_aio_discovery_resolver +): + """Test that when disable_discovery=True, the discovery is not created and resolver is not called (mock).""" + mocks = setup_async_driver_mocks() + + try: + # Mock the pool's call method to prevent unhandled exceptions + with unittest.mock.patch("ydb.aio.pool.ConnectionPool.__call__") as mock_call: + mock_call.return_value = create_future_with_error() + + driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery) + + try: + # Check that the discovery class was not instantiated + mocks["mock_discovery_class"].assert_not_called() + + # Check that discovery is disabled in debug details + assert_discovery_disabled(driver) + + # Execute a dummy call to verify no discovery requests are made + try: + try: + await driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation") + except ydb.issues.Error: + pass # Expected exception, we just want to ensure no discovery occurs + except Exception as e: + if "discovery is disabled" in str(e).lower(): + raise # If the error is related to discovery being disabled, re-raise it + pass # Other exceptions are expected as we're using mocks + + # Verify the mock wasn't called + assert ( + not mock_aio_discovery_resolver.called + ), "Discovery resolver should not be called when discovery is disabled" + finally: + # The stop method is already mocked, so we don't need to await it + pass + finally: + teardown_async_mocks(mocks) + + +@pytest.mark.asyncio +async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection): + """Test that when disable_discovery=False, the discovery is created (mock).""" + mocks = setup_async_driver_mocks() + + try: + mock_discovery_instance = unittest.mock.MagicMock() + mocks["mock_discovery_class"].return_value = mock_discovery_instance + + driver = ydb.aio.Driver(driver_config=driver_config_enabled_discovery) + + try: + # Check that the discovery class was instantiated + mocks["mock_discovery_class"].assert_called_once() + assert driver is not None # Use the driver variable to avoid F841 + finally: + # The stop method is already mocked, so we don't need to await it + pass + finally: + teardown_async_mocks(mocks) + + +# Common integration test logic +def perform_integration_test_checks(driver, is_async=False): + """Common assertions for integration tests.""" + assert_discovery_disabled(driver) + + +# Integration tests with real YDB +def test_integration_disable_discovery(driver_config_disabled_discovery): + """Integration test that tests the disable_discovery feature with a real YDB container.""" + # Create driver with discovery disabled + driver = ydb.Driver(driver_config=driver_config_disabled_discovery) + try: + driver.wait(timeout=15) + perform_integration_test_checks(driver) + + # Try to execute a simple query to ensure it works with discovery disabled + with ydb.SessionPool(driver) as pool: + + def query_callback(session): + result_sets = session.transaction().execute(TEST_QUERY, commit_tx=True) + assert len(result_sets) == 1 + assert result_sets[0].rows[0].sum == 3 + + pool.retry_operation_sync(query_callback) + finally: + driver.stop(timeout=10) + + +@pytest.mark.asyncio +async def test_integration_aio_disable_discovery(driver_config_disabled_discovery): + """Integration test that tests the disable_discovery feature with a real YDB container (async).""" + # Create driver with discovery disabled + driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery) + try: + await driver.wait(timeout=15) + perform_integration_test_checks(driver, is_async=True) + + # Try to execute a simple query to ensure it works with discovery disabled + session_pool = ydb.aio.SessionPool(driver, size=10) + + async def query_callback(session): + result_sets = await session.transaction().execute(TEST_QUERY, commit_tx=True) + assert len(result_sets) == 1 + assert result_sets[0].rows[0].sum == 3 + + try: + await session_pool.retry_operation(query_callback) + finally: + await session_pool.stop() + finally: + await driver.stop(timeout=10) diff --git a/ydb/aio/pool.py b/ydb/aio/pool.py index c8fbb904..99a3cfdb 100644 --- a/ydb/aio/pool.py +++ b/ydb/aio/pool.py @@ -199,12 +199,27 @@ def __init__(self, driver_config): self._store = ConnectionsCache(driver_config.use_all_nodes) self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config) self._stopped = False - self._discovery = Discovery(self._store, self._driver_config) - self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run()) + if driver_config.disable_discovery: + # If discovery is disabled, just add the initial endpoint to the store + async def init_connection(): + ready_connection = Connection(self._driver_config.endpoint, self._driver_config) + await ready_connection.connection_ready( + ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10) + ) + self._store.add(ready_connection) + + # Create and schedule the task to initialize the connection + self._discovery = None + self._discovery_task = asyncio.get_event_loop().create_task(init_connection()) + else: + # Start discovery as usual + self._discovery = Discovery(self._store, self._driver_config) + self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run()) async def stop(self, timeout=10): - self._discovery.stop() + if self._discovery: + self._discovery.stop() await self._grpc_init.close() try: await asyncio.wait_for(self._discovery_task, timeout=timeout) @@ -215,7 +230,8 @@ async def stop(self, timeout=10): def _on_disconnected(self, connection): async def __wrapper__(): await connection.close() - self._discovery.notify_disconnected() + if self._discovery: + self._discovery.notify_disconnected() return __wrapper__ @@ -223,7 +239,9 @@ async def wait(self, timeout=7, fail_fast=False): await self._store.get(fast_fail=fail_fast, wait_timeout=timeout) def discovery_debug_details(self): - return self._discovery.discovery_debug_details() + if self._discovery: + return self._discovery.discovery_debug_details() + return "Discovery is disabled, using only the initial endpoint" async def __aenter__(self): return self @@ -248,7 +266,8 @@ async def __call__( try: connection = await self._store.get(preferred_endpoint, fast_fail=fast_fail, wait_timeout=wait_timeout) except BaseException: - self._discovery.notify_disconnected() + if self._discovery: + self._discovery.notify_disconnected() raise return await connection( diff --git a/ydb/driver.py b/ydb/driver.py index 3998aeee..09d531d0 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -96,6 +96,7 @@ class DriverConfig(object): "grpc_lb_policy_name", "discovery_request_timeout", "compression", + "disable_discovery", ) def __init__( @@ -120,6 +121,7 @@ def __init__( grpc_lb_policy_name="round_robin", discovery_request_timeout=10, compression=None, + disable_discovery=False, ): """ A driver config to initialize a driver instance @@ -140,6 +142,7 @@ def __init__( If tracing aio ScopeManager must be ContextVarsScopeManager :param grpc_lb_policy_name: A load balancing policy to be used for discovery channel construction. Default value is `round_round` :param discovery_request_timeout: A default timeout to complete the discovery. The default value is 10 seconds. + :param disable_discovery: If True, endpoint discovery is disabled and only the start endpoint is used for all requests. """ self.endpoint = endpoint @@ -167,6 +170,7 @@ def __init__( self.grpc_lb_policy_name = grpc_lb_policy_name self.discovery_request_timeout = discovery_request_timeout self.compression = compression + self.disable_discovery = disable_discovery def set_database(self, database): self.database = database diff --git a/ydb/pool.py b/ydb/pool.py index 1e75950e..476ea674 100644 --- a/ydb/pool.py +++ b/ydb/pool.py @@ -350,8 +350,21 @@ def __init__(self, driver_config): self._store = ConnectionsCache(driver_config.use_all_nodes, driver_config.tracer) self.tracer = driver_config.tracer self._grpc_init = connection_impl.Connection(self._driver_config.endpoint, self._driver_config) - self._discovery_thread = Discovery(self._store, self._driver_config) - self._discovery_thread.start() + + if driver_config.disable_discovery: + # If discovery is disabled, just add the initial endpoint to the store + ready_connection = connection_impl.Connection.ready_factory( + self._driver_config.endpoint, + self._driver_config, + ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10), + ) + self._store.add(ready_connection) + self._discovery_thread = None + else: + # Start discovery thread as usual + self._discovery_thread = Discovery(self._store, self._driver_config) + self._discovery_thread.start() + self._stopped = False self._stop_guard = threading.Lock() @@ -367,9 +380,11 @@ def stop(self, timeout=10): return self._stopped = True - self._discovery_thread.stop() + if self._discovery_thread: + self._discovery_thread.stop() self._grpc_init.close() - self._discovery_thread.join(timeout) + if self._discovery_thread: + self._discovery_thread.join(timeout) def async_wait(self, fail_fast=False): """ @@ -404,7 +419,13 @@ def _on_disconnected(self, connection): self._discovery_thread.notify_disconnected() def discovery_debug_details(self): - return self._discovery_thread.discovery_debug_details() + """ + Returns debug string about last errors + :return: str + """ + if self._discovery_thread: + return self._discovery_thread.discovery_debug_details() + return "Discovery is disabled, using only the initial endpoint" @tracing.with_trace() def __call__(