Skip to content

Commit 641d81e

Browse files
committed
🏷️ add types to tests
1 parent a3bd19a commit 641d81e

File tree

8 files changed

+684
-495
lines changed

8 files changed

+684
-495
lines changed

requirements_dev.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@ python-slugify>=7.0.0
1313
types-python-slugify
1414
simplejson>=3.19.1
1515
types-simplejson
16-
sqlalchemy<2.0.0
16+
sqlalchemy>=2.0.0
1717
sqlalchemy-utils
18+
types-sqlalchemy-utils
1819
tox
1920
tqdm>=4.65.0
2021
types-tqdm
2122
packaging>=23.1
2223
tabulate
2324
types-tabulate
24-
typing_extensions
25+
typing_extensions
26+
requests
27+
types-requests
28+
mypy>=1.3.0

tests/conftest.py

Lines changed: 65 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,38 @@
11
import json
2+
import os
23
import socket
4+
import typing as t
35
from codecs import open
4-
from collections import namedtuple
5-
from contextlib import closing, contextmanager
6+
from contextlib import contextmanager
67
from os.path import abspath, dirname, isfile, join
8+
from pathlib import Path
79
from random import choice
810
from string import ascii_lowercase, ascii_uppercase, digits
911
from time import sleep
1012

1113
import docker
1214
import mysql.connector
1315
import pytest
16+
from _pytest._py.path import LocalPath
17+
from _pytest.config import Config
18+
from _pytest.config.argparsing import Parser
19+
from _pytest.legacypath import TempdirFactory
1420
from click.testing import CliRunner
21+
from docker import DockerClient
1522
from docker.errors import NotFound
16-
from mysql.connector import errorcode
23+
from docker.models.containers import Container
24+
from faker import Faker
25+
from mysql.connector import CMySQLConnection, MySQLConnection, errorcode
26+
from mysql.connector.pooling import PooledMySQLConnection
1727
from requests import HTTPError
1828
from sqlalchemy.exc import IntegrityError
29+
from sqlalchemy.orm import Session
1930
from sqlalchemy_utils import database_exists, drop_database
2031

21-
from .database import Database
22-
from .factories import ArticleFactory, AuthorFactory, CrazyNameFactory, ImageFactory, MiscFactory, TagFactory
32+
from . import database, factories, models
2333

2434

25-
def pytest_addoption(parser):
35+
def pytest_addoption(parser: "Parser"):
2636
parser.addoption(
2737
"--mysql-user",
2838
dest="mysql_user",
@@ -78,9 +88,9 @@ def pytest_addoption(parser):
7888

7989

8090
@pytest.fixture(scope="session", autouse=True)
81-
def cleanup_hanged_docker_containers():
91+
def cleanup_hanged_docker_containers() -> None:
8292
try:
83-
client = docker.from_env()
93+
client: DockerClient = docker.from_env()
8494
for container in client.containers.list():
8595
if container.name == "pytest_mysql_to_sqlite3":
8696
container.kill()
@@ -89,9 +99,9 @@ def cleanup_hanged_docker_containers():
8999
pass
90100

91101

92-
def pytest_keyboard_interrupt():
102+
def pytest_keyboard_interrupt() -> None:
93103
try:
94-
client = docker.from_env()
104+
client: DockerClient = docker.from_env()
95105
for container in client.containers.list():
96106
if container.name == "pytest_mysql_to_sqlite3":
97107
container.kill()
@@ -103,17 +113,17 @@ def pytest_keyboard_interrupt():
103113
class Helpers:
104114
@staticmethod
105115
@contextmanager
106-
def not_raises(exception):
116+
def not_raises(exception: t.Type[Exception]) -> t.Generator:
107117
try:
108118
yield
109119
except exception:
110120
raise pytest.fail("DID RAISE {0}".format(exception))
111121

112122
@staticmethod
113123
@contextmanager
114-
def session_scope(db):
124+
def session_scope(db: database.Database) -> t.Generator:
115125
"""Provide a transactional scope around a series of operations."""
116-
session = db.Session()
126+
session: Session = db.Session()
117127
try:
118128
yield session
119129
session.commit()
@@ -125,29 +135,37 @@ def session_scope(db):
125135

126136

127137
@pytest.fixture
128-
def helpers():
138+
def helpers() -> t.Type[Helpers]:
129139
return Helpers
130140

131141

132142
@pytest.fixture()
133-
def sqlite_database(tmpdir):
134-
db_name = "".join(choice(ascii_uppercase + ascii_lowercase + digits) for _ in range(32))
135-
return str(tmpdir.join("{}.sqlite3".format(db_name)))
143+
def sqlite_database(tmpdir: LocalPath) -> t.Union[str, Path, "os.PathLike[t.Any]"]:
144+
db_name: str = "".join(choice(ascii_uppercase + ascii_lowercase + digits) for _ in range(32))
145+
return Path(tmpdir.join(Path("{}.sqlite3".format(db_name))))
136146

137147

138-
def is_port_in_use(port, host="0.0.0.0"):
148+
def is_port_in_use(port: int, host: str = "0.0.0.0") -> bool:
139149
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
140150
return s.connect_ex((host, port)) == 0
141151

142152

143-
@pytest.fixture(scope="session")
144-
def mysql_credentials(pytestconfig):
145-
MySQLCredentials = namedtuple("MySQLCredentials", ["user", "password", "host", "port", "database"])
153+
class MySQLCredentials(t.NamedTuple):
154+
"""MySQL credentials."""
146155

147-
db_credentials_file = abspath(join(dirname(__file__), "db_credentials.json"))
156+
user: str
157+
password: str
158+
host: str
159+
port: int
160+
database: str
161+
162+
163+
@pytest.fixture(scope="session")
164+
def mysql_credentials(pytestconfig: Config) -> MySQLCredentials:
165+
db_credentials_file: str = abspath(join(dirname(__file__), "db_credentials.json"))
148166
if isfile(db_credentials_file):
149167
with open(db_credentials_file, "r", "utf-8") as fh:
150-
db_credentials = json.load(fh)
168+
db_credentials: t.Dict[str, t.Any] = json.load(fh)
151169
return MySQLCredentials(
152170
user=db_credentials["mysql_user"],
153171
password=db_credentials["mysql_password"],
@@ -156,16 +174,13 @@ def mysql_credentials(pytestconfig):
156174
port=db_credentials["mysql_port"],
157175
)
158176

159-
port = pytestconfig.getoption("mysql_port") or 3306
177+
port: int = pytestconfig.getoption("mysql_port") or 3306
160178
if pytestconfig.getoption("use_docker"):
161179
while is_port_in_use(port, pytestconfig.getoption("mysql_host")):
162180
if port >= 2**16 - 1:
163181
pytest.fail(
164182
"No ports appear to be available on the host {}".format(pytestconfig.getoption("mysql_host"))
165183
)
166-
raise ConnectionError(
167-
"No ports appear to be available on the host {}".format(pytestconfig.getoption("mysql_host"))
168-
)
169184
port += 1
170185

171186
return MySQLCredentials(
@@ -178,11 +193,11 @@ def mysql_credentials(pytestconfig):
178193

179194

180195
@pytest.fixture(scope="session")
181-
def mysql_instance(mysql_credentials, pytestconfig):
182-
container = None
183-
mysql_connection = None
184-
mysql_available = False
185-
mysql_connection_retries = 15 # failsafe
196+
def mysql_instance(mysql_credentials: MySQLCredentials, pytestconfig: Config) -> t.Iterator[MySQLConnection]:
197+
container: t.Optional[Container] = None
198+
mysql_connection: t.Optional[t.Union[PooledMySQLConnection, MySQLConnection, CMySQLConnection]] = None
199+
mysql_available: bool = False
200+
mysql_connection_retries: int = 15 # failsafe
186201

187202
db_credentials_file = abspath(join(dirname(__file__), "db_credentials.json"))
188203
if isfile(db_credentials_file):
@@ -198,7 +213,6 @@ def mysql_instance(mysql_credentials, pytestconfig):
198213
client = docker.from_env()
199214
except Exception as err:
200215
pytest.fail(str(err))
201-
raise
202216

203217
docker_mysql_image = pytestconfig.getoption("docker_mysql_image") or "mysql:latest"
204218

@@ -208,7 +222,6 @@ def mysql_instance(mysql_credentials, pytestconfig):
208222
client.images.pull(docker_mysql_image)
209223
except (HTTPError, NotFound) as err:
210224
pytest.fail(str(err))
211-
raise
212225

213226
container = client.containers.run(
214227
image=docker_mysql_image,
@@ -256,17 +269,22 @@ def mysql_instance(mysql_credentials, pytestconfig):
256269
if not mysql_available and mysql_connection_retries <= 0:
257270
raise ConnectionAbortedError("Maximum MySQL connection retries exhausted! Are you sure MySQL is running?")
258271

259-
yield
272+
yield # type: ignore[misc]
260273

261-
if use_docker:
274+
if use_docker and container is not None:
262275
container.kill()
263276

264277

265278
@pytest.fixture(scope="session")
266-
def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_faker):
267-
temp_image_dir = tmpdir_factory.mktemp("images")
268-
269-
db = Database(
279+
def mysql_database(
280+
tmpdir_factory: TempdirFactory,
281+
mysql_instance: MySQLConnection,
282+
mysql_credentials: MySQLCredentials,
283+
_session_faker: Faker,
284+
) -> t.Iterator[database.Database]:
285+
temp_image_dir: LocalPath = tmpdir_factory.mktemp("images")
286+
287+
db: database.Database = database.Database(
270288
"mysql+mysqldb://{user}:{password}@{host}:{port}/{database}".format(
271289
user=mysql_credentials.user,
272290
password=mysql_credentials.password,
@@ -278,13 +296,13 @@ def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_f
278296

279297
with Helpers.session_scope(db) as session:
280298
for _ in range(_session_faker.pyint(min_value=12, max_value=24)):
281-
article = ArticleFactory()
282-
article.authors.append(AuthorFactory())
283-
article.tags.append(TagFactory())
284-
article.misc.append(MiscFactory())
299+
article: models.Article = factories.ArticleFactory()
300+
article.authors.append(factories.AuthorFactory())
301+
article.tags.append(factories.TagFactory())
302+
article.misc.append(factories.MiscFactory())
285303
for _ in range(_session_faker.pyint(min_value=1, max_value=4)):
286304
article.images.append(
287-
ImageFactory(
305+
factories.ImageFactory(
288306
path=join(
289307
str(temp_image_dir),
290308
_session_faker.year(),
@@ -297,7 +315,7 @@ def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_f
297315
session.add(article)
298316

299317
for _ in range(_session_faker.pyint(min_value=12, max_value=24)):
300-
session.add(CrazyNameFactory())
318+
session.add(factories.CrazyNameFactory())
301319
try:
302320
session.commit()
303321
except IntegrityError:
@@ -310,5 +328,5 @@ def mysql_database(tmpdir_factory, mysql_instance, mysql_credentials, _session_f
310328

311329

312330
@pytest.fixture()
313-
def cli_runner():
331+
def cli_runner() -> t.Iterator[CliRunner]:
314332
yield CliRunner()

tests/database.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1+
import typing as t
12
from datetime import datetime, timedelta
23
from decimal import Decimal
34

45
import simplejson as json
56
from sqlalchemy import create_engine
7+
from sqlalchemy.engine import Engine
68
from sqlalchemy.orm import sessionmaker
79
from sqlalchemy_utils import create_database, database_exists
810

911
from .models import Base
1012

1113

1214
class Database:
13-
engine = None
14-
Session = None
15+
engine: Engine
16+
Session: sessionmaker
1517

1618
def __init__(self, database_uri):
1719
self.Session = sessionmaker()
@@ -21,15 +23,15 @@ def __init__(self, database_uri):
2123
self._create_db_tables()
2224
self.Session.configure(bind=self.engine)
2325

24-
def _create_db_tables(self):
26+
def _create_db_tables(self) -> None:
2527
Base.metadata.create_all(self.engine)
2628

2729
@classmethod
28-
def dumps(cls, data):
30+
def dumps(cls, data: t.Any) -> str:
2931
return json.dumps(data, default=cls.json_serializer)
3032

3133
@staticmethod
32-
def json_serializer(data):
34+
def json_serializer(data: t.Any) -> t.Optional[str]:
3335
if isinstance(data, datetime):
3436
return data.isoformat()
3537
if isinstance(data, Decimal):
@@ -38,3 +40,4 @@ def json_serializer(data):
3840
hours, remainder = divmod(data.total_seconds(), 3600)
3941
minutes, seconds = divmod(remainder, 60)
4042
return "{:02}:{:02}:{:02}".format(int(hours), int(minutes), int(seconds))
43+
return None

0 commit comments

Comments
 (0)