Skip to content

Commit 734fbc2

Browse files
Local ckpt support for python mlx (#7)
* Local ckpt support for python mlx * update README * Update lint.yml --------- Co-authored-by: Zach Nagengast <zacharynagengast@gmail.com>
1 parent 8cec0d3 commit 734fbc2

File tree

5 files changed

+27
-4
lines changed

5 files changed

+27
-4
lines changed

.github/workflows/lint.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,16 @@ jobs:
1414
- name: Checkout code
1515
uses: actions/checkout@v3
1616

17+
- name: Cache Homebrew
18+
uses: actions/cache@v3
19+
with:
20+
path: /home/linuxbrew/.linuxbrew
21+
key: ${{ runner.os }}-homebrew-${{ hashFiles('**/Brewfile.lock.json') }}
22+
restore-keys: |
23+
${{ runner.os }}-homebrew-
24+
1725
- name: Set up Homebrew
18-
id: set-up-homebrew
26+
if: steps.cache-homebrew.outputs.cache-hit != 'true'
1927
uses: Homebrew/actions/setup-homebrew@master
2028

2129
- name: Setup environment

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ Some notable optional arguments:
7474
- For image-to-image, use `--image-path` (path to input image) and `--denoise` (value between 0. and 1.)
7575
- T5 text embeddings, use `--t5`
7676
- For different resolutions, use `--height` and `--width`
77+
- For using a local checkpoint, use `--local-ckpt </path/to/ckpt.safetensors>` (e.g. `~/models/stable-diffusion-3-medium/sd3_medium.safetensors`).
7778

7879
Please refer to the help menu for all available arguments: `diffusionkit-cli -h`.
7980

python/src/mlx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def __init__(
4848
model_size: str = "2b",
4949
low_memory_mode: bool = True,
5050
a16: bool = False,
51+
local_ckpt=None,
5152
):
53+
model_io.LOCAl_SD3_CKPT = local_ckpt
5254
self.dtype = mx.float16 if w16 else mx.float32
5355
self.activation_dtype = mx.float16 if a16 else mx.float32
5456
self.use_t5 = use_t5

python/src/mlx/model_io.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
"8b": 192,
6363
}
6464

65+
LOCAl_SD3_CKPT = None
66+
6567

6668
def mmdit_state_dict_adjustments(state_dict, prefix=""):
6769
# Remove prefix
@@ -453,7 +455,8 @@ def load_mmdit(
453455
model = MMDiT(config)
454456

455457
mmdit_weights = _MMDIT[key][model_key]
456-
weights = mx.load(hf_hub_download(key, mmdit_weights))
458+
mmdit_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, mmdit_weights)
459+
weights = mx.load(mmdit_weights_ckpt)
457460
weights = mmdit_state_dict_adjustments(weights, prefix="model.diffusion_model.")
458461
weights = {k: v.astype(dtype) for k, v in weights.items()}
459462
model.update(tree_unflatten(tree_flatten(weights)))
@@ -548,7 +551,8 @@ def load_vae_decoder(
548551

549552
dtype = mx.float16 if float16 else mx.float32
550553
vae_weights = _MMDIT[key][model_key]
551-
weights = mx.load(hf_hub_download(key, vae_weights))
554+
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
555+
weights = mx.load(vae_weights_ckpt)
552556
weights = vae_decoder_state_dict_adjustments(
553557
weights, prefix="first_stage_model.decoder."
554558
)
@@ -575,7 +579,8 @@ def load_vae_encoder(
575579

576580
dtype = mx.float16 if float16 else mx.float32
577581
vae_weights = _MMDIT[key][model_key]
578-
weights = mx.load(hf_hub_download(key, vae_weights))
582+
vae_weights_ckpt = LOCAl_SD3_CKPT or hf_hub_download(key, vae_weights)
583+
weights = mx.load(vae_weights_ckpt)
579584
weights = vae_encoder_state_dict_adjustments(
580585
weights, prefix="first_stage_model.encoder."
581586
)

python/src/mlx/scripts/generate_images.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ def cli():
9999
default=0.0,
100100
help="Denoising factor when an input image is provided. (between 0.0 and 1.0)",
101101
)
102+
parser.add_argument(
103+
"--local-ckpt",
104+
default=None,
105+
type=str,
106+
help="Path to the local mmdit checkpoint.",
107+
)
102108
args = parser.parse_args()
103109

104110
if args.benchmark_mode:
@@ -119,6 +125,7 @@ def cli():
119125
model_size=args.model_size,
120126
low_memory_mode=args.low_memory_mode,
121127
a16=args.a16,
128+
local_ckpt=args.local_ckpt,
122129
)
123130

124131
# Ensure that models are read in memory if needed

0 commit comments

Comments
 (0)