Skip to content

Commit f5abc62

Browse files
committed
convert to coreml from local ckpt
1 parent 11096c7 commit f5abc62

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,17 @@ huggingface-cli login --token YOUR_HF_HUB_TOKEN
4747
**Step 2:** Prepare the denoise model (MMDiT) Core ML model files (`.mlpackage`)
4848

4949
```shell
50-
python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> --model-version {2b} -o <output-mlpackages-directory> --latent-size {64, 128}
50+
python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path <path-to-sd3-mmdit.safetensors or model-version-string-from-hub> --model-version {2b} -o <output-mlpackages-directory> --latent-size {64, 128}
5151
```
5252

5353
**Step 3:** Prepare the VAE Decoder Core ML model files (`.mlpackage`)
5454

5555
```shell
56-
python -m tests.torch2coreml.test_vae --sd3-ckpt-path <path-to-sd3-mmdit.safetensors> -o <output-mlpackages-directory> --latent-size {64, 128}
56+
python -m tests.torch2coreml.test_vae --sd3-ckpt-path <path-to-sd3-mmdit.safetensors or model-version-string-from-hub> -o <output-mlpackages-directory> --latent-size {64, 128}
5757
```
58+
59+
Note:
60+
- `--sd3-ckpt-path` can be a path to a local `.safetensors` file or a HuggingFace repo (e.g. `stabilityai/stable-diffusion-3-medium`).
5861
</details>
5962

6063
## <a name="image-generation-with-python-mlx"></a> Image Generation with Python MLX

tests/torch2coreml/test_mmdit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class TestSD3MMDiT(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):
4747

4848
@classmethod
4949
def setUpClass(cls):
50+
global TEST_SD3_CKPT_PATH
5051
cls.model_name = "MultiModalDiffusionTransformer"
5152
cls.test_output_names = ["denoiser_output"]
5253
cls.test_cache_dir = TEST_CACHE_DIR
@@ -60,7 +61,7 @@ def setUpClass(cls):
6061
.eval()
6162
)
6263
logger.info("Initialized.")
63-
TEST_SD3_CKPT_PATH = hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
64+
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
6465
if TEST_SD3_CKPT_PATH is not None:
6566

6667
logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}")
@@ -133,6 +134,7 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]:
133134
parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int)
134135
args = parser.parse_args()
135136

137+
TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
136138
TEST_SD3_HF_REPO = args.sd3_ckpt_path
137139
TEST_LATENT_SIZE = args.latent_size
138140
TEST_CKPT_FILE_NAME = args.ckpt_file_name

tests/torch2coreml/test_vae.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TestSD3VAEDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCa
4545

4646
@classmethod
4747
def setUpClass(cls):
48+
global TEST_SD3_CKPT_PATH
4849
cls.model_name = "VAEDecoder"
4950
cls.test_output_names = ["image"]
5051
cls.test_cache_dir = TEST_CACHE_DIR
@@ -56,7 +57,7 @@ def setUpClass(cls):
5657
)
5758
logger.info("Initialized.")
5859

59-
TEST_SD3_CKPT_PATH = hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
60+
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
6061
if TEST_SD3_CKPT_PATH is not None:
6162
logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}")
6263
_load_vae_decoder_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH)
@@ -103,6 +104,7 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]:
103104
parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int)
104105
args = parser.parse_args()
105106

107+
TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
106108
TEST_SD3_HF_REPO = args.sd3_ckpt_path
107109
TEST_LATENT_SIZE = args.latent_size
108110

0 commit comments

Comments
 (0)