Skip to content

Commit cab0e4f

Browse files
committed
ydb-platform#325: add disable_discovery option to DriverConfig
1 parent 44f7f6d commit cab0e4f

File tree

4 files changed

+360
-11
lines changed

4 files changed

+360
-11
lines changed

tests/test_disable_discovery.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
import pytest
2+
import unittest.mock
3+
import ydb
4+
import asyncio
5+
from ydb import _apis
6+
7+
8+
# Common test constants and mock config
9+
DISCOVERY_DISABLED_ERROR_MSG = "Discovery should not be executed when discovery is disabled"
10+
TEST_ERROR = "Test error"
11+
TEST_QUERY = "SELECT 1 + 2 AS sum"
12+
13+
14+
@pytest.fixture
15+
def mock_connection():
16+
"""Mock a YDB connection to avoid actual connections."""
17+
with unittest.mock.patch('ydb.connection.Connection.ready_factory') as mock_factory:
18+
# Setup the mock to return a connection-like object
19+
mock_connection = unittest.mock.MagicMock()
20+
# Use the endpoint fixture value via the function parameter
21+
mock_connection.endpoint = "localhost:2136" # Will be overridden in tests
22+
mock_connection.node_id = "mock_node_id"
23+
mock_factory.return_value = mock_connection
24+
yield mock_factory
25+
26+
27+
@pytest.fixture
28+
def mock_aio_connection():
29+
"""Mock a YDB async connection to avoid actual connections."""
30+
with unittest.mock.patch('ydb.aio.connection.Connection.__init__') as mock_init:
31+
# Setup the mock to return None (as __init__ does)
32+
mock_init.return_value = None
33+
34+
# Mock connection_ready method
35+
with unittest.mock.patch('ydb.aio.connection.Connection.connection_ready') as mock_ready:
36+
# Create event loop if there isn't one currently
37+
loop = asyncio.new_event_loop()
38+
asyncio.set_event_loop(loop)
39+
40+
future = asyncio.Future()
41+
future.set_result(None)
42+
mock_ready.return_value = future
43+
yield mock_init
44+
45+
46+
def create_mock_discovery_resolver(path):
47+
"""Create a mock discovery resolver that raises exception if called."""
48+
def _mock_fixture():
49+
with unittest.mock.patch(path) as mock_resolve:
50+
# Configure mock to throw an exception if called
51+
mock_resolve.side_effect = Exception(DISCOVERY_DISABLED_ERROR_MSG)
52+
yield mock_resolve
53+
return _mock_fixture
54+
55+
56+
# Mock discovery resolvers to verify no discovery requests are made
57+
mock_discovery_resolver = pytest.fixture(create_mock_discovery_resolver('ydb.resolver.DiscoveryEndpointsResolver.context_resolve'))
58+
mock_aio_discovery_resolver = pytest.fixture(create_mock_discovery_resolver('ydb.aio.resolver.DiscoveryEndpointsResolver.resolve'))
59+
60+
61+
# We'll use the fixtures from conftest.py instead of these mock fixtures
62+
# However, we'll keep them for tests that don't need the real YDB container
63+
64+
65+
# Basic unit tests for DriverConfig
66+
def test_driver_config_has_disable_discovery_option(endpoint, database):
67+
"""Test that DriverConfig has the disable_discovery option."""
68+
config = ydb.DriverConfig(
69+
endpoint=endpoint,
70+
database=database,
71+
disable_discovery=True
72+
)
73+
assert hasattr(config, "disable_discovery")
74+
assert config.disable_discovery is True
75+
76+
77+
# Driver config fixtures
78+
def create_driver_config(endpoint, database, disable_discovery):
79+
"""Create a driver config with the given discovery setting."""
80+
return ydb.DriverConfig(
81+
endpoint=endpoint,
82+
database=database,
83+
disable_discovery=disable_discovery,
84+
)
85+
86+
87+
@pytest.fixture
88+
def driver_config_disabled_discovery(endpoint, database):
89+
"""A driver config with discovery disabled"""
90+
return create_driver_config(endpoint, database, True)
91+
92+
93+
@pytest.fixture
94+
def driver_config_enabled_discovery(endpoint, database):
95+
"""A driver config with discovery enabled (default)"""
96+
return create_driver_config(endpoint, database, False)
97+
98+
99+
# Common test assertions
100+
def assert_discovery_disabled(driver):
101+
"""Assert that discovery is disabled in the driver."""
102+
assert "Discovery is disabled" in driver.discovery_debug_details()
103+
104+
105+
def create_future_with_error():
106+
"""Create a future with a test error."""
107+
future = asyncio.Future()
108+
future.set_exception(ydb.issues.Error(TEST_ERROR))
109+
return future
110+
111+
112+
def create_completed_future():
113+
"""Create a completed future."""
114+
future = asyncio.Future()
115+
future.set_result(None)
116+
return future
117+
118+
119+
# Mock tests for synchronous driver
120+
def test_sync_driver_discovery_disabled_mock(driver_config_disabled_discovery, mock_connection, mock_discovery_resolver):
121+
"""Test that when disable_discovery=True, the discovery thread is not started and resolver is not called (mock)."""
122+
with unittest.mock.patch('ydb.pool.Discovery') as mock_discovery_class:
123+
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
124+
125+
try:
126+
# Check that the discovery thread was not created
127+
mock_discovery_class.assert_not_called()
128+
129+
# Check that discovery is disabled in debug details
130+
assert_discovery_disabled(driver)
131+
132+
# Execute a dummy call to verify no discovery requests are made
133+
try:
134+
driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation")
135+
except ydb.issues.Error:
136+
pass # Expected exception, we just want to ensure no discovery occurs
137+
138+
# Verify the mock wasn't called
139+
assert not mock_discovery_resolver.called, "Discovery resolver should not be called when discovery is disabled"
140+
finally:
141+
# Clean up
142+
driver.stop()
143+
144+
145+
def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection):
146+
"""Test that when disable_discovery=False, the discovery thread is started (mock)."""
147+
with unittest.mock.patch('ydb.pool.Discovery') as mock_discovery_class:
148+
mock_discovery_instance = unittest.mock.MagicMock()
149+
mock_discovery_class.return_value = mock_discovery_instance
150+
151+
driver = ydb.Driver(driver_config=driver_config_enabled_discovery)
152+
153+
try:
154+
# Check that the discovery thread was created and started
155+
mock_discovery_class.assert_called_once()
156+
assert mock_discovery_instance.start.called
157+
finally:
158+
# Clean up
159+
driver.stop()
160+
161+
162+
# Helper for setting up async driver test mocks
163+
def setup_async_driver_mocks():
164+
"""Set up common mocks for async driver tests."""
165+
mocks = {}
166+
167+
# Create mock for Discovery class
168+
discovery_patcher = unittest.mock.patch('ydb.aio.pool.Discovery')
169+
mocks['mock_discovery_class'] = discovery_patcher.start()
170+
171+
# Mock the event loop
172+
loop_patcher = unittest.mock.patch('asyncio.get_event_loop')
173+
mock_loop = loop_patcher.start()
174+
mock_loop_instance = unittest.mock.MagicMock()
175+
mock_loop.return_value = mock_loop_instance
176+
mock_loop_instance.create_task.return_value = unittest.mock.MagicMock()
177+
mocks['mock_loop'] = mock_loop
178+
179+
# Mock the connection pool stop method
180+
stop_patcher = unittest.mock.patch('ydb.aio.pool.ConnectionPool.stop')
181+
mock_stop = stop_patcher.start()
182+
mock_stop.return_value = create_completed_future()
183+
mocks['mock_stop'] = mock_stop
184+
185+
# Add cleanup for all patchers
186+
mocks['patchers'] = [discovery_patcher, loop_patcher, stop_patcher]
187+
188+
return mocks
189+
190+
191+
def teardown_async_mocks(mocks):
192+
"""Clean up all mock patchers."""
193+
for patcher in mocks['patchers']:
194+
patcher.stop()
195+
196+
197+
# Mock tests for asynchronous driver
198+
@pytest.mark.asyncio
199+
async def test_aio_driver_discovery_disabled_mock(driver_config_disabled_discovery, mock_aio_connection, mock_aio_discovery_resolver):
200+
"""Test that when disable_discovery=True, the discovery is not created and resolver is not called (mock)."""
201+
mocks = setup_async_driver_mocks()
202+
203+
try:
204+
# Mock the pool's call method to prevent unhandled exceptions
205+
with unittest.mock.patch('ydb.aio.pool.ConnectionPool.__call__') as mock_call:
206+
mock_call.return_value = create_future_with_error()
207+
208+
driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)
209+
210+
try:
211+
# Check that the discovery class was not instantiated
212+
mocks['mock_discovery_class'].assert_not_called()
213+
214+
# Check that discovery is disabled in debug details
215+
assert_discovery_disabled(driver)
216+
217+
# Execute a dummy call to verify no discovery requests are made
218+
try:
219+
try:
220+
await driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation")
221+
except ydb.issues.Error:
222+
pass # Expected exception, we just want to ensure no discovery occurs
223+
except Exception as e:
224+
if "discovery is disabled" in str(e).lower():
225+
raise # If the error is related to discovery being disabled, re-raise it
226+
pass # Other exceptions are expected as we're using mocks
227+
228+
# Verify the mock wasn't called
229+
assert not mock_aio_discovery_resolver.called, "Discovery resolver should not be called when discovery is disabled"
230+
finally:
231+
# The stop method is already mocked, so we don't need to await it
232+
pass
233+
finally:
234+
teardown_async_mocks(mocks)
235+
236+
237+
@pytest.mark.asyncio
238+
async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection):
239+
"""Test that when disable_discovery=False, the discovery is created (mock)."""
240+
mocks = setup_async_driver_mocks()
241+
242+
try:
243+
mock_discovery_instance = unittest.mock.MagicMock()
244+
mocks['mock_discovery_class'].return_value = mock_discovery_instance
245+
246+
driver = ydb.aio.Driver(driver_config=driver_config_enabled_discovery)
247+
248+
try:
249+
# Check that the discovery class was instantiated
250+
mocks['mock_discovery_class'].assert_called_once()
251+
finally:
252+
# The stop method is already mocked, so we don't need to await it
253+
pass
254+
finally:
255+
teardown_async_mocks(mocks)
256+
257+
258+
# Common integration test logic
259+
def perform_integration_test_checks(driver, is_async=False):
260+
"""Common assertions for integration tests."""
261+
assert_discovery_disabled(driver)
262+
263+
264+
# Integration tests with real YDB
265+
def test_integration_disable_discovery(driver_config_disabled_discovery):
266+
"""Integration test that tests the disable_discovery feature with a real YDB container."""
267+
# Create driver with discovery disabled
268+
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
269+
try:
270+
driver.wait(timeout=15)
271+
perform_integration_test_checks(driver)
272+
273+
# Try to execute a simple query to ensure it works with discovery disabled
274+
with ydb.SessionPool(driver) as pool:
275+
def query_callback(session):
276+
result_sets = session.transaction().execute(TEST_QUERY, commit_tx=True)
277+
assert len(result_sets) == 1
278+
assert result_sets[0].rows[0].sum == 3
279+
280+
pool.retry_operation_sync(query_callback)
281+
finally:
282+
driver.stop(timeout=10)
283+
284+
285+
@pytest.mark.asyncio
286+
async def test_integration_aio_disable_discovery(driver_config_disabled_discovery):
287+
"""Integration test that tests the disable_discovery feature with a real YDB container (async)."""
288+
# Create driver with discovery disabled
289+
driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)
290+
try:
291+
await driver.wait(timeout=15)
292+
perform_integration_test_checks(driver, is_async=True)
293+
294+
# Try to execute a simple query to ensure it works with discovery disabled
295+
session_pool = ydb.aio.SessionPool(driver, size=10)
296+
297+
async def query_callback(session):
298+
result_sets = await session.transaction().execute(TEST_QUERY, commit_tx=True)
299+
assert len(result_sets) == 1
300+
assert result_sets[0].rows[0].sum == 3
301+
302+
try:
303+
await session_pool.retry_operation(query_callback)
304+
finally:
305+
await session_pool.stop()
306+
finally:
307+
await driver.stop(timeout=10)

ydb/aio/pool.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,25 @@ def __init__(self, driver_config):
199199
self._store = ConnectionsCache(driver_config.use_all_nodes)
200200
self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config)
201201
self._stopped = False
202-
self._discovery = Discovery(self._store, self._driver_config)
203202

204-
self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())
203+
if driver_config.disable_discovery:
204+
# If discovery is disabled, just add the initial endpoint to the store
205+
async def init_connection():
206+
ready_connection = Connection(self._driver_config.endpoint, self._driver_config)
207+
await ready_connection.connection_ready(ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10))
208+
self._store.add(ready_connection)
209+
210+
# Create and schedule the task to initialize the connection
211+
self._discovery = None
212+
self._discovery_task = asyncio.get_event_loop().create_task(init_connection())
213+
else:
214+
# Start discovery as usual
215+
self._discovery = Discovery(self._store, self._driver_config)
216+
self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())
205217

206218
async def stop(self, timeout=10):
207-
self._discovery.stop()
219+
if self._discovery:
220+
self._discovery.stop()
208221
await self._grpc_init.close()
209222
try:
210223
await asyncio.wait_for(self._discovery_task, timeout=timeout)
@@ -215,15 +228,18 @@ async def stop(self, timeout=10):
215228
def _on_disconnected(self, connection):
216229
async def __wrapper__():
217230
await connection.close()
218-
self._discovery.notify_disconnected()
231+
if self._discovery:
232+
self._discovery.notify_disconnected()
219233

220234
return __wrapper__
221235

222236
async def wait(self, timeout=7, fail_fast=False):
223237
await self._store.get(fast_fail=fail_fast, wait_timeout=timeout)
224238

225239
def discovery_debug_details(self):
226-
return self._discovery.discovery_debug_details()
240+
if self._discovery:
241+
return self._discovery.discovery_debug_details()
242+
return "Discovery is disabled, using only the initial endpoint"
227243

228244
async def __aenter__(self):
229245
return self
@@ -248,7 +264,8 @@ async def __call__(
248264
try:
249265
connection = await self._store.get(preferred_endpoint, fast_fail=fast_fail, wait_timeout=wait_timeout)
250266
except BaseException:
251-
self._discovery.notify_disconnected()
267+
if self._discovery:
268+
self._discovery.notify_disconnected()
252269
raise
253270

254271
return await connection(

0 commit comments

Comments
 (0)