From 558dcf8ceacb2b6693fd7162ef80bd060f86d335 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 10:14:46 +1000 Subject: [PATCH 1/5] tests: monkeypatch secondary reference to `gguf_sd_loader()` `gguf_sd_loader()` has multiple references in the codebase. It is imported before monkeypatching, so we need to monkeypatch another reference to it. This fixes tests for `ModelOnDisk.load_state_dict()`. --- tests/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/conftest.py b/tests/conftest.py index 5cf96551e00..d2d87019739 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -96,6 +96,7 @@ def override_model_loading(monkeypatch): monkeypatch.setattr(safetensors.torch, "load", load_stripped_model) monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model) monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model) + monkeypatch.setattr("invokeai.backend.model_manager.config.gguf_sd_loader", load_stripped_model) def fake_scan(*args, **kwargs): return SimpleNamespace(infected_files=0, scan_err=None) From ae650737112d28438679cdbc2497562902c3ec17 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 10:20:50 +1000 Subject: [PATCH 2/5] tests: add stripped models for FLUX varieties Stripped models for: - FLUX Dev.safetensors - FLUX Schnell.safetensors - FLUX Fill.safetensors - FLUX Dev (Quantized).safetensors - FLUX Schnell (Quantized).safetensors - flux1-fill-dev-Q8_0.gguf - midjourneyReplica_flux1Dev.safetensors --- .../stripped_models/FLUX Dev (Quantized).safetensors | 3 +++ tests/test_model_probe/stripped_models/FLUX Dev.safetensors | 3 +++ tests/test_model_probe/stripped_models/FLUX Fill.safetensors | 3 +++ .../stripped_models/FLUX Schnell (Quantized).safetensors | 3 +++ .../test_model_probe/stripped_models/FLUX Schnell.safetensors | 3 +++ .../test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf | 3 +++ .../stripped_models/midjourneyReplica_flux1Dev.safetensors | 3 +++ 7 files changed, 21 insertions(+) create mode 100644 tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors create mode 100644 tests/test_model_probe/stripped_models/FLUX Dev.safetensors create mode 100644 tests/test_model_probe/stripped_models/FLUX Fill.safetensors create mode 100644 tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors create mode 100644 tests/test_model_probe/stripped_models/FLUX Schnell.safetensors create mode 100644 tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf create mode 100644 tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors diff --git a/tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors b/tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors new file mode 100644 index 00000000000..e8646cb043a --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Dev (Quantized).safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fe25212279fec351340d1c4a9da0eb902af82162350970c148bf331c1c02f3c5 +size 292730 diff --git a/tests/test_model_probe/stripped_models/FLUX Dev.safetensors b/tests/test_model_probe/stripped_models/FLUX Dev.safetensors new file mode 100644 index 00000000000..32718fac409 --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Dev.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:84850676ab6fc163b4fe3bb87b1584a5a78b523e5f6e58b6ecb2c7d34e4c0796 +size 130743 diff --git a/tests/test_model_probe/stripped_models/FLUX Fill.safetensors b/tests/test_model_probe/stripped_models/FLUX Fill.safetensors new file mode 100644 index 00000000000..d15e6cb0e0b --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Fill.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eb64744f32674cd1e8c3c09e578d18e1ca84c3deac0ef0a2fc3654ec9ac0a84d +size 130744 diff --git a/tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors b/tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors new file mode 100644 index 00000000000..30688e9c339 --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Schnell (Quantized).safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:42cd75dbd5dec6252de6f959a6ed678fb0e5bef166eca7ac38c51577a0d4e4eb +size 291091 diff --git a/tests/test_model_probe/stripped_models/FLUX Schnell.safetensors b/tests/test_model_probe/stripped_models/FLUX Schnell.safetensors new file mode 100644 index 00000000000..4f46a9fe198 --- /dev/null +++ b/tests/test_model_probe/stripped_models/FLUX Schnell.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a1533dced878ca5a8bae39bfdbed85dfd97e937ec3c97540da1e7d4011ffed98 +size 130098 diff --git a/tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf b/tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf new file mode 100644 index 00000000000..deabac76c9e --- /dev/null +++ b/tests/test_model_probe/stripped_models/flux1-fill-dev-Q8_0.gguf @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:cac069dd904e0d676baacecfeaba52bbbe808a6d755dabdd94c7281656fa0507 +size 129356 diff --git a/tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors b/tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors new file mode 100644 index 00000000000..9fd14405496 --- /dev/null +++ b/tests/test_model_probe/stripped_models/midjourneyReplica_flux1Dev.safetensors @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:98d0f54489ec096f543a9b8f88683fd960acd96521d987e027be9e23d621d96f +size 151803 From 9707724ec27abaa4891ea3fa1b33b90e27377989 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 10:21:41 +1000 Subject: [PATCH 3/5] docs: add reminder to FLUX variant probing to add test cases if we have a probe failure --- invokeai/backend/model_manager/legacy_probe.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/invokeai/backend/model_manager/legacy_probe.py b/invokeai/backend/model_manager/legacy_probe.py index 8a0e770d037..dd94ee0e900 100644 --- a/invokeai/backend/model_manager/legacy_probe.py +++ b/invokeai/backend/model_manager/legacy_probe.py @@ -572,6 +572,8 @@ def get_variant_type(self) -> ModelVariantType: if in_channels is None: # If we cannot find the in_channels, we assume that this is a normal variant. Log a warning. + # If this occurs, we should add a test case for the affected model here: + # tests/backend/flux/test_flux_state_dict_utils.py logger.warning( f"{self.model_path} does not have img_in.weight or model.diffusion_model.img_in.weight key. Assuming normal variant." ) From 88ea6b538d2b0ebddba97bf54dee02982cfb96fb Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 10:21:57 +1000 Subject: [PATCH 4/5] tests: add test for `get_flux_in_channels_from_state_dict()` --- .../flux/test_flux_state_dict_utils.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 tests/backend/flux/test_flux_state_dict_utils.py diff --git a/tests/backend/flux/test_flux_state_dict_utils.py b/tests/backend/flux/test_flux_state_dict_utils.py new file mode 100644 index 00000000000..c4540ef0d22 --- /dev/null +++ b/tests/backend/flux/test_flux_state_dict_utils.py @@ -0,0 +1,35 @@ +from pathlib import Path + +import pytest + +from invokeai.backend.flux.flux_state_dict_utils import get_flux_in_channels_from_state_dict +from invokeai.backend.model_manager.config import ModelOnDisk + +test_cases = [ + # Unquantized + ("FLUX Dev.safetensors", 64), + ("FLUX Schnell.safetensors", 64), + ("FLUX Fill.safetensors", 384), + # BNB-NF4 quantized + ("FLUX Dev (Quantized).safetensors", 1), # BNB-NF4 + ("FLUX Schnell (Quantized).safetensors", 1), # BNB-NF4 + # GGUF quantized FLUX Fill + ("flux1-fill-dev-Q8_0.gguf", 384), + # Fine-tune w/ "model.diffusion_model.img_in.weight" instead of "img_in.weight" + ("midjourneyReplica_flux1Dev.safetensors", 64), + # Not a FLUX model, testing fallback case + ("Noodles Style.safetensors", None), +] + + +@pytest.mark.parametrize("model_file_name,expected_in_channels", test_cases) +def test_get_flux_in_channels_from_state_dict(model_file_name: str, expected_in_channels: int, override_model_loading): + model_path = Path(f"tests/test_model_probe/stripped_models/{model_file_name}") + + mod = ModelOnDisk(model_path) + + state_dict = mod.load_state_dict() + + in_channels = get_flux_in_channels_from_state_dict(state_dict) + + assert in_channels == expected_in_channels From 34457fc3812c26469d5a2524d40fd4f5e3dc50f7 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Mon, 31 Mar 2025 10:39:30 +1000 Subject: [PATCH 5/5] tests: monkeypatch more references to `gguf_sd_loader()` The new util in de9f541bf60199e14d159dc3511482ae38a5cb60 alters import order, breaking some model probe tests. We need to patch more references to `gguf_sd_loader()` to fix em --- tests/conftest.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index d2d87019739..724c03938a5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -97,6 +97,8 @@ def override_model_loading(monkeypatch): monkeypatch.setattr(safetensors.torch, "load_file", load_stripped_model) monkeypatch.setattr(gguf_loaders, "gguf_sd_loader", load_stripped_model) monkeypatch.setattr("invokeai.backend.model_manager.config.gguf_sd_loader", load_stripped_model) + monkeypatch.setattr("invokeai.backend.model_manager.util.model_util.gguf_sd_loader", load_stripped_model) + monkeypatch.setattr("invokeai.backend.model_manager.legacy_probe.gguf_sd_loader", load_stripped_model) def fake_scan(*args, **kwargs): return SimpleNamespace(infected_files=0, scan_err=None)