From 89b99f8ff31507e7e78bc70a4b14c6944a2c8638 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=92=D0=B8=D0=BA=D1=82=D0=BE=D1=80=20=D0=96=D0=B8=D1=80?= =?UTF-8?q?=D0=BD=D0=BE=D0=B2?= Date: Thu, 19 Jun 2025 18:46:55 +0300 Subject: [PATCH 1/2] add alembic support --- ydb_sqlalchemy/sqlalchemy/compiler/sa20.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py index 702d7aa..4ede43c 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py @@ -89,4 +89,21 @@ def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw): class YqlDDLCompiler(BaseYqlDDLCompiler): - ... + def visit_foreign_key_constraint(self, constraint, **kwargs): + return None + + def visit_primary_key_constraint(self, constraint, **kwargs): + if len(constraint) == 0: + return "" + text = "" + text += "PRIMARY KEY " + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in ( + constraint.columns_autoinc_first + if constraint._implicit_generated + else constraint.columns + ) + ) + text += self.define_constraint_deferrability(constraint) + return text From 50deffdc6312954dbe3eb707043f0ca19dc76f75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=92=D0=B8=D0=BA=D1=82=D0=BE=D1=80=20=D0=96=D0=B8=D1=80?= =?UTF-8?q?=D0=BD=D0=BE=D0=B2?= Date: Thu, 19 Jun 2025 19:09:26 +0300 Subject: [PATCH 2/2] reformat code --- test/test_core.py | 6 ++---- test/test_suite.py | 1 - ydb_sqlalchemy/__init__.py | 5 +++-- ydb_sqlalchemy/sqlalchemy/__init__.py | 12 ++++++----- .../sqlalchemy/compiler/__init__.py | 20 +++++++++++-------- ydb_sqlalchemy/sqlalchemy/compiler/base.py | 18 +++-------------- ydb_sqlalchemy/sqlalchemy/compiler/sa14.py | 1 + ydb_sqlalchemy/sqlalchemy/compiler/sa20.py | 10 +++------- ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py | 3 +-- 9 files changed, 32 insertions(+), 44 deletions(-) diff --git a/test/test_core.py b/test/test_core.py index 3f7d808..ab8db17 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -15,11 +15,9 @@ from ydb_sqlalchemy.sqlalchemy import types if sa.__version__ >= "2.": - from sqlalchemy import NullPool - from sqlalchemy import QueuePool + from sqlalchemy import NullPool, QueuePool else: - from sqlalchemy.pool import NullPool - from sqlalchemy.pool import QueuePool + from sqlalchemy.pool import NullPool, QueuePool def clear_sql(stm): diff --git a/test/test_suite.py b/test/test_suite.py index 0ffa6df..e67af8a 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -68,7 +68,6 @@ from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest - from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest from sqlalchemy.testing.suite.test_types import StringTest as _StringTest from sqlalchemy.testing.suite.test_types import ( diff --git a/ydb_sqlalchemy/__init__.py b/ydb_sqlalchemy/__init__.py index 55ade24..62f8b58 100644 --- a/ydb_sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/__init__.py @@ -1,4 +1,5 @@ -from ._version import VERSION # noqa: F401 +import ydb_dbapi as dbapi from ydb_dbapi import IsolationLevel # noqa: F401 + +from ._version import VERSION # noqa: F401 from .sqlalchemy import Upsert, types, upsert # noqa: F401 -import ydb_dbapi as dbapi diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 7bdf8f8..d68a009 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -9,23 +9,25 @@ import sqlalchemy as sa import ydb +import ydb_dbapi from sqlalchemy import util from sqlalchemy.engine import characteristics, reflection from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect from sqlalchemy.exc import NoSuchTableError from sqlalchemy.sql import functions - from sqlalchemy.sql.elements import ClauseList -import ydb_dbapi +from ydb_sqlalchemy.sqlalchemy.compiler import ( + YqlCompiler, + YqlDDLCompiler, + YqlIdentifierPreparer, + YqlTypeCompiler, +) from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncConnection from ydb_sqlalchemy.sqlalchemy.dml import Upsert -from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler - from . import types - OLD_SA = sa.__version__ < "2." diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py b/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py index 31affdd..7e7316e 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py @@ -3,14 +3,18 @@ sa_version = sa.__version__ if sa_version.startswith("2."): - from .sa20 import YqlCompiler - from .sa20 import YqlDDLCompiler - from .sa20 import YqlTypeCompiler - from .sa20 import YqlIdentifierPreparer + from .sa20 import ( + YqlCompiler, + YqlDDLCompiler, + YqlIdentifierPreparer, + YqlTypeCompiler, + ) elif sa_version.startswith("1.4."): - from .sa14 import YqlCompiler - from .sa14 import YqlDDLCompiler - from .sa14 import YqlTypeCompiler - from .sa14 import YqlIdentifierPreparer + from .sa14 import ( + YqlCompiler, + YqlDDLCompiler, + YqlIdentifierPreparer, + YqlTypeCompiler, + ) else: raise RuntimeError("Unsupported SQLAlchemy version.") diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/base.py b/ydb_sqlalchemy/sqlalchemy/compiler/base.py index c522765..e0d7088 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/base.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/base.py @@ -1,8 +1,8 @@ import collections +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + import sqlalchemy as sa import ydb -from ydb_dbapi import NotSupportedError - from sqlalchemy.exc import CompileError from sqlalchemy.sql import ddl from sqlalchemy.sql.compiler import ( @@ -12,22 +12,10 @@ StrSQLTypeCompiler, selectable, ) -from typing import ( - Any, - Dict, - List, - Mapping, - Sequence, - Optional, - Tuple, - Type, - Union, -) - +from ydb_dbapi import NotSupportedError from .. import types - OLD_SA = sa.__version__ < "2." if OLD_SA: from sqlalchemy import bindparam as _bindparam diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py index 598fc29..a6ad1d4 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py @@ -1,4 +1,5 @@ from typing import Union + import sqlalchemy as sa import ydb diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py index 4ede43c..a838ad7 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py @@ -1,6 +1,7 @@ +from typing import Union + import sqlalchemy as sa import ydb - from sqlalchemy.exc import CompileError from sqlalchemy.sql import literal_column from sqlalchemy.util.compat import inspect_getfullargspec @@ -11,7 +12,6 @@ BaseYqlIdentifierPreparer, BaseYqlTypeCompiler, ) -from typing import Union class YqlTypeCompiler(BaseYqlTypeCompiler): @@ -99,11 +99,7 @@ def visit_primary_key_constraint(self, constraint, **kwargs): text += "PRIMARY KEY " text += "(%s)" % ", ".join( self.preparer.quote(c.name) - for c in ( - constraint.columns_autoinc_first - if constraint._implicit_generated - else constraint.columns - ) + for c in (constraint.columns_autoinc_first if constraint._implicit_generated else constraint.columns) ) text += self.define_constraint_deferrability(constraint) return text diff --git a/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py b/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py index 66f72ae..6972954 100644 --- a/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py +++ b/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py @@ -1,8 +1,7 @@ +import ydb from sqlalchemy.engine.interfaces import AdaptedConnection - from sqlalchemy.util.concurrency import await_only from ydb_dbapi import AsyncConnection, AsyncCursor -import ydb class AdaptedAsyncConnection(AdaptedConnection):