|
1 | 1 | import sqlite3
|
2 | 2 | import threading
|
| 3 | +from dataclasses import dataclass |
3 | 4 | from typing import Union, cast
|
4 | 5 |
|
5 | 6 | from invokeai.app.services.board_records.board_records_base import BoardRecordStorageBase
|
|
11 | 12 | BoardRecordSaveException,
|
12 | 13 | UncategorizedImageCounts,
|
13 | 14 | deserialize_board_record,
|
14 |
| - get_board_record_query, |
15 |
| - get_list_all_board_records_query, |
16 |
| - get_paginated_list_board_records_query, |
17 | 15 | )
|
18 | 16 | from invokeai.app.services.shared.pagination import OffsetPaginatedResults
|
19 | 17 | from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
20 | 18 | from invokeai.app.util.misc import uuid_string
|
21 | 19 |
|
| 20 | +# This query is missing a GROUP BY clause, which is required for the query to be valid. |
| 21 | +BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """ |
| 22 | + SELECT b.board_id, |
| 23 | + b.board_name, |
| 24 | + b.created_at, |
| 25 | + b.updated_at, |
| 26 | + b.archived, |
| 27 | + COUNT( |
| 28 | + CASE |
| 29 | + WHEN i.image_category in ('general') |
| 30 | + AND i.is_intermediate = 0 THEN 1 |
| 31 | + END |
| 32 | + ) AS image_count, |
| 33 | + COUNT( |
| 34 | + CASE |
| 35 | + WHEN i.image_category in ('control', 'mask', 'user', 'other') |
| 36 | + AND i.is_intermediate = 0 THEN 1 |
| 37 | + END |
| 38 | + ) AS asset_count, |
| 39 | + ( |
| 40 | + SELECT bi.image_name |
| 41 | + FROM board_images bi |
| 42 | + JOIN images i ON bi.image_name = i.image_name |
| 43 | + WHERE bi.board_id = b.board_id |
| 44 | + AND i.is_intermediate = 0 |
| 45 | + ORDER BY i.created_at DESC |
| 46 | + LIMIT 1 |
| 47 | + ) AS cover_image_name |
| 48 | + FROM boards b |
| 49 | + LEFT JOIN board_images bi ON b.board_id = bi.board_id |
| 50 | + LEFT JOIN images i ON bi.image_name = i.image_name |
| 51 | + """ |
| 52 | + |
| 53 | + |
| 54 | +@dataclass |
| 55 | +class PaginatedBoardRecordsQueries: |
| 56 | + main_query: str |
| 57 | + total_count_query: str |
| 58 | + |
| 59 | + |
| 60 | +def get_paginated_list_board_records_queries(include_archived: bool) -> PaginatedBoardRecordsQueries: |
| 61 | + """Gets a query to retrieve a paginated list of board records.""" |
| 62 | + |
| 63 | + archived_condition = "WHERE b.archived = 0" if not include_archived else "" |
| 64 | + |
| 65 | + # The GROUP BY must be added _after_ the WHERE clause! |
| 66 | + main_query = f""" |
| 67 | + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} |
| 68 | + {archived_condition} |
| 69 | + GROUP BY b.board_id, |
| 70 | + b.board_name, |
| 71 | + b.created_at, |
| 72 | + b.updated_at |
| 73 | + ORDER BY b.created_at DESC |
| 74 | + LIMIT ? OFFSET ?; |
| 75 | + """ |
| 76 | + |
| 77 | + total_count_query = f""" |
| 78 | + SELECT COUNT(*) |
| 79 | + FROM boards b |
| 80 | + {archived_condition}; |
| 81 | + """ |
| 82 | + |
| 83 | + return PaginatedBoardRecordsQueries(main_query=main_query, total_count_query=total_count_query) |
| 84 | + |
| 85 | + |
| 86 | +def get_list_all_board_records_query(include_archived: bool) -> str: |
| 87 | + """Gets a query to retrieve all board records.""" |
| 88 | + |
| 89 | + archived_condition = "WHERE b.archived = 0" if not include_archived else "" |
| 90 | + |
| 91 | + # The GROUP BY must be added _after_ the WHERE clause! |
| 92 | + return f""" |
| 93 | + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} |
| 94 | + {archived_condition} |
| 95 | + GROUP BY b.board_id, |
| 96 | + b.board_name, |
| 97 | + b.created_at, |
| 98 | + b.updated_at |
| 99 | + ORDER BY b.created_at DESC; |
| 100 | + """ |
| 101 | + |
| 102 | + |
| 103 | +def get_board_record_query() -> str: |
| 104 | + """Gets a query to retrieve a board record.""" |
| 105 | + |
| 106 | + return f""" |
| 107 | + {BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY} |
| 108 | + WHERE b.board_id = ?; |
| 109 | + """ |
| 110 | + |
22 | 111 |
|
23 | 112 | class SqliteBoardRecordStorage(BoardRecordStorageBase):
|
24 | 113 | _conn: sqlite3.Connection
|
@@ -149,7 +238,7 @@ def get_many(
|
149 | 238 | try:
|
150 | 239 | self._lock.acquire()
|
151 | 240 |
|
152 |
| - queries = get_paginated_list_board_records_query(include_archived=include_archived) |
| 241 | + queries = get_paginated_list_board_records_queries(include_archived=include_archived) |
153 | 242 |
|
154 | 243 | self._cursor.execute(
|
155 | 244 | queries.main_query,
|
|
0 commit comments