Skip to content

Implement ScanQuery #50

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,85 @@ def test_several_keys(self, connection, metadata):
assert desc.partitioning_settings.max_partitions_count == 5


class TestScanQuery(TablesTest):
__backend__ = True

@classmethod
def define_tables(cls, metadata: sa.MetaData):
Table(
"test",
metadata,
Column("id", Integer, primary_key=True),
)

@classmethod
def insert_data(cls, connection: sa.Connection):
table = cls.tables.test
for i in range(50):
connection.execute(ydb_sa.upsert(table).values([{"id": i * 1000 + j} for j in range(1000)]))

def test_characteristic(self):
engine = self.bind.execution_options()

with engine.connect() as connection:
default_options = connection.get_execution_options()

with engine.connect() as connection:
connection.execution_options(ydb_scan_query=True)
options_after_set = connection.get_execution_options()

with engine.connect() as connection:
options_after_reset = connection.get_execution_options()

assert "ydb_scan_query" not in default_options
assert options_after_set["ydb_scan_query"]
assert "ydb_scan_query" not in options_after_reset

def test_fetchmany(self, connection_no_trans: sa.Connection):
table = self.tables.test
stmt = sa.select(table).where(table.c.id % 2 == 0)

connection_no_trans.execution_options(ydb_scan_query=True)
cursor = connection_no_trans.execute(stmt)

assert cursor.cursor.use_scan_query
result = cursor.fetchmany(1000) # fetches only the first 5k rows
assert result == [(i,) for i in range(2000) if i % 2 == 0]

def test_fetchall(self, connection_no_trans: sa.Connection):
table = self.tables.test
stmt = sa.select(table).where(table.c.id % 2 == 0)

connection_no_trans.execution_options(ydb_scan_query=True)
cursor = connection_no_trans.execute(stmt)

assert cursor.cursor.use_scan_query
result = cursor.fetchall()
assert result == [(i,) for i in range(50000) if i % 2 == 0]

def test_begin_does_nothing(self, connection_no_trans: sa.Connection):
table = self.tables.test
connection_no_trans.execution_options(ydb_scan_query=True)

with connection_no_trans.begin():
cursor = connection_no_trans.execute(sa.select(table))

assert cursor.cursor.use_scan_query
assert cursor.cursor.tx_context is None

def test_engine_option(self):
table = self.tables.test
engine = self.bind.execution_options(ydb_scan_query=True)

with engine.begin() as connection:
cursor = connection.execute(sa.select(table))
assert cursor.cursor.use_scan_query

with engine.begin() as connection:
cursor = connection.execute(sa.select(table))
assert cursor.cursor.use_scan_query


class TestTransaction(TablesTest):
__backend__ = True

Expand Down
13 changes: 11 additions & 2 deletions ydb_sqlalchemy/dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,12 @@ def __init__(
self.interactive_transaction: bool = False # AUTOCOMMIT
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
self.tx_context: Optional[ydb.TxContext] = None
self.use_scan_query: bool = False

def cursor(self):
return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context, self.table_path_prefix)
return self._cursor_class(
self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix
)

def describe(self, table_path: str) -> ydb.TableDescription:
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
Expand Down Expand Up @@ -115,9 +118,15 @@ def get_isolation_level(self) -> str:
else:
raise NotSupportedError(f"{self.tx_mode.name} is not supported")

def set_ydb_scan_query(self, value: bool) -> None:
self.use_scan_query = value

def get_ydb_scan_query(self) -> bool:
return self.use_scan_query

def begin(self):
self.tx_context = None
if self.interactive_transaction:
if self.interactive_transaction and not self.use_scan_query:
session = self._maybe_await(self.session_pool.acquire)
self.tx_context = session.transaction(self.tx_mode)
self._maybe_await(self.tx_context.begin)
Expand Down
59 changes: 57 additions & 2 deletions ydb_sqlalchemy/dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,17 @@
import itertools
import logging
import posixpath
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from collections.abc import AsyncIterator
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Sequence,
Union,
)

import ydb
import ydb.aio
Expand Down Expand Up @@ -77,14 +87,18 @@ def wrapper(*args, **kwargs):
class Cursor:
def __init__(
self,
driver: Union[ydb.Driver, ydb.aio.Driver],
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
tx_mode: ydb.AbstractTransactionModeBuilder,
tx_context: Optional[ydb.BaseTxContext] = None,
use_scan_query: bool = False,
table_path_prefix: str = "",
):
self.driver = driver
self.session_pool = session_pool
self.tx_mode = tx_mode
self.tx_context = tx_context
self.use_scan_query = use_scan_query
self.description = None
self.arraysize = 1
self.rows = None
Expand Down Expand Up @@ -120,6 +134,8 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
logger.info("execute sql: %s, params: %s", query, parameters)
if operation.is_ddl:
chunks = self._execute_ddl(query)
elif self.use_scan_query:
chunks = self._execute_scan_query(query, parameters)
else:
chunks = self._execute_dml(query, parameters)

Expand Down Expand Up @@ -164,6 +180,21 @@ def _make_data_query(
name = hashlib.sha256(yql_with_params.encode("utf-8")).hexdigest()
return ydb.DataQuery(yql_text, parameters_types, name=name)

@_handle_ydb_errors
def _execute_scan_query(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
) -> Generator[ydb.convert.ResultSet, None, None]:
prepared_query = query
if isinstance(query, str) and parameters:
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query)

if isinstance(query, str):
scan_query = ydb.ScanQuery(query, None)
else:
scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types)

return self._execute_scan_query_in_driver(scan_query, parameters)

@_handle_ydb_errors
def _execute_dml(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
Expand Down Expand Up @@ -219,6 +250,15 @@ def _execute_in_session(
) -> ydb.convert.ResultSets:
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
) -> Generator[ydb.convert.ResultSet, None, None]:
chunk: ydb.ScanQueryResult
for chunk in self.driver.table_client.scan_query(scan_query, parameters):
yield chunk.result_set

def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
return callee(self.tx_context, *args, **kwargs)

Expand Down Expand Up @@ -264,7 +304,7 @@ def executescript(self, script):
return self.execute(script)

def fetchone(self):
return next(self.rows or [], None)
return next(self.rows or iter([]), None)

def fetchmany(self, size=None):
return list(itertools.islice(self.rows, size or self.arraysize))
Expand Down Expand Up @@ -328,6 +368,21 @@ async def _execute_in_session(
) -> ydb.convert.ResultSets:
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
) -> Generator[ydb.convert.ResultSet, None, None]:
iterator: AsyncIterator[ydb.ScanQueryResult] = self._await(
self.driver.table_client.scan_query(scan_query, parameters)
)
while True:
try:
result = self._await(iterator.__anext__())
yield result.result_set
except StopAsyncIteration:
break

def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs):
return self._await(callee(self.tx_context, *args, **kwargs))

Expand Down
30 changes: 29 additions & 1 deletion ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import sqlalchemy as sa
import ydb
from sqlalchemy.engine import reflection
from sqlalchemy import util
from sqlalchemy.engine import characteristics, reflection
from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect
from sqlalchemy.exc import CompileError, NoSuchTableError
from sqlalchemy.sql import functions, literal_column
Expand Down Expand Up @@ -557,6 +558,17 @@ def _get_column_info(t):
return COLUMN_TYPES[t], nullable


class YdbScanQueryCharacteristic(characteristics.ConnectionCharacteristic):
def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> None:
dialect.reset_ydb_scan_query(dbapi_connection)

def set_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection, value: bool) -> None:
dialect.set_ydb_scan_query(dbapi_connection, value)

def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> Any:
dialect.get_ydb_scan_query(dbapi_connection)


class YqlDialect(StrCompileDialect):
name = "yql"
driver = "ydb"
Expand Down Expand Up @@ -600,6 +612,13 @@ class YqlDialect(StrCompileDialect):
sa.types.DateTime: types.YqlDateTime,
}

connection_characteristics = util.immutabledict(
{
"isolation_level": characteristics.IsolationLevelCharacteristic(),
"ydb_scan_query": YdbScanQueryCharacteristic(),
}
)

construct_arguments = [
(
sa.schema.Table,
Expand Down Expand Up @@ -723,6 +742,15 @@ def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str:
def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str:
return dbapi_connection.get_isolation_level()

def set_ydb_scan_query(self, dbapi_connection: dbapi.Connection, value: bool) -> None:
dbapi_connection.set_ydb_scan_query(value)

def reset_ydb_scan_query(self, dbapi_connection: dbapi.Connection):
self.set_ydb_scan_query(dbapi_connection, False)

def get_ydb_scan_query(self, dbapi_connection: dbapi.Connection) -> str:
return dbapi_connection.get_ydb_scan_query()

def connect(self, *cargs, **cparams):
return self.loaded_dbapi.connect(*cargs, **cparams)

Expand Down
Loading