1
1
from datetime import datetime
2
- from typing import Optional , Union
2
+ from typing import Any , Optional , Union
3
3
4
+ from attr import dataclass
4
5
from pydantic import BaseModel , Field
5
6
6
7
from invokeai .app .util .misc import get_iso_timestamp
7
8
from invokeai .app .util .model_exclude_null import BaseModelExcludeNull
8
9
10
+ # This query is missing a GROUP BY clause, which is required for the query to be valid.
11
+ BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY = """
12
+ SELECT b.board_id,
13
+ b.board_name,
14
+ b.created_at,
15
+ b.updated_at,
16
+ b.archived,
17
+ COUNT(
18
+ CASE
19
+ WHEN i.image_category in ('general')
20
+ AND i.is_intermediate = 0 THEN 1
21
+ END
22
+ ) AS image_count,
23
+ COUNT(
24
+ CASE
25
+ WHEN i.image_category in ('control', 'mask', 'user', 'other')
26
+ AND i.is_intermediate = 0 THEN 1
27
+ END
28
+ ) AS asset_count,
29
+ (
30
+ SELECT bi.image_name
31
+ FROM board_images bi
32
+ JOIN images i ON bi.image_name = i.image_name
33
+ WHERE bi.board_id = b.board_id
34
+ AND i.is_intermediate = 0
35
+ ORDER BY i.created_at DESC
36
+ LIMIT 1
37
+ ) AS cover_image_name
38
+ FROM boards b
39
+ LEFT JOIN board_images bi ON b.board_id = bi.board_id
40
+ LEFT JOIN images i ON bi.image_name = i.image_name
41
+ """
42
+
43
+
44
+ @dataclass
45
+ class PaginatedBoardRecordsQueries :
46
+ main_query : str
47
+ total_count_query : str
48
+
49
+
50
+ def get_paginated_list_board_records_query (include_archived : bool ) -> PaginatedBoardRecordsQueries :
51
+ """Gets a query to retrieve a paginated list of board records."""
52
+
53
+ archived_condition = "WHERE b.archived = 0" if not include_archived else ""
54
+
55
+ # The GROUP BY must be added _after_ the WHERE clause!
56
+ main_query = f"""
57
+ { BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY }
58
+ { archived_condition }
59
+ GROUP BY b.board_id,
60
+ b.board_name,
61
+ b.created_at,
62
+ b.updated_at
63
+ ORDER BY b.created_at DESC
64
+ LIMIT ? OFFSET ?;
65
+ """
66
+
67
+ total_count_query = f"""
68
+ SELECT COUNT(*)
69
+ FROM boards b
70
+ { archived_condition } ;
71
+ """
72
+
73
+ return PaginatedBoardRecordsQueries (main_query = main_query , total_count_query = total_count_query )
74
+
75
+
76
+ def get_list_all_board_records_query (include_archived : bool ) -> str :
77
+ """Gets a query to retrieve all board records."""
78
+
79
+ archived_condition = "WHERE b.archived = 0" if not include_archived else ""
80
+
81
+ # The GROUP BY must be added _after_ the WHERE clause!
82
+ return f"""
83
+ { BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY }
84
+ { archived_condition }
85
+ GROUP BY b.board_id,
86
+ b.board_name,
87
+ b.created_at,
88
+ b.updated_at
89
+ ORDER BY b.created_at DESC;
90
+ """
91
+
92
+
93
+ def get_board_record_query () -> str :
94
+ """Gets a query to retrieve a board record."""
95
+
96
+ return f"""
97
+ { BASE_UNTERMINATED_AND_MISSING_GROUP_BY_BOARD_RECORDS_QUERY }
98
+ WHERE b.board_id = ?;
99
+ """
100
+
9
101
10
102
class BoardRecord (BaseModelExcludeNull ):
11
103
"""Deserialized board record."""
@@ -26,21 +118,25 @@ class BoardRecord(BaseModelExcludeNull):
26
118
"""Whether or not the board is archived."""
27
119
is_private : Optional [bool ] = Field (default = None , description = "Whether the board is private." )
28
120
"""Whether the board is private."""
121
+ image_count : int = Field (description = "The number of images in the board." )
122
+ asset_count : int = Field (description = "The number of assets in the board." )
29
123
30
124
31
- def deserialize_board_record (board_dict : dict ) -> BoardRecord :
125
+ def deserialize_board_record (board_dict : dict [ str , Any ] ) -> BoardRecord :
32
126
"""Deserializes a board record."""
33
127
34
128
# Retrieve all the values, setting "reasonable" defaults if they are not present.
35
129
36
130
board_id = board_dict .get ("board_id" , "unknown" )
37
131
board_name = board_dict .get ("board_name" , "unknown" )
38
- cover_image_name = board_dict .get ("cover_image_name" , "unknown" )
132
+ cover_image_name = board_dict .get ("cover_image_name" , None )
39
133
created_at = board_dict .get ("created_at" , get_iso_timestamp ())
40
134
updated_at = board_dict .get ("updated_at" , get_iso_timestamp ())
41
135
deleted_at = board_dict .get ("deleted_at" , get_iso_timestamp ())
42
136
archived = board_dict .get ("archived" , False )
43
137
is_private = board_dict .get ("is_private" , False )
138
+ image_count = board_dict .get ("image_count" , 0 )
139
+ asset_count = board_dict .get ("asset_count" , 0 )
44
140
45
141
return BoardRecord (
46
142
board_id = board_id ,
@@ -51,6 +147,8 @@ def deserialize_board_record(board_dict: dict) -> BoardRecord:
51
147
deleted_at = deleted_at ,
52
148
archived = archived ,
53
149
is_private = is_private ,
150
+ image_count = image_count ,
151
+ asset_count = asset_count ,
54
152
)
55
153
56
154
@@ -63,21 +161,21 @@ class BoardChanges(BaseModel, extra="forbid"):
63
161
class BoardRecordNotFoundException (Exception ):
64
162
"""Raised when an board record is not found."""
65
163
66
- def __init__ (self , message = "Board record not found" ):
164
+ def __init__ (self , message : str = "Board record not found" ):
67
165
super ().__init__ (message )
68
166
69
167
70
168
class BoardRecordSaveException (Exception ):
71
169
"""Raised when an board record cannot be saved."""
72
170
73
- def __init__ (self , message = "Board record not saved" ):
171
+ def __init__ (self , message : str = "Board record not saved" ):
74
172
super ().__init__ (message )
75
173
76
174
77
175
class BoardRecordDeleteException (Exception ):
78
176
"""Raised when an board record cannot be deleted."""
79
177
80
- def __init__ (self , message = "Board record not deleted" ):
178
+ def __init__ (self , message : str = "Board record not deleted" ):
81
179
super ().__init__ (message )
82
180
83
181
0 commit comments