diff --git a/test/test_core.py b/test/test_core.py index 3f7d808..ab8db17 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -15,11 +15,9 @@ from ydb_sqlalchemy.sqlalchemy import types if sa.__version__ >= "2.": - from sqlalchemy import NullPool - from sqlalchemy import QueuePool + from sqlalchemy import NullPool, QueuePool else: - from sqlalchemy.pool import NullPool - from sqlalchemy.pool import QueuePool + from sqlalchemy.pool import NullPool, QueuePool def clear_sql(stm): diff --git a/test/test_suite.py b/test/test_suite.py index 0ffa6df..e67af8a 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -68,7 +68,6 @@ from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest - from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest from sqlalchemy.testing.suite.test_types import StringTest as _StringTest from sqlalchemy.testing.suite.test_types import ( diff --git a/ydb_sqlalchemy/__init__.py b/ydb_sqlalchemy/__init__.py index 55ade24..62f8b58 100644 --- a/ydb_sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/__init__.py @@ -1,4 +1,5 @@ -from ._version import VERSION # noqa: F401 +import ydb_dbapi as dbapi from ydb_dbapi import IsolationLevel # noqa: F401 + +from ._version import VERSION # noqa: F401 from .sqlalchemy import Upsert, types, upsert # noqa: F401 -import ydb_dbapi as dbapi diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 7bdf8f8..d68a009 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -9,23 +9,25 @@ import sqlalchemy as sa import ydb +import ydb_dbapi from sqlalchemy import util from sqlalchemy.engine import characteristics, reflection from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect from sqlalchemy.exc import NoSuchTableError from sqlalchemy.sql import functions - from sqlalchemy.sql.elements import ClauseList -import ydb_dbapi +from ydb_sqlalchemy.sqlalchemy.compiler import ( + YqlCompiler, + YqlDDLCompiler, + YqlIdentifierPreparer, + YqlTypeCompiler, +) from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncConnection from ydb_sqlalchemy.sqlalchemy.dml import Upsert -from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler - from . import types - OLD_SA = sa.__version__ < "2." diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py b/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py index 31affdd..7e7316e 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/__init__.py @@ -3,14 +3,18 @@ sa_version = sa.__version__ if sa_version.startswith("2."): - from .sa20 import YqlCompiler - from .sa20 import YqlDDLCompiler - from .sa20 import YqlTypeCompiler - from .sa20 import YqlIdentifierPreparer + from .sa20 import ( + YqlCompiler, + YqlDDLCompiler, + YqlIdentifierPreparer, + YqlTypeCompiler, + ) elif sa_version.startswith("1.4."): - from .sa14 import YqlCompiler - from .sa14 import YqlDDLCompiler - from .sa14 import YqlTypeCompiler - from .sa14 import YqlIdentifierPreparer + from .sa14 import ( + YqlCompiler, + YqlDDLCompiler, + YqlIdentifierPreparer, + YqlTypeCompiler, + ) else: raise RuntimeError("Unsupported SQLAlchemy version.") diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/base.py b/ydb_sqlalchemy/sqlalchemy/compiler/base.py index c522765..e0d7088 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/base.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/base.py @@ -1,8 +1,8 @@ import collections +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union + import sqlalchemy as sa import ydb -from ydb_dbapi import NotSupportedError - from sqlalchemy.exc import CompileError from sqlalchemy.sql import ddl from sqlalchemy.sql.compiler import ( @@ -12,22 +12,10 @@ StrSQLTypeCompiler, selectable, ) -from typing import ( - Any, - Dict, - List, - Mapping, - Sequence, - Optional, - Tuple, - Type, - Union, -) - +from ydb_dbapi import NotSupportedError from .. import types - OLD_SA = sa.__version__ < "2." if OLD_SA: from sqlalchemy import bindparam as _bindparam diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py index 598fc29..a6ad1d4 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa14.py @@ -1,4 +1,5 @@ from typing import Union + import sqlalchemy as sa import ydb diff --git a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py index 702d7aa..a838ad7 100644 --- a/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py +++ b/ydb_sqlalchemy/sqlalchemy/compiler/sa20.py @@ -1,6 +1,7 @@ +from typing import Union + import sqlalchemy as sa import ydb - from sqlalchemy.exc import CompileError from sqlalchemy.sql import literal_column from sqlalchemy.util.compat import inspect_getfullargspec @@ -11,7 +12,6 @@ BaseYqlIdentifierPreparer, BaseYqlTypeCompiler, ) -from typing import Union class YqlTypeCompiler(BaseYqlTypeCompiler): @@ -89,4 +89,17 @@ def visit_upsert(self, insert_stmt, visited_bindparam=None, **kw): class YqlDDLCompiler(BaseYqlDDLCompiler): - ... + def visit_foreign_key_constraint(self, constraint, **kwargs): + return None + + def visit_primary_key_constraint(self, constraint, **kwargs): + if len(constraint) == 0: + return "" + text = "" + text += "PRIMARY KEY " + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) + for c in (constraint.columns_autoinc_first if constraint._implicit_generated else constraint.columns) + ) + text += self.define_constraint_deferrability(constraint) + return text diff --git a/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py b/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py index 66f72ae..6972954 100644 --- a/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py +++ b/ydb_sqlalchemy/sqlalchemy/dbapi_adapter.py @@ -1,8 +1,7 @@ +import ydb from sqlalchemy.engine.interfaces import AdaptedConnection - from sqlalchemy.util.concurrency import await_only from ydb_dbapi import AsyncConnection, AsyncCursor -import ydb class AdaptedAsyncConnection(AdaptedConnection):