Skip to content

Commit b329c27

Browse files
committed
Formatting
1 parent b076287 commit b329c27

File tree

7 files changed

+11
-4
lines changed

7 files changed

+11
-4
lines changed

.flake8

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
[flake8]
22
max-line-length = 120
3+
extend-ignore = E203
34
filename = *.py
45

56
[isort]

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ __pycache__/
2727
# Distribution / packaging
2828
.Python
2929
build/
30+
.build/
3031
develop-eggs/
3132
dist/
3233
downloads/

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ repos:
1111
- id: black
1212
name: black
1313
language: python
14+
args: ["--config", ".flake8"]
1415

1516
- repo: https://github.com/pre-commit/pre-commit-hooks
1617
rev: v4.5.0

python/src/diffusionkit/mlx/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
4040
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
4141
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
42-
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev"
42+
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
4343
}
4444

4545
T5_MAX_LENGTH = {
@@ -49,6 +49,7 @@
4949
"argmaxinc/mlx-FLUX.1-dev": 512,
5050
}
5151

52+
5253
class DiffusionKitInferenceContext(AppleSiliconContextMixin, InferenceContextSpec):
5354
def code_spec(self):
5455
return {}

python/src/diffusionkit/mlx/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def hidden_size(self) -> int:
105105
use_qk_norm=True,
106106
float16_dtype=mx.bfloat16,
107107
guidance_embed=True, # Add this line
108-
dtype=mx.bfloat16
108+
dtype=mx.bfloat16,
109109
)
110110

111111

python/src/diffusionkit/mlx/mmdit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def __init__(self, config: MMDiTConfig):
2929
self.config = config
3030

3131
if config.guidance_embed:
32-
self.guidance_in = MLPEmbedder(in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size)
32+
self.guidance_in = MLPEmbedder(
33+
in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size
34+
)
3335
else:
3436
self.guidance_in = nn.Identity()
3537

@@ -939,6 +941,7 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
939941
.flatten(-2)
940942
)
941943

944+
942945
class MLPEmbedder(nn.Module):
943946
def __init__(self, in_dim: int, hidden_dim: int):
944947
super().__init__()

python/src/diffusionkit/mlx/scripts/generate_images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def cli():
111111
args.a16 = True
112112

113113
if "FLUX" in args.model_version and args.cfg > 0.0:
114-
logger.warning("Disabling CFG for FLUX.1-schnell model.")
114+
logger.warning(f"Disabling CFG for {args.model_version} model.")
115115
args.cfg = 0.0
116116

117117
if args.benchmark_mode:

0 commit comments

Comments
 (0)