Skip to content

Commit 8149b96

Browse files
authored
Merge pull request #50 Implement ScanQuery from LuckySting/main
Thanks for the PR
2 parents 48cc1c2 + dc7bdea commit 8149b96

File tree

4 files changed

+176
-5
lines changed

4 files changed

+176
-5
lines changed

test/test_core.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,85 @@ def test_several_keys(self, connection, metadata):
442442
assert desc.partitioning_settings.max_partitions_count == 5
443443

444444

445+
class TestScanQuery(TablesTest):
446+
__backend__ = True
447+
448+
@classmethod
449+
def define_tables(cls, metadata: sa.MetaData):
450+
Table(
451+
"test",
452+
metadata,
453+
Column("id", Integer, primary_key=True),
454+
)
455+
456+
@classmethod
457+
def insert_data(cls, connection: sa.Connection):
458+
table = cls.tables.test
459+
for i in range(50):
460+
connection.execute(ydb_sa.upsert(table).values([{"id": i * 1000 + j} for j in range(1000)]))
461+
462+
def test_characteristic(self):
463+
engine = self.bind.execution_options()
464+
465+
with engine.connect() as connection:
466+
default_options = connection.get_execution_options()
467+
468+
with engine.connect() as connection:
469+
connection.execution_options(ydb_scan_query=True)
470+
options_after_set = connection.get_execution_options()
471+
472+
with engine.connect() as connection:
473+
options_after_reset = connection.get_execution_options()
474+
475+
assert "ydb_scan_query" not in default_options
476+
assert options_after_set["ydb_scan_query"]
477+
assert "ydb_scan_query" not in options_after_reset
478+
479+
def test_fetchmany(self, connection_no_trans: sa.Connection):
480+
table = self.tables.test
481+
stmt = sa.select(table).where(table.c.id % 2 == 0)
482+
483+
connection_no_trans.execution_options(ydb_scan_query=True)
484+
cursor = connection_no_trans.execute(stmt)
485+
486+
assert cursor.cursor.use_scan_query
487+
result = cursor.fetchmany(1000) # fetches only the first 5k rows
488+
assert result == [(i,) for i in range(2000) if i % 2 == 0]
489+
490+
def test_fetchall(self, connection_no_trans: sa.Connection):
491+
table = self.tables.test
492+
stmt = sa.select(table).where(table.c.id % 2 == 0)
493+
494+
connection_no_trans.execution_options(ydb_scan_query=True)
495+
cursor = connection_no_trans.execute(stmt)
496+
497+
assert cursor.cursor.use_scan_query
498+
result = cursor.fetchall()
499+
assert result == [(i,) for i in range(50000) if i % 2 == 0]
500+
501+
def test_begin_does_nothing(self, connection_no_trans: sa.Connection):
502+
table = self.tables.test
503+
connection_no_trans.execution_options(ydb_scan_query=True)
504+
505+
with connection_no_trans.begin():
506+
cursor = connection_no_trans.execute(sa.select(table))
507+
508+
assert cursor.cursor.use_scan_query
509+
assert cursor.cursor.tx_context is None
510+
511+
def test_engine_option(self):
512+
table = self.tables.test
513+
engine = self.bind.execution_options(ydb_scan_query=True)
514+
515+
with engine.begin() as connection:
516+
cursor = connection.execute(sa.select(table))
517+
assert cursor.cursor.use_scan_query
518+
519+
with engine.begin() as connection:
520+
cursor = connection.execute(sa.select(table))
521+
assert cursor.cursor.use_scan_query
522+
523+
445524
class TestTransaction(TablesTest):
446525
__backend__ = True
447526

ydb_sqlalchemy/dbapi/connection.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,12 @@ def __init__(
5757
self.interactive_transaction: bool = False # AUTOCOMMIT
5858
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
5959
self.tx_context: Optional[ydb.TxContext] = None
60+
self.use_scan_query: bool = False
6061

6162
def cursor(self):
62-
return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context, self.table_path_prefix)
63+
return self._cursor_class(
64+
self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix
65+
)
6366

6467
def describe(self, table_path: str) -> ydb.TableDescription:
6568
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
@@ -115,9 +118,15 @@ def get_isolation_level(self) -> str:
115118
else:
116119
raise NotSupportedError(f"{self.tx_mode.name} is not supported")
117120

121+
def set_ydb_scan_query(self, value: bool) -> None:
122+
self.use_scan_query = value
123+
124+
def get_ydb_scan_query(self) -> bool:
125+
return self.use_scan_query
126+
118127
def begin(self):
119128
self.tx_context = None
120-
if self.interactive_transaction:
129+
if self.interactive_transaction and not self.use_scan_query:
121130
session = self._maybe_await(self.session_pool.acquire)
122131
self.tx_context = session.transaction(self.tx_mode)
123132
self._maybe_await(self.tx_context.begin)

ydb_sqlalchemy/dbapi/cursor.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,17 @@
55
import itertools
66
import logging
77
import posixpath
8-
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
8+
from collections.abc import AsyncIterator
9+
from typing import (
10+
Any,
11+
Dict,
12+
Generator,
13+
List,
14+
Mapping,
15+
Optional,
16+
Sequence,
17+
Union,
18+
)
919

1020
import ydb
1121
import ydb.aio
@@ -77,14 +87,18 @@ def wrapper(*args, **kwargs):
7787
class Cursor:
7888
def __init__(
7989
self,
90+
driver: Union[ydb.Driver, ydb.aio.Driver],
8091
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
8192
tx_mode: ydb.AbstractTransactionModeBuilder,
8293
tx_context: Optional[ydb.BaseTxContext] = None,
94+
use_scan_query: bool = False,
8395
table_path_prefix: str = "",
8496
):
97+
self.driver = driver
8598
self.session_pool = session_pool
8699
self.tx_mode = tx_mode
87100
self.tx_context = tx_context
101+
self.use_scan_query = use_scan_query
88102
self.description = None
89103
self.arraysize = 1
90104
self.rows = None
@@ -120,6 +134,8 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] =
120134
logger.info("execute sql: %s, params: %s", query, parameters)
121135
if operation.is_ddl:
122136
chunks = self._execute_ddl(query)
137+
elif self.use_scan_query:
138+
chunks = self._execute_scan_query(query, parameters)
123139
else:
124140
chunks = self._execute_dml(query, parameters)
125141

@@ -164,6 +180,21 @@ def _make_data_query(
164180
name = hashlib.sha256(yql_with_params.encode("utf-8")).hexdigest()
165181
return ydb.DataQuery(yql_text, parameters_types, name=name)
166182

183+
@_handle_ydb_errors
184+
def _execute_scan_query(
185+
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
186+
) -> Generator[ydb.convert.ResultSet, None, None]:
187+
prepared_query = query
188+
if isinstance(query, str) and parameters:
189+
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query)
190+
191+
if isinstance(query, str):
192+
scan_query = ydb.ScanQuery(query, None)
193+
else:
194+
scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types)
195+
196+
return self._execute_scan_query_in_driver(scan_query, parameters)
197+
167198
@_handle_ydb_errors
168199
def _execute_dml(
169200
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
@@ -219,6 +250,15 @@ def _execute_in_session(
219250
) -> ydb.convert.ResultSets:
220251
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
221252

253+
def _execute_scan_query_in_driver(
254+
self,
255+
scan_query: ydb.ScanQuery,
256+
parameters: Optional[Mapping[str, Any]],
257+
) -> Generator[ydb.convert.ResultSet, None, None]:
258+
chunk: ydb.ScanQueryResult
259+
for chunk in self.driver.table_client.scan_query(scan_query, parameters):
260+
yield chunk.result_set
261+
222262
def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
223263
return callee(self.tx_context, *args, **kwargs)
224264

@@ -264,7 +304,7 @@ def executescript(self, script):
264304
return self.execute(script)
265305

266306
def fetchone(self):
267-
return next(self.rows or [], None)
307+
return next(self.rows or iter([]), None)
268308

269309
def fetchmany(self, size=None):
270310
return list(itertools.islice(self.rows, size or self.arraysize))
@@ -328,6 +368,21 @@ async def _execute_in_session(
328368
) -> ydb.convert.ResultSets:
329369
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
330370

371+
def _execute_scan_query_in_driver(
372+
self,
373+
scan_query: ydb.ScanQuery,
374+
parameters: Optional[Mapping[str, Any]],
375+
) -> Generator[ydb.convert.ResultSet, None, None]:
376+
iterator: AsyncIterator[ydb.ScanQueryResult] = self._await(
377+
self.driver.table_client.scan_query(scan_query, parameters)
378+
)
379+
while True:
380+
try:
381+
result = self._await(iterator.__anext__())
382+
yield result.result_set
383+
except StopAsyncIteration:
384+
break
385+
331386
def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs):
332387
return self._await(callee(self.tx_context, *args, **kwargs))
333388

ydb_sqlalchemy/sqlalchemy/__init__.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import sqlalchemy as sa
1111
import ydb
12-
from sqlalchemy.engine import reflection
12+
from sqlalchemy import util
13+
from sqlalchemy.engine import characteristics, reflection
1314
from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect
1415
from sqlalchemy.exc import CompileError, NoSuchTableError
1516
from sqlalchemy.sql import functions, literal_column
@@ -557,6 +558,17 @@ def _get_column_info(t):
557558
return COLUMN_TYPES[t], nullable
558559

559560

561+
class YdbScanQueryCharacteristic(characteristics.ConnectionCharacteristic):
562+
def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> None:
563+
dialect.reset_ydb_scan_query(dbapi_connection)
564+
565+
def set_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection, value: bool) -> None:
566+
dialect.set_ydb_scan_query(dbapi_connection, value)
567+
568+
def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> Any:
569+
dialect.get_ydb_scan_query(dbapi_connection)
570+
571+
560572
class YqlDialect(StrCompileDialect):
561573
name = "yql"
562574
driver = "ydb"
@@ -600,6 +612,13 @@ class YqlDialect(StrCompileDialect):
600612
sa.types.DateTime: types.YqlDateTime,
601613
}
602614

615+
connection_characteristics = util.immutabledict(
616+
{
617+
"isolation_level": characteristics.IsolationLevelCharacteristic(),
618+
"ydb_scan_query": YdbScanQueryCharacteristic(),
619+
}
620+
)
621+
603622
construct_arguments = [
604623
(
605624
sa.schema.Table,
@@ -723,6 +742,15 @@ def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str:
723742
def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str:
724743
return dbapi_connection.get_isolation_level()
725744

745+
def set_ydb_scan_query(self, dbapi_connection: dbapi.Connection, value: bool) -> None:
746+
dbapi_connection.set_ydb_scan_query(value)
747+
748+
def reset_ydb_scan_query(self, dbapi_connection: dbapi.Connection):
749+
self.set_ydb_scan_query(dbapi_connection, False)
750+
751+
def get_ydb_scan_query(self, dbapi_connection: dbapi.Connection) -> str:
752+
return dbapi_connection.get_ydb_scan_query()
753+
726754
def connect(self, *cargs, **cparams):
727755
return self.loaded_dbapi.connect(*cargs, **cparams)
728756

0 commit comments

Comments
 (0)