Skip to content

Commit 0fc06bc

Browse files
tidy(app): move sqlite-specific objects to sqlite file
1 parent fcc3f7f commit 0fc06bc

File tree

2 files changed

+93
-96
lines changed

2 files changed

+93
-96
lines changed

invokeai/app/services/board_records/board_records_common.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,11 @@
11
from datetime import datetime
22
from typing import Any, Optional, Union
33

4-
from attr import dataclass
54
from pydantic import BaseModel, Field
65

76
from invokeai.app.util.misc import get_iso_timestamp
87
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
98

10-
# This query is missing a GROUP BY clause, which is required for the query to be valid.
11-
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
12-
SELECT b.board_id,
13-
b.board_name,
14-
b.created_at,
15-
b.updated_at,
16-
b.archived,
17-
COUNT(
18-
CASE
19-
WHEN i.image_category in ('general')
20-
AND i.is_intermediate = 0 THEN 1
21-
END
22-
) AS image_count,
23-
COUNT(
24-
CASE
25-
WHEN i.image_category in ('control', 'mask', 'user', 'other')
26-
AND i.is_intermediate = 0 THEN 1
27-
END
28-
) AS asset_count,
29-
(
30-
SELECT bi.image_name
31-
FROM board_images bi
32-
JOIN images i ON bi.image_name = i.image_name
33-
WHERE bi.board_id = b.board_id
34-
AND i.is_intermediate = 0
35-
ORDER BY i.created_at DESC
36-
LIMIT 1
37-
) AS cover_image_name
38-
FROM boards b
39-
LEFT JOIN board_images bi ON b.board_id = bi.board_id
40-
LEFT JOIN images i ON bi.image_name = i.image_name
41-
"""
42-
43-
44-
@dataclass
45-
class PaginatedBoardRecordsQueries:
46-
main_query: str
47-
total_count_query: str
48-
49-
50-
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
51-
"""Gets a query to retrieve a paginated list of board records."""
52-
53-
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
54-
55-
# The GROUP BY must be added _after_ the WHERE clause!
56-
main_query = f"""
57-
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
58-
{archived_condition}
59-
GROUP BY b.board_id,
60-
b.board_name,
61-
b.created_at,
62-
b.updated_at
63-
ORDER BY b.created_at DESC
64-
LIMIT ? OFFSET ?;
65-
"""
66-
67-
total_count_query = f"""
68-
SELECT COUNT(*)
69-
FROM boards b
70-
{archived_condition};
71-
"""
72-
73-
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
74-
75-
76-
def get_list_all_board_records_query(include_archived: bool) -> str:
77-
"""Gets a query to retrieve all board records."""
78-
79-
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
80-
81-
# The GROUP BY must be added _after_ the WHERE clause!
82-
return f"""
83-
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
84-
{archived_condition}
85-
GROUP BY b.board_id,
86-
b.board_name,
87-
b.created_at,
88-
b.updated_at
89-
ORDER BY b.created_at DESC;
90-
"""
91-
92-
93-
def get_board_record_query() -> str:
94-
"""Gets a query to retrieve a board record."""
95-
96-
return f"""
97-
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
98-
WHERE b.board_id = ?;
99-
"""
100-
1019

10210
class BoardRecord(BaseModelExcludeNull):
10311
"""Deserialized board record."""

invokeai/app/services/board_records/board_records_sqlite.py

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sqlite3
22
import threading
3+
from dataclasses import dataclass
34
from typing import Union, cast
45

56
from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
@@ -11,14 +12,102 @@
1112
BoardRecordSaveException,
1213
UncategorizedImageCounts,
1314
deserialize_board_record,
14-
get_board_record_query,
15-
get_list_all_board_records_query,
16-
get_paginated_list_board_records_query,
1715
)
1816
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
1917
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
2018
from invokeai.app.util.misc import uuid_string
2119

20+
# This query is missing a GROUP BY clause, which is required for the query to be valid.
21+
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
22+
SELECT b.board_id,
23+
b.board_name,
24+
b.created_at,
25+
b.updated_at,
26+
b.archived,
27+
COUNT(
28+
CASE
29+
WHEN i.image_category in ('general')
30+
AND i.is_intermediate = 0 THEN 1
31+
END
32+
) AS image_count,
33+
COUNT(
34+
CASE
35+
WHEN i.image_category in ('control', 'mask', 'user', 'other')
36+
AND i.is_intermediate = 0 THEN 1
37+
END
38+
) AS asset_count,
39+
(
40+
SELECT bi.image_name
41+
FROM board_images bi
42+
JOIN images i ON bi.image_name = i.image_name
43+
WHERE bi.board_id = b.board_id
44+
AND i.is_intermediate = 0
45+
ORDER BY i.created_at DESC
46+
LIMIT 1
47+
) AS cover_image_name
48+
FROM boards b
49+
LEFT JOIN board_images bi ON b.board_id = bi.board_id
50+
LEFT JOIN images i ON bi.image_name = i.image_name
51+
"""
52+
53+
54+
@dataclass
55+
class PaginatedBoardRecordsQueries:
56+
main_query: str
57+
total_count_query: str
58+
59+
60+
def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries:
61+
"""Gets a query to retrieve a paginated list of board records."""
62+
63+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
64+
65+
# The GROUP BY must be added _after_ the WHERE clause!
66+
main_query = f"""
67+
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
68+
{archived_condition}
69+
GROUP BY b.board_id,
70+
b.board_name,
71+
b.created_at,
72+
b.updated_at
73+
ORDER BY b.created_at DESC
74+
LIMIT ? OFFSET ?;
75+
"""
76+
77+
total_count_query = f"""
78+
SELECT COUNT(*)
79+
FROM boards b
80+
{archived_condition};
81+
"""
82+
83+
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
84+
85+
86+
def get_list_all_board_records_query(include_archived: bool) -> str:
87+
"""Gets a query to retrieve all board records."""
88+
89+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
90+
91+
# The GROUP BY must be added _after_ the WHERE clause!
92+
return f"""
93+
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
94+
{archived_condition}
95+
GROUP BY b.board_id,
96+
b.board_name,
97+
b.created_at,
98+
b.updated_at
99+
ORDER BY b.created_at DESC;
100+
"""
101+
102+
103+
def get_board_record_query() -> str:
104+
"""Gets a query to retrieve a board record."""
105+
106+
return f"""
107+
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
108+
WHERE b.board_id = ?;
109+
"""
110+
22111

23112
class SqliteBoardRecordStorage(BoardRecordStorageBase):
24113
_conn: sqlite3.Connection
@@ -149,7 +238,7 @@ def get_many(
149238
try:
150239
self._lock.acquire()
151240

152-
queries = get_paginated_list_board_records_query(include_archived=include_archived)
241+
queries = get_paginated_list_board_records_queries(include_archived=include_archived)
153242

154243
self._cursor.execute(
155244
queries.main_query,

0 commit comments

Comments
 (0)