Skip to content

Commit 4edc179

Browse files
authored
#325: add disable_discovery option to DriverConfig (#666)
* #325: add disable_discovery option to DriverConfig * minor * black * flake8 * black 2
1 parent 44f7f6d commit 4edc179

File tree

4 files changed

+368
-11
lines changed

4 files changed

+368
-11
lines changed

tests/test_disable_discovery.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
import pytest
2+
import unittest.mock
3+
import ydb
4+
import asyncio
5+
from ydb import _apis
6+
7+
8+
TEST_ERROR = "Test error"
9+
TEST_QUERY = "SELECT 1 + 2 AS sum"
10+
11+
12+
@pytest.fixture
13+
def mock_connection():
14+
"""Mock a YDB connection to avoid actual connections."""
15+
with unittest.mock.patch("ydb.connection.Connection.ready_factory") as mock_factory:
16+
# Setup the mock to return a connection-like object
17+
mock_connection = unittest.mock.MagicMock()
18+
# Use the endpoint fixture value via the function parameter
19+
mock_connection.endpoint = "localhost:2136" # Will be overridden in tests
20+
mock_connection.node_id = "mock_node_id"
21+
mock_factory.return_value = mock_connection
22+
yield mock_factory
23+
24+
25+
@pytest.fixture
26+
def mock_aio_connection():
27+
"""Mock a YDB async connection to avoid actual connections."""
28+
with unittest.mock.patch("ydb.aio.connection.Connection.__init__") as mock_init:
29+
# Setup the mock to return None (as __init__ does)
30+
mock_init.return_value = None
31+
32+
# Mock connection_ready method
33+
with unittest.mock.patch("ydb.aio.connection.Connection.connection_ready") as mock_ready:
34+
# Create event loop if there isn't one currently
35+
loop = asyncio.new_event_loop()
36+
asyncio.set_event_loop(loop)
37+
38+
future = asyncio.Future()
39+
future.set_result(None)
40+
mock_ready.return_value = future
41+
yield mock_init
42+
43+
44+
def create_mock_discovery_resolver(path):
45+
"""Create a mock discovery resolver that raises exception if called."""
46+
47+
def _mock_fixture():
48+
with unittest.mock.patch(path) as mock_resolve:
49+
# Configure mock to throw an exception if called
50+
mock_resolve.side_effect = Exception("Discovery should not be executed when discovery is disabled")
51+
yield mock_resolve
52+
53+
return _mock_fixture
54+
55+
56+
# Mock discovery resolvers to verify no discovery requests are made
57+
mock_discovery_resolver = pytest.fixture(
58+
create_mock_discovery_resolver("ydb.resolver.DiscoveryEndpointsResolver.context_resolve")
59+
)
60+
mock_aio_discovery_resolver = pytest.fixture(
61+
create_mock_discovery_resolver("ydb.aio.resolver.DiscoveryEndpointsResolver.resolve")
62+
)
63+
64+
65+
# Basic unit tests for DriverConfig
66+
def test_driver_config_has_disable_discovery_option(endpoint, database):
67+
"""Test that DriverConfig has the disable_discovery option."""
68+
config = ydb.DriverConfig(endpoint=endpoint, database=database, disable_discovery=True)
69+
assert hasattr(config, "disable_discovery")
70+
assert config.disable_discovery is True
71+
72+
73+
# Driver config fixtures
74+
def create_driver_config(endpoint, database, disable_discovery):
75+
"""Create a driver config with the given discovery setting."""
76+
return ydb.DriverConfig(
77+
endpoint=endpoint,
78+
database=database,
79+
disable_discovery=disable_discovery,
80+
)
81+
82+
83+
@pytest.fixture
84+
def driver_config_disabled_discovery(endpoint, database):
85+
"""A driver config with discovery disabled"""
86+
return create_driver_config(endpoint, database, True)
87+
88+
89+
@pytest.fixture
90+
def driver_config_enabled_discovery(endpoint, database):
91+
"""A driver config with discovery enabled (default)"""
92+
return create_driver_config(endpoint, database, False)
93+
94+
95+
# Common test assertions
96+
def assert_discovery_disabled(driver):
97+
"""Assert that discovery is disabled in the driver."""
98+
assert "Discovery is disabled" in driver.discovery_debug_details()
99+
100+
101+
def create_future_with_error():
102+
"""Create a future with a test error."""
103+
future = asyncio.Future()
104+
future.set_exception(ydb.issues.Error(TEST_ERROR))
105+
return future
106+
107+
108+
def create_completed_future():
109+
"""Create a completed future."""
110+
future = asyncio.Future()
111+
future.set_result(None)
112+
return future
113+
114+
115+
# Mock tests for synchronous driver
116+
def test_sync_driver_discovery_disabled_mock(
117+
driver_config_disabled_discovery, mock_connection, mock_discovery_resolver
118+
):
119+
"""Test that when disable_discovery=True, the discovery thread is not started and resolver is not called (mock)."""
120+
with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class:
121+
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
122+
123+
try:
124+
# Check that the discovery thread was not created
125+
mock_discovery_class.assert_not_called()
126+
127+
# Check that discovery is disabled in debug details
128+
assert_discovery_disabled(driver)
129+
130+
# Execute a dummy call to verify no discovery requests are made
131+
try:
132+
driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation")
133+
except ydb.issues.Error:
134+
pass # Expected exception, we just want to ensure no discovery occurs
135+
136+
# Verify the mock wasn't called
137+
assert (
138+
not mock_discovery_resolver.called
139+
), "Discovery resolver should not be called when discovery is disabled"
140+
finally:
141+
# Clean up
142+
driver.stop()
143+
144+
145+
def test_sync_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_connection):
146+
"""Test that when disable_discovery=False, the discovery thread is started (mock)."""
147+
with unittest.mock.patch("ydb.pool.Discovery") as mock_discovery_class:
148+
mock_discovery_instance = unittest.mock.MagicMock()
149+
mock_discovery_class.return_value = mock_discovery_instance
150+
151+
driver = ydb.Driver(driver_config=driver_config_enabled_discovery)
152+
153+
try:
154+
# Check that the discovery thread was created and started
155+
mock_discovery_class.assert_called_once()
156+
assert mock_discovery_instance.start.called
157+
finally:
158+
# Clean up
159+
driver.stop()
160+
161+
162+
# Helper for setting up async driver test mocks
163+
def setup_async_driver_mocks():
164+
"""Set up common mocks for async driver tests."""
165+
mocks = {}
166+
167+
# Create mock for Discovery class
168+
discovery_patcher = unittest.mock.patch("ydb.aio.pool.Discovery")
169+
mocks["mock_discovery_class"] = discovery_patcher.start()
170+
171+
# Mock the event loop
172+
loop_patcher = unittest.mock.patch("asyncio.get_event_loop")
173+
mock_loop = loop_patcher.start()
174+
mock_loop_instance = unittest.mock.MagicMock()
175+
mock_loop.return_value = mock_loop_instance
176+
mock_loop_instance.create_task.return_value = unittest.mock.MagicMock()
177+
mocks["mock_loop"] = mock_loop
178+
179+
# Mock the connection pool stop method
180+
stop_patcher = unittest.mock.patch("ydb.aio.pool.ConnectionPool.stop")
181+
mock_stop = stop_patcher.start()
182+
mock_stop.return_value = create_completed_future()
183+
mocks["mock_stop"] = mock_stop
184+
185+
# Add cleanup for all patchers
186+
mocks["patchers"] = [discovery_patcher, loop_patcher, stop_patcher]
187+
188+
return mocks
189+
190+
191+
def teardown_async_mocks(mocks):
192+
"""Clean up all mock patchers."""
193+
for patcher in mocks["patchers"]:
194+
patcher.stop()
195+
196+
197+
# Mock tests for asynchronous driver
198+
@pytest.mark.asyncio
199+
async def test_aio_driver_discovery_disabled_mock(
200+
driver_config_disabled_discovery, mock_aio_connection, mock_aio_discovery_resolver
201+
):
202+
"""Test that when disable_discovery=True, the discovery is not created and resolver is not called (mock)."""
203+
mocks = setup_async_driver_mocks()
204+
205+
try:
206+
# Mock the pool's call method to prevent unhandled exceptions
207+
with unittest.mock.patch("ydb.aio.pool.ConnectionPool.__call__") as mock_call:
208+
mock_call.return_value = create_future_with_error()
209+
210+
driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)
211+
212+
try:
213+
# Check that the discovery class was not instantiated
214+
mocks["mock_discovery_class"].assert_not_called()
215+
216+
# Check that discovery is disabled in debug details
217+
assert_discovery_disabled(driver)
218+
219+
# Execute a dummy call to verify no discovery requests are made
220+
try:
221+
try:
222+
await driver(ydb.issues.Error(TEST_ERROR), _apis.OperationService.Stub, "GetOperation")
223+
except ydb.issues.Error:
224+
pass # Expected exception, we just want to ensure no discovery occurs
225+
except Exception as e:
226+
if "discovery is disabled" in str(e).lower():
227+
raise # If the error is related to discovery being disabled, re-raise it
228+
pass # Other exceptions are expected as we're using mocks
229+
230+
# Verify the mock wasn't called
231+
assert (
232+
not mock_aio_discovery_resolver.called
233+
), "Discovery resolver should not be called when discovery is disabled"
234+
finally:
235+
# The stop method is already mocked, so we don't need to await it
236+
pass
237+
finally:
238+
teardown_async_mocks(mocks)
239+
240+
241+
@pytest.mark.asyncio
242+
async def test_aio_driver_discovery_enabled_mock(driver_config_enabled_discovery, mock_aio_connection):
243+
"""Test that when disable_discovery=False, the discovery is created (mock)."""
244+
mocks = setup_async_driver_mocks()
245+
246+
try:
247+
mock_discovery_instance = unittest.mock.MagicMock()
248+
mocks["mock_discovery_class"].return_value = mock_discovery_instance
249+
250+
driver = ydb.aio.Driver(driver_config=driver_config_enabled_discovery)
251+
252+
try:
253+
# Check that the discovery class was instantiated
254+
mocks["mock_discovery_class"].assert_called_once()
255+
assert driver is not None # Use the driver variable to avoid F841
256+
finally:
257+
# The stop method is already mocked, so we don't need to await it
258+
pass
259+
finally:
260+
teardown_async_mocks(mocks)
261+
262+
263+
# Common integration test logic
264+
def perform_integration_test_checks(driver, is_async=False):
265+
"""Common assertions for integration tests."""
266+
assert_discovery_disabled(driver)
267+
268+
269+
# Integration tests with real YDB
270+
def test_integration_disable_discovery(driver_config_disabled_discovery):
271+
"""Integration test that tests the disable_discovery feature with a real YDB container."""
272+
# Create driver with discovery disabled
273+
driver = ydb.Driver(driver_config=driver_config_disabled_discovery)
274+
try:
275+
driver.wait(timeout=15)
276+
perform_integration_test_checks(driver)
277+
278+
# Try to execute a simple query to ensure it works with discovery disabled
279+
with ydb.SessionPool(driver) as pool:
280+
281+
def query_callback(session):
282+
result_sets = session.transaction().execute(TEST_QUERY, commit_tx=True)
283+
assert len(result_sets) == 1
284+
assert result_sets[0].rows[0].sum == 3
285+
286+
pool.retry_operation_sync(query_callback)
287+
finally:
288+
driver.stop(timeout=10)
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_integration_aio_disable_discovery(driver_config_disabled_discovery):
293+
"""Integration test that tests the disable_discovery feature with a real YDB container (async)."""
294+
# Create driver with discovery disabled
295+
driver = ydb.aio.Driver(driver_config=driver_config_disabled_discovery)
296+
try:
297+
await driver.wait(timeout=15)
298+
perform_integration_test_checks(driver, is_async=True)
299+
300+
# Try to execute a simple query to ensure it works with discovery disabled
301+
session_pool = ydb.aio.SessionPool(driver, size=10)
302+
303+
async def query_callback(session):
304+
result_sets = await session.transaction().execute(TEST_QUERY, commit_tx=True)
305+
assert len(result_sets) == 1
306+
assert result_sets[0].rows[0].sum == 3
307+
308+
try:
309+
await session_pool.retry_operation(query_callback)
310+
finally:
311+
await session_pool.stop()
312+
finally:
313+
await driver.stop(timeout=10)

ydb/aio/pool.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,27 @@ def __init__(self, driver_config):
199199
self._store = ConnectionsCache(driver_config.use_all_nodes)
200200
self._grpc_init = Connection(self._driver_config.endpoint, self._driver_config)
201201
self._stopped = False
202-
self._discovery = Discovery(self._store, self._driver_config)
203202

204-
self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())
203+
if driver_config.disable_discovery:
204+
# If discovery is disabled, just add the initial endpoint to the store
205+
async def init_connection():
206+
ready_connection = Connection(self._driver_config.endpoint, self._driver_config)
207+
await ready_connection.connection_ready(
208+
ready_timeout=getattr(self._driver_config, "discovery_request_timeout", 10)
209+
)
210+
self._store.add(ready_connection)
211+
212+
# Create and schedule the task to initialize the connection
213+
self._discovery = None
214+
self._discovery_task = asyncio.get_event_loop().create_task(init_connection())
215+
else:
216+
# Start discovery as usual
217+
self._discovery = Discovery(self._store, self._driver_config)
218+
self._discovery_task = asyncio.get_event_loop().create_task(self._discovery.run())
205219

206220
async def stop(self, timeout=10):
207-
self._discovery.stop()
221+
if self._discovery:
222+
self._discovery.stop()
208223
await self._grpc_init.close()
209224
try:
210225
await asyncio.wait_for(self._discovery_task, timeout=timeout)
@@ -215,15 +230,18 @@ async def stop(self, timeout=10):
215230
def _on_disconnected(self, connection):
216231
async def __wrapper__():
217232
await connection.close()
218-
self._discovery.notify_disconnected()
233+
if self._discovery:
234+
self._discovery.notify_disconnected()
219235

220236
return __wrapper__
221237

222238
async def wait(self, timeout=7, fail_fast=False):
223239
await self._store.get(fast_fail=fail_fast, wait_timeout=timeout)
224240

225241
def discovery_debug_details(self):
226-
return self._discovery.discovery_debug_details()
242+
if self._discovery:
243+
return self._discovery.discovery_debug_details()
244+
return "Discovery is disabled, using only the initial endpoint"
227245

228246
async def __aenter__(self):
229247
return self
@@ -248,7 +266,8 @@ async def __call__(
248266
try:
249267
connection = await self._store.get(preferred_endpoint, fast_fail=fast_fail, wait_timeout=wait_timeout)
250268
except BaseException:
251-
self._discovery.notify_disconnected()
269+
if self._discovery:
270+
self._discovery.notify_disconnected()
252271
raise
253272

254273
return await connection(

0 commit comments

Comments
 (0)