diff --git a/tests/query/test_query_client_settings.py b/tests/query/test_query_client_settings.py new file mode 100644 index 00000000..030cd83d --- /dev/null +++ b/tests/query/test_query_client_settings.py @@ -0,0 +1,143 @@ +import pytest + +import ydb + +from datetime import date, datetime, timedelta + + +@pytest.fixture +def settings_on(): + settings = ( + ydb.QueryClientSettings() + .with_native_date_in_result_sets(True) + .with_native_datetime_in_result_sets(True) + .with_native_timestamp_in_result_sets(True) + .with_native_interval_in_result_sets(True) + .with_native_json_in_result_sets(True) + ) + return settings + + +@pytest.fixture +def settings_off(): + settings = ( + ydb.QueryClientSettings() + .with_native_date_in_result_sets(False) + .with_native_datetime_in_result_sets(False) + .with_native_timestamp_in_result_sets(False) + .with_native_interval_in_result_sets(False) + .with_native_json_in_result_sets(False) + ) + return settings + + +params = pytest.mark.parametrize( + "value,ydb_type,casted_result,uncasted_result", + [ + (365, "Date", date(1971, 1, 1), 365), + (3600 * 24 * 365, "Datetime", datetime(1971, 1, 1, 0, 0), 3600 * 24 * 365), + (timedelta(seconds=1), "Interval", timedelta(seconds=1), 1000000), + ( + 1511789040123456, + "Timestamp", + datetime.fromisoformat("2017-11-27 13:24:00.123456"), + 1511789040123456, + ), + ('{"foo": "bar"}', "Json", {"foo": "bar"}, '{"foo": "bar"}'), + ('{"foo": "bar"}', "JsonDocument", {"foo": "bar"}, '{"foo":"bar"}'), + ], +) + + +class TestQueryClientSettings: + @params + def test_driver_turn_on(self, driver_sync, settings_on, value, ydb_type, casted_result, uncasted_result): + driver_sync._driver_config.query_client_settings = settings_on + pool = ydb.QuerySessionPool(driver_sync) + result = pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == casted_result + pool.stop() + + @params + def test_driver_turn_off(self, driver_sync, settings_off, value, ydb_type, casted_result, uncasted_result): + driver_sync._driver_config.query_client_settings = settings_off + pool = ydb.QuerySessionPool(driver_sync) + result = pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == uncasted_result + pool.stop() + + @params + def test_session_pool_turn_on(self, driver_sync, settings_on, value, ydb_type, casted_result, uncasted_result): + pool = ydb.QuerySessionPool(driver_sync, query_client_settings=settings_on) + result = pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == casted_result + pool.stop() + + @params + def test_session_pool_turn_off(self, driver_sync, settings_off, value, ydb_type, casted_result, uncasted_result): + pool = ydb.QuerySessionPool(driver_sync, query_client_settings=settings_off) + result = pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == uncasted_result + pool.stop() + + @pytest.mark.asyncio + @params + async def test_driver_async_turn_on(self, driver, settings_on, value, ydb_type, casted_result, uncasted_result): + driver._driver_config.query_client_settings = settings_on + pool = ydb.aio.QuerySessionPool(driver) + result = await pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == casted_result + await pool.stop() + + @pytest.mark.asyncio + @params + async def test_driver_async_turn_off(self, driver, settings_off, value, ydb_type, casted_result, uncasted_result): + driver._driver_config.query_client_settings = settings_off + pool = ydb.aio.QuerySessionPool(driver) + result = await pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == uncasted_result + await pool.stop() + + @pytest.mark.asyncio + @params + async def test_session_pool_async_turn_on( + self, driver, settings_on, value, ydb_type, casted_result, uncasted_result + ): + pool = ydb.aio.QuerySessionPool(driver, query_client_settings=settings_on) + result = await pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == casted_result + await pool.stop() + + @pytest.mark.asyncio + @params + async def test_session_pool_async_turn_off( + self, driver, settings_off, value, ydb_type, casted_result, uncasted_result + ): + pool = ydb.aio.QuerySessionPool(driver, query_client_settings=settings_off) + result = await pool.execute_with_retries( + f"DECLARE $param as {ydb_type}; SELECT $param as value", + {"$param": (value, getattr(ydb.PrimitiveType, ydb_type))}, + ) + assert result[0].rows[0].value == uncasted_result + await pool.stop() diff --git a/ydb/aio/query/pool.py b/ydb/aio/query/pool.py index 6d116600..db01adce 100644 --- a/ydb/aio/query/pool.py +++ b/ydb/aio/query/pool.py @@ -13,6 +13,7 @@ RetrySettings, retry_operation_async, ) +from ...query.base import QueryClientSettings from ... import convert from ..._grpc.grpcwrapper import common_utils @@ -22,10 +23,17 @@ class QuerySessionPool: """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" - def __init__(self, driver: common_utils.SupportedDriverType, size: int = 100): + def __init__( + self, + driver: common_utils.SupportedDriverType, + size: int = 100, + *, + query_client_settings: Optional[QueryClientSettings] = None, + ): """ :param driver: A driver instance :param size: Size of session pool + :param query_client_settings: ydb.QueryClientSettings object to configure QueryService behavior """ self._driver = driver @@ -35,9 +43,10 @@ def __init__(self, driver: common_utils.SupportedDriverType, size: int = 100): self._current_size = 0 self._waiters = 0 self._loop = asyncio.get_running_loop() + self._query_client_settings = query_client_settings async def _create_new_session(self): - session = QuerySession(self._driver) + session = QuerySession(self._driver, settings=self._query_client_settings) await session.create() logger.debug(f"New session was created for pool. Session id: {session._state.session_id}") return session diff --git a/ydb/driver.py b/ydb/driver.py index 1559b0d0..5a9402f6 100644 --- a/ydb/driver.py +++ b/ydb/driver.py @@ -89,6 +89,7 @@ class DriverConfig(object): "secure_channel", "table_client_settings", "topic_client_settings", + "query_client_settings", "endpoints", "primary_user_agent", "tracer", @@ -112,6 +113,7 @@ def __init__( grpc_keep_alive_timeout=None, table_client_settings=None, topic_client_settings=None, + query_client_settings=None, endpoints=None, primary_user_agent="python-library", tracer=None, @@ -159,6 +161,7 @@ def __init__( self.grpc_keep_alive_timeout = grpc_keep_alive_timeout self.table_client_settings = table_client_settings self.topic_client_settings = topic_client_settings + self.query_client_settings = query_client_settings self.primary_user_agent = primary_user_agent self.tracer = tracer if tracer is not None else tracing.Tracer(None) self.grpc_lb_policy_name = grpc_lb_policy_name diff --git a/ydb/query/pool.py b/ydb/query/pool.py index 4c51a971..f1fcd173 100644 --- a/ydb/query/pool.py +++ b/ydb/query/pool.py @@ -8,6 +8,7 @@ import threading import queue +from .base import QueryClientSettings from .session import ( QuerySession, ) @@ -27,10 +28,17 @@ class QuerySessionPool: """QuerySessionPool is an object to simplify operations with sessions of Query Service.""" - def __init__(self, driver: common_utils.SupportedDriverType, size: int = 100): + def __init__( + self, + driver: common_utils.SupportedDriverType, + size: int = 100, + *, + query_client_settings: Optional[QueryClientSettings] = None, + ): """ :param driver: A driver instance. :param size: Max size of Session Pool. + :param query_client_settings: ydb.QueryClientSettings object to configure QueryService behavior """ self._driver = driver @@ -39,9 +47,10 @@ def __init__(self, driver: common_utils.SupportedDriverType, size: int = 100): self._size = size self._should_stop = threading.Event() self._lock = threading.RLock() + self._query_client_settings = query_client_settings def _create_new_session(self, timeout: Optional[float]): - session = QuerySession(self._driver) + session = QuerySession(self._driver, settings=self._query_client_settings) session.create(settings=BaseRequestSettings().with_timeout(timeout)) logger.debug(f"New session was created for pool. Session id: {session._state.session_id}") return session diff --git a/ydb/query/session.py b/ydb/query/session.py index e13540d3..0165f821 100644 --- a/ydb/query/session.py +++ b/ydb/query/session.py @@ -134,9 +134,20 @@ class BaseQuerySession: def __init__(self, driver: common_utils.SupportedDriverType, settings: Optional[base.QueryClientSettings] = None): self._driver = driver - self._settings = settings if settings is not None else base.QueryClientSettings() + self._settings = self._get_client_settings(driver, settings) self._state = QuerySessionState(settings) + def _get_client_settings( + self, + driver: common_utils.SupportedDriverType, + settings: Optional[base.QueryClientSettings] = None, + ) -> base.QueryClientSettings: + if settings is not None: + return settings + if driver._driver_config.query_client_settings is not None: + return driver._driver_config.query_client_settings + return base.QueryClientSettings() + def _create_call(self, settings: Optional[BaseRequestSettings] = None) -> "BaseQuerySession": return self._driver( _apis.ydb_query.CreateSessionRequest(),