Skip to content

Commit 1800d49

Browse files
committed
🏷️ fix types
1 parent 9efa266 commit 1800d49

File tree

5 files changed

+169
-145
lines changed

5 files changed

+169
-145
lines changed

requirements_dev.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pytest-timeout
1111
pytimeparse2
1212
simplejson>=3.19.1
1313
types-simplejson
14-
sqlalchemy<2.0.0
14+
sqlalchemy
1515
sqlalchemy-utils
1616
tox
1717
tqdm>=4.65.0
@@ -20,4 +20,8 @@ packaging>=23.1
2020
tabulate
2121
types-tabulate
2222
Unidecode>=1.3.6
23-
typing_extensions
23+
typing_extensions
24+
types-Pygments
25+
types-colorama
26+
types-mock
27+
types-setuptools

tests/database.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212

1313

1414
class Database:
15-
engine: Engine = None
16-
Session: sessionmaker = None
15+
engine: Engine
16+
Session: sessionmaker
1717

1818
def __init__(self, database_uri):
1919
self.Session = sessionmaker()

tests/func/sqlite3_to_mysql_test.py

Lines changed: 52 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pytest_mock import MockFixture
1818
from sqlalchemy import MetaData, Table, create_engine, inspect, select, text
1919
from sqlalchemy.engine import Connection, CursorResult, Engine, Inspector, Row
20+
from sqlalchemy.engine.interfaces import ReflectedIndex
2021
from sqlalchemy.sql import Select
2122
from sqlalchemy.sql.elements import TextClause
2223

@@ -337,21 +338,21 @@ def test_transfer_transfers_all_tables_in_sqlite_file(
337338

338339
""" Test if all the tables have the same indices """
339340
index_keys: t.Tuple[str, ...] = ("name", "column_names", "unique")
340-
mysql_indices: t.Tuple[t.Dict[str, t.Any], ...] = tuple(
341-
{key: index[key] for key in index_keys}
341+
mysql_indices: t.Tuple[ReflectedIndex, ...] = tuple(
342+
t.cast(ReflectedIndex, {key: index[key] for key in index_keys})
342343
for index in (chain.from_iterable(mysql_inspect.get_indexes(table_name) for table_name in mysql_tables))
343344
)
344345

345346
for table_name in sqlite_tables:
346-
sqlite_indices: t.List[t.Dict[str, t.Any]] = sqlite_inspect.get_indexes(table_name)
347+
sqlite_indices: t.List[ReflectedIndex] = sqlite_inspect.get_indexes(table_name)
347348
if with_rowid:
348349
sqlite_indices.insert(
349350
0,
350-
{
351-
"name": "{}_rowid".format(table_name),
352-
"column_names": ["rowid"],
353-
"unique": 1,
354-
},
351+
ReflectedIndex(
352+
name="{}_rowid".format(table_name),
353+
column_names=["rowid"],
354+
unique=True,
355+
),
355356
)
356357
for sqlite_index in sqlite_indices:
357358
sqlite_index["unique"] = bool(sqlite_index["unique"])
@@ -381,13 +382,19 @@ def test_transfer_transfers_all_tables_in_sqlite_file(
381382
constraint_type="FOREIGN KEY",
382383
)
383384
mysql_fk_result: CursorResult = mysql_cnx.execute(mysql_fk_stmt)
384-
mysql_foreign_keys: t.List[t.Dict[str, t.Any]] = [dict(row) for row in mysql_fk_result]
385-
386-
sqlite_fk_stmt: str = 'PRAGMA foreign_key_list("{table}")'.format(table=table_name)
385+
mysql_foreign_keys: t.List[t.Dict[str, t.Any]] = [
386+
{
387+
"table": fk["table"],
388+
"from": fk["from"],
389+
"to": fk["to"],
390+
}
391+
for fk in mysql_fk_result.mappings()
392+
]
393+
394+
sqlite_fk_stmt: TextClause = text('PRAGMA foreign_key_list("{table}")'.format(table=table_name))
387395
sqlite_fk_result = sqlite_cnx.execute(sqlite_fk_stmt)
388396
if sqlite_fk_result.returns_rows:
389-
for row in sqlite_fk_result:
390-
fk: t.Dict[str, t.Any] = dict(row)
397+
for fk in sqlite_fk_result.mappings():
391398
assert {
392399
"table": fk["table"],
393400
"from": fk["from"],
@@ -398,23 +405,28 @@ def test_transfer_transfers_all_tables_in_sqlite_file(
398405
sqlite_results: t.List[t.Tuple[t.Tuple[t.Any, ...], ...]] = []
399406
mysql_results: t.List[t.Tuple[t.Tuple[t.Any, ...], ...]] = []
400407

401-
meta: MetaData = MetaData(bind=None)
408+
meta: MetaData = MetaData()
402409
for table_name in sqlite_tables:
403-
sqlite_table: Table = Table(table_name, meta, autoload=True, autoload_with=sqlite_engine)
404-
sqlite_stmt: Select = select([sqlite_table])
405-
sqlite_result: t.List[Row] = sqlite_cnx.execute(sqlite_stmt).fetchall()
410+
sqlite_table: Table = Table(table_name, meta, autoload_with=sqlite_engine)
411+
sqlite_stmt: Select = select(sqlite_table)
412+
sqlite_result: t.List[Row[t.Any]] = list(sqlite_cnx.execute(sqlite_stmt).fetchall())
406413
sqlite_result.sort()
407414
sqlite_results.append(tuple(tuple(data for data in row) for row in sqlite_result))
408415

409416
for table_name in mysql_tables:
410-
mysql_table: Table = Table(table_name, meta, autoload=True, autoload_with=mysql_engine)
411-
mysql_stmt: Select = select([mysql_table])
412-
mysql_result: t.List[Row] = mysql_cnx.execute(mysql_stmt).fetchall()
417+
mysql_table: Table = Table(table_name, meta, autoload_with=mysql_engine)
418+
mysql_stmt: Select = select(mysql_table)
419+
mysql_result: t.List[Row[t.Any]] = list(mysql_cnx.execute(mysql_stmt).fetchall())
413420
mysql_result.sort()
414421
mysql_results.append(tuple(tuple(data for data in row) for row in mysql_result))
415422

416423
assert sqlite_results == mysql_results
417424

425+
mysql_cnx.close()
426+
sqlite_cnx.close()
427+
mysql_engine.dispose()
428+
sqlite_engine.dispose()
429+
418430
@pytest.mark.transfer
419431
@pytest.mark.parametrize(
420432
"chunk, with_rowid, mysql_insert_method, ignore_duplicate_keys",
@@ -518,21 +530,21 @@ def test_transfer_specific_tables_transfers_only_specified_tables_from_sqlite_fi
518530

519531
""" Test if all the tables have the same indices """
520532
index_keys: t.Tuple[str, ...] = ("name", "column_names", "unique")
521-
mysql_indices: t.Tuple[t.Dict[str, t.Any], ...] = tuple(
522-
{key: index[key] for key in index_keys}
533+
mysql_indices: t.Tuple[ReflectedIndex, ...] = tuple(
534+
t.cast(ReflectedIndex, {key: index[key] for key in index_keys})
523535
for index in (chain.from_iterable(mysql_inspect.get_indexes(table_name) for table_name in mysql_tables))
524536
)
525537

526538
for table_name in random_sqlite_tables:
527-
sqlite_indices: t.List[t.Dict[str, t.Any]] = sqlite_inspect.get_indexes(table_name)
539+
sqlite_indices: t.List[ReflectedIndex] = sqlite_inspect.get_indexes(table_name)
528540
if with_rowid:
529541
sqlite_indices.insert(
530542
0,
531-
{
532-
"name": "{}_rowid".format(table_name),
533-
"column_names": ["rowid"],
534-
"unique": 1,
535-
},
543+
ReflectedIndex(
544+
name="{}_rowid".format(table_name),
545+
column_names=["rowid"],
546+
unique=True,
547+
),
536548
)
537549
for sqlite_index in sqlite_indices:
538550
sqlite_index["unique"] = bool(sqlite_index["unique"])
@@ -544,19 +556,24 @@ def test_transfer_specific_tables_transfers_only_specified_tables_from_sqlite_fi
544556
sqlite_results: t.List[t.Tuple[t.Tuple[t.Any, ...], ...]] = []
545557
mysql_results: t.List[t.Tuple[t.Tuple[t.Any, ...], ...]] = []
546558

547-
meta: MetaData = MetaData(bind=None)
559+
meta: MetaData = MetaData()
548560
for table_name in random_sqlite_tables:
549-
sqlite_table: Table = Table(table_name, meta, autoload=True, autoload_with=sqlite_engine)
550-
sqlite_stmt: Select = select([sqlite_table])
551-
sqlite_result: t.List[Row] = sqlite_cnx.execute(sqlite_stmt).fetchall()
561+
sqlite_table: Table = Table(table_name, meta, autoload_with=sqlite_engine)
562+
sqlite_stmt: Select = select(sqlite_table)
563+
sqlite_result: t.List[Row[t.Any]] = list(sqlite_cnx.execute(sqlite_stmt).fetchall())
552564
sqlite_result.sort()
553565
sqlite_results.append(tuple(tuple(data for data in row) for row in sqlite_result))
554566

555567
for table_name in mysql_tables:
556-
mysql_table: Table = Table(table_name, meta, autoload=True, autoload_with=mysql_engine)
557-
mysql_stmt: Select = select([mysql_table])
558-
mysql_result: t.List[Row] = mysql_cnx.execute(mysql_stmt).fetchall()
568+
mysql_table: Table = Table(table_name, meta, autoload_with=mysql_engine)
569+
mysql_stmt: Select = select(mysql_table)
570+
mysql_result: t.List[Row[t.Any]] = list(mysql_cnx.execute(mysql_stmt).fetchall())
559571
mysql_result.sort()
560572
mysql_results.append(tuple(tuple(data for data in row) for row in mysql_result))
561573

562574
assert sqlite_results == mysql_results
575+
576+
mysql_cnx.close()
577+
sqlite_cnx.close()
578+
mysql_engine.dispose()
579+
sqlite_engine.dispose()

tests/models.py

Lines changed: 52 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -12,57 +12,53 @@
1212
TIMESTAMP,
1313
VARCHAR,
1414
BigInteger,
15-
Boolean,
1615
Column,
17-
Date,
18-
DateTime,
16+
Dialect,
1917
ForeignKey,
2018
Integer,
2119
SmallInteger,
2220
String,
2321
Table,
2422
Text,
25-
Time,
2623
)
27-
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
28-
from sqlalchemy.orm import backref, declarative_base, relationship
29-
from sqlalchemy.orm.decl_api import DeclarativeMeta
24+
from sqlalchemy.orm import DeclarativeBase, Mapped, backref, mapped_column, relationship
3025
from sqlalchemy.sql.functions import current_timestamp
3126

3227

3328
class SQLiteNumeric(types.TypeDecorator):
3429
impl: t.Type[String] = types.String
3530

36-
def load_dialect_impl(self, dialect: SQLiteDialect) -> t.Any:
31+
def load_dialect_impl(self, dialect: Dialect) -> t.Any:
3732
return dialect.type_descriptor(types.VARCHAR(100))
3833

39-
def process_bind_param(self, value: t.Any, dialect: SQLiteDialect) -> str:
34+
def process_bind_param(self, value: t.Any, dialect: Dialect) -> str:
4035
return str(value)
4136

42-
def process_result_value(self, value: t.Any, dialect: SQLiteDialect) -> Decimal:
37+
def process_result_value(self, value: t.Any, dialect: Dialect) -> Decimal:
4338
return Decimal(value)
4439

4540

4641
class MyCustomType(types.TypeDecorator):
4742
impl: t.Type[String] = types.String
4843

49-
def load_dialect_impl(self, dialect: SQLiteDialect) -> t.Any:
44+
def load_dialect_impl(self, dialect: Dialect) -> t.Any:
5045
return dialect.type_descriptor(types.VARCHAR(self.length))
5146

52-
def process_bind_param(self, value: t.Any, dialect: SQLiteDialect) -> str:
47+
def process_bind_param(self, value: t.Any, dialect: Dialect) -> str:
5348
return str(value)
5449

55-
def process_result_value(self, value: t.Any, dialect: SQLiteDialect) -> str:
50+
def process_result_value(self, value: t.Any, dialect: Dialect) -> str:
5651
return str(value)
5752

5853

59-
Base: DeclarativeMeta = declarative_base()
54+
class Base(DeclarativeBase):
55+
pass
6056

6157

6258
class Author(Base):
6359
__tablename__ = "authors"
64-
id: int = Column(Integer, primary_key=True)
65-
name: str = Column(String(128), nullable=False, index=True)
60+
id: Mapped[int] = mapped_column(primary_key=True)
61+
name: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
6662

6763
def __repr__(self):
6864
return "<Author(id='{id}', name='{name}')>".format(id=self.id, name=self.name)
@@ -78,9 +74,9 @@ def __repr__(self):
7874

7975
class Image(Base):
8076
__tablename__ = "images"
81-
id: int = Column(Integer, primary_key=True)
82-
path: str = Column(String(255), index=True)
83-
description: str = Column(String(255), nullable=True)
77+
id: Mapped[int] = mapped_column(primary_key=True)
78+
path: Mapped[str] = mapped_column(String(255), index=True)
79+
description: Mapped[str] = mapped_column(String(255), nullable=True)
8480

8581
def __repr__(self):
8682
return "<Image(id='{id}', path='{path}')>".format(id=self.id, path=self.path)
@@ -96,8 +92,8 @@ def __repr__(self):
9692

9793
class Tag(Base):
9894
__tablename__ = "tags"
99-
id: int = Column(Integer, primary_key=True)
100-
name: str = Column(String(128), nullable=False, index=True)
95+
id: Mapped[int] = mapped_column(primary_key=True)
96+
name: Mapped[str] = mapped_column(String(128), nullable=False, index=True)
10197

10298
def __repr__(self):
10399
return "<Tag(id='{id}', name='{name}')>".format(id=self.id, name=self.name)
@@ -115,27 +111,27 @@ class Misc(Base):
115111
"""This model contains all possible MySQL types"""
116112

117113
__tablename__ = "misc"
118-
id: int = Column(Integer, primary_key=True)
119-
big_integer_field: int = Column(BigInteger, default=0)
120-
blob_field: bytes = Column(BLOB, nullable=True, index=True)
121-
boolean_field: bool = Column(Boolean, default=False)
122-
char_field: str = Column(CHAR(255), nullable=True)
123-
date_field: date = Column(Date, nullable=True)
124-
date_time_field: datetime = Column(DateTime, nullable=True)
125-
decimal_field: Decimal = Column(SQLiteNumeric(10, 2), nullable=True)
126-
float_field: Decimal = Column(SQLiteNumeric(12, 4), default=0)
127-
integer_field: int = Column(Integer, default=0)
114+
id: Mapped[int] = mapped_column(primary_key=True)
115+
big_integer_field: Mapped[int] = mapped_column(BigInteger, default=0)
116+
blob_field: Mapped[bytes] = mapped_column(BLOB, nullable=True, index=True)
117+
boolean_field: Mapped[bool] = mapped_column(default=False)
118+
char_field: Mapped[str] = mapped_column(CHAR(255), nullable=True)
119+
date_field: Mapped[date] = mapped_column(nullable=True)
120+
date_time_field: Mapped[datetime] = mapped_column(nullable=True)
121+
decimal_field: Mapped[Decimal] = mapped_column(SQLiteNumeric(10, 2), nullable=True)
122+
float_field: Mapped[Decimal] = mapped_column(SQLiteNumeric(12, 4), default=0)
123+
integer_field: Mapped[int] = mapped_column(default=0)
128124
if environ.get("LEGACY_DB", "0") == "0":
129-
json_field: t.Dict[str, t.Any] = Column(JSON, nullable=True)
130-
numeric_field: Decimal = Column(SQLiteNumeric(12, 4), default=0)
131-
real_field: float = Column(REAL(12, 4), default=0)
132-
small_integer_field: int = Column(SmallInteger, default=0)
133-
string_field: str = Column(String(255), nullable=True)
134-
text_field: str = Column(Text, nullable=True)
135-
time_field: time = Column(Time, nullable=True)
136-
varchar_field: str = Column(VARCHAR(255), nullable=True)
137-
timestamp_field: datetime = Column(TIMESTAMP, default=current_timestamp())
138-
my_type_field: t.Any = Column(MyCustomType(255), nullable=True)
125+
json_field: Mapped[t.Dict[str, t.Any]] = mapped_column(JSON, nullable=True)
126+
numeric_field: Mapped[Decimal] = mapped_column(SQLiteNumeric(12, 4), default=0)
127+
real_field: Mapped[float] = mapped_column(REAL(12, 4), default=0)
128+
small_integer_field: Mapped[int] = mapped_column(SmallInteger, default=0)
129+
string_field: Mapped[str] = mapped_column(String(255), nullable=True)
130+
text_field: Mapped[str] = mapped_column(Text, nullable=True)
131+
time_field: Mapped[time] = mapped_column(nullable=True)
132+
varchar_field: Mapped[str] = mapped_column(VARCHAR(255), nullable=True)
133+
timestamp_field: Mapped[datetime] = mapped_column(TIMESTAMP, default=current_timestamp())
134+
my_type_field: Mapped[t.Any] = mapped_column(MyCustomType(255), nullable=True)
139135

140136

141137
article_misc: Table = Table(
@@ -148,9 +144,9 @@ class Misc(Base):
148144

149145
class Media(Base):
150146
__tablename__ = "media"
151-
id: str = Column(CHAR(64), primary_key=True)
152-
title: str = Column(String(255), index=True)
153-
description: str = Column(String(255), nullable=True)
147+
id: Mapped[str] = mapped_column(CHAR(64), primary_key=True)
148+
title: Mapped[str] = mapped_column(String(255), index=True)
149+
description: Mapped[str] = mapped_column(String(255), nullable=True)
154150

155151
def __repr__(self):
156152
return "<Media(id='{id}', title='{title}')>".format(id=self.id, title=self.title)
@@ -166,39 +162,39 @@ def __repr__(self):
166162

167163
class Article(Base):
168164
__tablename__ = "articles"
169-
id: int = Column(Integer, primary_key=True)
170-
hash: str = Column(String(32), unique=True)
171-
slug: str = Column(String(255), index=True)
172-
title: str = Column(String(255), index=True)
173-
content: str = Column(Text, nullable=True, index=True)
174-
status: str = Column(CHAR(1), index=True)
175-
published: datetime = Column(DateTime, nullable=True)
165+
id: Mapped[int] = mapped_column(primary_key=True)
166+
hash: Mapped[str] = mapped_column(String(32), unique=True)
167+
slug: Mapped[str] = mapped_column(String(255), index=True)
168+
title: Mapped[str] = mapped_column(String(255), index=True)
169+
content: Mapped[str] = mapped_column(Text, nullable=True, index=True)
170+
status: Mapped[str] = mapped_column(CHAR(1), index=True)
171+
published: Mapped[datetime] = mapped_column(nullable=True)
176172
# relationships
177-
authors: t.List[Author] = relationship(
173+
authors: Mapped[t.List[Author]] = relationship(
178174
"Author",
179175
secondary=article_authors,
180176
backref=backref("authors", lazy="dynamic"),
181177
lazy="dynamic",
182178
)
183-
tags: t.List[Tag] = relationship(
179+
tags: Mapped[t.List[Tag]] = relationship(
184180
"Tag",
185181
secondary=article_tags,
186182
backref=backref("tags", lazy="dynamic"),
187183
lazy="dynamic",
188184
)
189-
images: t.List[Image] = relationship(
185+
images: Mapped[t.List[Image]] = relationship(
190186
"Image",
191187
secondary=article_images,
192188
backref=backref("images", lazy="dynamic"),
193189
lazy="dynamic",
194190
)
195-
media: t.List[Media] = relationship(
191+
media: Mapped[t.List[Media]] = relationship(
196192
"Media",
197193
secondary=article_media,
198194
backref=backref("media", lazy="dynamic"),
199195
lazy="dynamic",
200196
)
201-
misc: t.List[Misc] = relationship(
197+
misc: Mapped[t.List[Misc]] = relationship(
202198
"Misc",
203199
secondary=article_misc,
204200
backref=backref("misc", lazy="dynamic"),

0 commit comments

Comments
 (0)