Skip to content

Commit 0bf993b

Browse files
feat(app): refactor board record to include image & asset counts and cover image
This _substantially_ reduces the number of queries required to list all boards. A single query now gets one, all, or a page of boards, including counts and cover image name. - Add helpers to build the queries, which share a common base with some joins. - Update `BoardRecord` to include the counts. - Update `BoardDTO`, which is now identical to `BoardRecord`. I opted to not remove `BoardDTO` because it is used in many places. - Update boards high-level service and board records services accordingly.
1 parent 970bb16 commit 0bf993b

File tree

4 files changed

+125
-121
lines changed

4 files changed

+125
-121
lines changed

invokeai/app/services/board_records/board_records_common.py

Lines changed: 104 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,103 @@
11
from datetime import datetime
2-
from typing import Optional, Union
2+
from typing import Any, Optional, Union
33

4+
from attr import dataclass
45
from pydantic import BaseModel, Field
56

67
from invokeai.app.util.misc import get_iso_timestamp
78
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
89

10+
# This query is missing a GROUP BY clause, which is required for the query to be valid.
11+
BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
12+
SELECT b.board_id,
13+
b.board_name,
14+
b.created_at,
15+
b.updated_at,
16+
b.archived,
17+
COUNT(
18+
CASE
19+
WHEN i.image_category in ('general')
20+
AND i.is_intermediate = 0 THEN 1
21+
END
22+
) AS image_count,
23+
COUNT(
24+
CASE
25+
WHEN i.image_category in ('control', 'mask', 'user', 'other')
26+
AND i.is_intermediate = 0 THEN 1
27+
END
28+
) AS asset_count,
29+
(
30+
SELECT bi.image_name
31+
FROM board_images bi
32+
JOIN images i ON bi.image_name = i.image_name
33+
WHERE bi.board_id = b.board_id
34+
AND i.is_intermediate = 0
35+
ORDER BY i.created_at DESC
36+
LIMIT 1
37+
) AS cover_image_name
38+
FROM boards b
39+
LEFT JOIN board_images bi ON b.board_id = bi.board_id
40+
LEFT JOIN images i ON bi.image_name = i.image_name
41+
"""
42+
43+
44+
@dataclass
45+
class PaginatedBoardRecordsQueries:
46+
main_query: str
47+
total_count_query: str
48+
49+
50+
def get_paginated_list_board_records_query(include_archived: bool) -> PaginatedBoardRecordsQueries:
51+
"""Gets a query to retrieve a paginated list of board records."""
52+
53+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
54+
55+
# The GROUP BY must be added _after_ the WHERE clause!
56+
main_query = f"""
57+
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
58+
{archived_condition}
59+
GROUP BY b.board_id,
60+
b.board_name,
61+
b.created_at,
62+
b.updated_at
63+
ORDER BY b.created_at DESC
64+
LIMIT ? OFFSET ?;
65+
"""
66+
67+
total_count_query = f"""
68+
SELECT COUNT(*)
69+
FROM boards b
70+
{archived_condition};
71+
"""
72+
73+
return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query)
74+
75+
76+
def get_list_all_board_records_query(include_archived: bool) -> str:
77+
"""Gets a query to retrieve all board records."""
78+
79+
archived_condition = "WHERE b.archived = 0" if not include_archived else ""
80+
81+
# The GROUP BY must be added _after_ the WHERE clause!
82+
return f"""
83+
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
84+
{archived_condition}
85+
GROUP BY b.board_id,
86+
b.board_name,
87+
b.created_at,
88+
b.updated_at
89+
ORDER BY b.created_at DESC;
90+
"""
91+
92+
93+
def get_board_record_query() -> str:
94+
"""Gets a query to retrieve a board record."""
95+
96+
return f"""
97+
{BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY}
98+
WHERE b.board_id = ?;
99+
"""
100+
9101

10102
class BoardRecord(BaseModelExcludeNull):
11103
"""Deserialized board record."""
@@ -26,21 +118,25 @@ class BoardRecord(BaseModelExcludeNull):
26118
"""Whether or not the board is archived."""
27119
is_private: Optional[bool] = Field(default=None, description="Whether the board is private.")
28120
"""Whether the board is private."""
121+
image_count: int = Field(description="The number of images in the board.")
122+
asset_count: int = Field(description="The number of assets in the board.")
29123

30124

31-
def deserialize_board_record(board_dict: dict) -> BoardRecord:
125+
def deserialize_board_record(board_dict: dict[str, Any]) -> BoardRecord:
32126
"""Deserializes a board record."""
33127

34128
# Retrieve all the values, setting "reasonable" defaults if they are not present.
35129

36130
board_id = board_dict.get("board_id", "unknown")
37131
board_name = board_dict.get("board_name", "unknown")
38-
cover_image_name = board_dict.get("cover_image_name", "unknown")
132+
cover_image_name = board_dict.get("cover_image_name", None)
39133
created_at = board_dict.get("created_at", get_iso_timestamp())
40134
updated_at = board_dict.get("updated_at", get_iso_timestamp())
41135
deleted_at = board_dict.get("deleted_at", get_iso_timestamp())
42136
archived = board_dict.get("archived", False)
43137
is_private = board_dict.get("is_private", False)
138+
image_count = board_dict.get("image_count", 0)
139+
asset_count = board_dict.get("asset_count", 0)
44140

45141
return BoardRecord(
46142
board_id=board_id,
@@ -51,6 +147,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
51147
deleted_at=deleted_at,
52148
archived=archived,
53149
is_private=is_private,
150+
image_count=image_count,
151+
asset_count=asset_count,
54152
)
55153

56154

@@ -63,21 +161,21 @@ class BoardChanges(BaseModel, extra="forbid"):
63161
class BoardRecordNotFoundException(Exception):
64162
"""Raised when an board record is not found."""
65163

66-
def __init__(self, message="Board record not found"):
164+
def __init__(self, message: str = "Board record not found"):
67165
super().__init__(message)
68166

69167

70168
class BoardRecordSaveException(Exception):
71169
"""Raised when an board record cannot be saved."""
72170

73-
def __init__(self, message="Board record not saved"):
171+
def __init__(self, message: str = "Board record not saved"):
74172
super().__init__(message)
75173

76174

77175
class BoardRecordDeleteException(Exception):
78176
"""Raised when an board record cannot be deleted."""
79177

80-
def __init__(self, message="Board record not deleted"):
178+
def __init__(self, message: str = "Board record not deleted"):
81179
super().__init__(message)
82180

83181

invokeai/app/services/board_records/board_records_sqlite.py

Lines changed: 12 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
BoardRecordSaveException,
1212
UncategorizedImageCounts,
1313
deserialize_board_record,
14+
get_board_record_query,
15+
get_list_all_board_records_query,
16+
get_paginated_list_board_records_query,
1417
)
1518
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
1619
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
@@ -77,11 +80,7 @@ def get(
7780
try:
7881
self._lock.acquire()
7982
self._cursor.execute(
80-
"""--sql
81-
SELECT *
82-
FROM boards
83-
WHERE board_id = ?;
84-
""",
83+
get_board_record_query(),
8584
(board_id,),
8685
)
8786

@@ -93,7 +92,7 @@ def get(
9392
self._lock.release()
9493
if result is None:
9594
raise BoardRecordNotFoundException
96-
return BoardRecord(**dict(result))
95+
return deserialize_board_record(dict(result))
9796

9897
def update(
9998
self,
@@ -150,45 +149,17 @@ def get_many(
150149
try:
151150
self._lock.acquire()
152151

153-
# Build base query
154-
base_query = """
155-
SELECT *
156-
FROM boards
157-
{archived_filter}
158-
ORDER BY created_at DESC
159-
LIMIT ? OFFSET ?;
160-
"""
161-
162-
# Determine archived filter condition
163-
if include_archived:
164-
archived_filter = ""
165-
else:
166-
archived_filter = "WHERE archived = 0"
167-
168-
final_query = base_query.format(archived_filter=archived_filter)
152+
queries = get_paginated_list_board_records_query(include_archived=include_archived)
169153

170-
# Execute query to fetch boards
171-
self._cursor.execute(final_query, (limit, offset))
154+
self._cursor.execute(
155+
queries.main_query,
156+
(limit, offset),
157+
)
172158

173159
result = cast(list[sqlite3.Row], self._cursor.fetchall())
174160
boards = [deserialize_board_record(dict(r)) for r in result]
175161

176-
# Determine count query
177-
if include_archived:
178-
count_query = """
179-
SELECT COUNT(*)
180-
FROM boards;
181-
"""
182-
else:
183-
count_query = """
184-
SELECT COUNT(*)
185-
FROM boards
186-
WHERE archived = 0;
187-
"""
188-
189-
# Execute count query
190-
self._cursor.execute(count_query)
191-
162+
self._cursor.execute(queries.total_count_query)
192163
count = cast(int, self._cursor.fetchone()[0])
193164

194165
return OffsetPaginatedResults[BoardRecord](items=boards, offset=offset, limit=limit, total=count)
@@ -202,26 +173,9 @@ def get_many(
202173
def get_all(self, include_archived: bool = False) -> list[BoardRecord]:
203174
try:
204175
self._lock.acquire()
205-
206-
base_query = """
207-
SELECT *
208-
FROM boards
209-
{archived_filter}
210-
ORDER BY created_at DESC
211-
"""
212-
213-
if include_archived:
214-
archived_filter = ""
215-
else:
216-
archived_filter = "WHERE archived = 0"
217-
218-
final_query = base_query.format(archived_filter=archived_filter)
219-
220-
self._cursor.execute(final_query)
221-
176+
self._cursor.execute(get_list_all_board_records_query(include_archived=include_archived))
222177
result = cast(list[sqlite3.Row], self._cursor.fetchall())
223178
boards = [deserialize_board_record(dict(r)) for r in result]
224-
225179
return boards
226180

227181
except sqlite3.Error as e:
Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,8 @@
1-
from typing import Optional
2-
3-
from pydantic import Field
4-
51
from invokeai.app.services.board_records.board_records_common import BoardRecord
62

73

4+
# TODO(psyche): BoardDTO is now identical to BoardRecord. We should consider removing it.
85
class BoardDTO(BoardRecord):
9-
"""Deserialized board record with cover image URL and image count."""
10-
11-
cover_image_name: Optional[str] = Field(description="The name of the board's cover image.")
12-
"""The URL of the thumbnail of the most recent image in the board."""
13-
image_count: int = Field(description="The number of images in the board.")
14-
"""The number of images in the board."""
15-
6+
"""Deserialized board record."""
167

17-
def board_record_to_dto(board_record: BoardRecord, cover_image_name: Optional[str], image_count: int) -> BoardDTO:
18-
"""Converts a board record to a board DTO."""
19-
return BoardDTO(
20-
**board_record.model_dump(exclude={"cover_image_name"}),
21-
cover_image_name=cover_image_name,
22-
image_count=image_count,
23-
)
8+
pass
Lines changed: 6 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from invokeai.app.services.board_records.board_records_common import BoardChanges
22
from invokeai.app.services.boards.boards_base import BoardServiceABC
3-
from invokeai.app.services.boards.boards_common import BoardDTO, board_record_to_dto
3+
from invokeai.app.services.boards.boards_common import BoardDTO
44
from invokeai.app.services.invoker import Invoker
55
from invokeai.app.services.shared.pagination import OffsetPaginatedResults
66

@@ -16,32 +16,19 @@ def create(
1616
board_name: str,
1717
) -> BoardDTO:
1818
board_record = self.__invoker.services.board_records.save(board_name)
19-
return board_record_to_dto(board_record, None, 0)
19+
return BoardDTO.model_validate(board_record.model_dump())
2020

2121
def get_dto(self, board_id: str) -> BoardDTO:
2222
board_record = self.__invoker.services.board_records.get(board_id)
23-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
24-
if cover_image:
25-
cover_image_name = cover_image.image_name
26-
else:
27-
cover_image_name = None
28-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
29-
return board_record_to_dto(board_record, cover_image_name, image_count)
23+
return BoardDTO.model_validate(board_record.model_dump())
3024

3125
def update(
3226
self,
3327
board_id: str,
3428
changes: BoardChanges,
3529
) -> BoardDTO:
3630
board_record = self.__invoker.services.board_records.update(board_id, changes)
37-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(board_record.board_id)
38-
if cover_image:
39-
cover_image_name = cover_image.image_name
40-
else:
41-
cover_image_name = None
42-
43-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(board_id)
44-
return board_record_to_dto(board_record, cover_image_name, image_count)
31+
return BoardDTO.model_validate(board_record.model_dump())
4532

4633
def delete(self, board_id: str) -> None:
4734
self.__invoker.services.board_records.delete(board_id)
@@ -50,30 +37,10 @@ def get_many(
5037
self, offset: int = 0, limit: int = 10, include_archived: bool = False
5138
) -> OffsetPaginatedResults[BoardDTO]:
5239
board_records = self.__invoker.services.board_records.get_many(offset, limit, include_archived)
53-
board_dtos = []
54-
for r in board_records.items:
55-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
56-
if cover_image:
57-
cover_image_name = cover_image.image_name
58-
else:
59-
cover_image_name = None
60-
61-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
62-
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
63-
40+
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records.items]
6441
return OffsetPaginatedResults[BoardDTO](items=board_dtos, offset=offset, limit=limit, total=len(board_dtos))
6542

6643
def get_all(self, include_archived: bool = False) -> list[BoardDTO]:
6744
board_records = self.__invoker.services.board_records.get_all(include_archived)
68-
board_dtos = []
69-
for r in board_records:
70-
cover_image = self.__invoker.services.image_records.get_most_recent_image_for_board(r.board_id)
71-
if cover_image:
72-
cover_image_name = cover_image.image_name
73-
else:
74-
cover_image_name = None
75-
76-
image_count = self.__invoker.services.board_image_records.get_image_count_for_board(r.board_id)
77-
board_dtos.append(board_record_to_dto(r, cover_image_name, image_count))
78-
45+
board_dtos = [BoardDTO.model_validate(r.model_dump()) for r in board_records]
7946
return board_dtos

0 commit comments

Comments
 (0)