Skip to content

Commit a0a691c

Browse files
committed
Managed install: fix arch packages not being selectable when there is more than 1 workload selected
1 parent 8cc7c73 commit a0a691c

File tree

1 file changed

+16
-17
lines changed

1 file changed

+16
-17
lines changed

ai_diffusion/ui/server.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
parent=None,
6565
):
6666
super().__init__(parent)
67-
self._workload = Arch.all
67+
self._workloads: list[Arch] = []
6868
self._backend = settings.server_backend
6969
self._is_checkable = False
7070

@@ -180,7 +180,7 @@ def _backend_supports(self, item: PackageItem):
180180
def _workload_matches(self, item: PackageItem):
181181
return (
182182
not isinstance(item.package, ModelResource)
183-
or Arch.match(self._workload, item.package.arch)
183+
or item.package.arch in self.workloads
184184
or item.package.arch not in [Arch.sd15, Arch.sdxl, Arch.flux, Arch.flux_k, Arch.chroma]
185185
)
186186

@@ -204,12 +204,12 @@ def selected_packages(self):
204204
]
205205

206206
@property
207-
def workload(self):
208-
return self._workload
207+
def workloads(self):
208+
return self._workloads
209209

210-
@workload.setter
211-
def workload(self, workload: Arch):
212-
self._workload = workload
210+
@workloads.setter
211+
def workloads(self, workloads: list[Arch]):
212+
self._workloads = workloads
213213
self._update()
214214

215215
@property
@@ -783,9 +783,9 @@ def show_error(self, error: str):
783783
self._status_label.setStyleSheet(f"color:{red}")
784784

785785
def change_workload(self):
786-
if self.selected_workload is Arch.sd15:
786+
if self._workload_group.values[0] is PackageState.selected:
787787
self._packages["sd15"].expand()
788-
elif self.selected_workload is Arch.flux:
788+
if self._workload_group.values[2] is PackageState.selected:
789789
self._packages["flux"].expand()
790790
self.update_ui()
791791

@@ -821,7 +821,7 @@ def update_optional(self):
821821
]
822822

823823
for widget in self._packages.values():
824-
widget.workload = self.selected_workload
824+
widget.workloads = self.selected_workloads
825825
widget.backend = self._server.backend
826826
widget.set_installed([self._server.is_installed(p) for p in widget.package_names])
827827

@@ -839,17 +839,16 @@ def requires_install(self):
839839
return install_required or install_optional
840840

841841
@property
842-
def selected_workload(self):
842+
def selected_workloads(self):
843843
selected_or_installed = [
844844
state in [PackageState.selected, PackageState.installed]
845845
for state in self._workload_group.values
846846
]
847-
if all(selected_or_installed):
848-
return Arch.all
847+
result = []
849848
if selected_or_installed[0]:
850-
return Arch.sd15
849+
result.append(Arch.sd15)
851850
if selected_or_installed[1]:
852-
return Arch.sdxl
851+
result.append(Arch.sdxl)
853852
if selected_or_installed[2]:
854-
return Arch.flux
855-
return Arch.auto
853+
result.append(Arch.flux)
854+
return result

0 commit comments

Comments
 (0)