Skip to content

Commit ad8b50c

Browse files
authored
Merge pull request #10 from argmaxinc/coreml-converter-helpers
Add CoreML conversion helpers
2 parents 84e44fd + f748a99 commit ad8b50c

File tree

9 files changed

+156
-40
lines changed

9 files changed

+156
-40
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ huggingface-cli login --token YOUR_HF_HUB_TOKEN
4848
**Step 3:** Prepare the denoise model (MMDiT) Core ML model files (`.mlpackage`)
4949

5050
```shell
51-
python -m tests.torch2coreml.test_mmdit --sd3-ckpt-path stabilityai/stable-diffusion-3-medium --model-version 2b -o <output-mlpackages-directory> --latent-size {64, 128}
51+
python -m python.src.diffusionkit.tests.torch2coreml.test_mmdit --sd3-ckpt-path stabilityai/stable-diffusion-3-medium --model-version 2b -o <output-mlpackages-directory> --latent-size {64, 128}
5252
```
5353

5454
**Step 4:** Prepare the VAE Decoder Core ML model files (`.mlpackage`)
5555

5656
```shell
57-
python -m tests.torch2coreml.test_vae --sd3-ckpt-path stabilityai/stable-diffusion-3-medium -o <output-mlpackages-directory> --latent-size {64, 128}
57+
python -m python.src.diffusionkit.tests.torch2coreml.test_vae --sd3-ckpt-path stabilityai/stable-diffusion-3-medium -o <output-mlpackages-directory> --latent-size {64, 128}
5858
```
5959

6060
Note:

python/src/diffusionkit/tests/__init__.py

Whitespace-only changes.

python/src/diffusionkit/tests/mlx/__init__.py

Whitespace-only changes.

tests/mlx/test_diffusion_pipeline.py renamed to python/src/diffusionkit/tests/mlx/test_diffusion_pipeline.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
import mlx.core as mx
1111
import numpy as np
1212
from argmaxtools.utils import get_logger
13+
from diffusionkit.mlx import DiffusionPipeline
14+
from diffusionkit.utils import image_psnr
1315
from huggingface_hub import hf_hub_download
1416
from PIL import Image
1517

16-
from python.src.diffusionkit.mlx import DiffusionPipeline
17-
from python.src.diffusionkit.utils import image_psnr
18-
1918
logger = get_logger(__name__)
2019

2120
W16 = False
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .test_mmdit import convert_mmdit_to_mlpackage
2+
from .test_vae import convert_vae_to_mlpackage

tests/torch2coreml/test_mmdit.py renamed to python/src/diffusionkit/tests/torch2coreml/test_mmdit.py

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import torch
1212
from argmaxtools import test_utils as argmaxtools_test_utils
1313
from argmaxtools.utils import get_fastest_device, get_logger
14+
from diffusionkit.torch import mmdit
15+
from diffusionkit.torch.model_io import _load_mmdit_weights
1416
from huggingface_hub import hf_hub_download
1517

16-
from python.src.diffusionkit.torch import mmdit
17-
from python.src.diffusionkit.torch.model_io import _load_mmdit_weights
18-
1918
torch.set_grad_enabled(False)
2019
logger = get_logger(__name__)
2120

@@ -27,24 +26,38 @@
2726
TEST_TORCH_DTYPE = torch.float32
2827
TEST_PSNR_THR = 35
2928
TEST_LATENT_SIZE = 64 # 64 latent -> 512 image, 128 latent -> 1024 image
30-
31-
# Test configuration
32-
argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = 3.0
33-
argmaxtools_test_utils.TEST_COREML_PRECISION = ct.precision.FLOAT32
34-
argmaxtools_test_utils.TEST_COMPUTE_UNIT = ct.ComputeUnit.CPU_AND_GPU
35-
argmaxtools_test_utils.TEST_COMPRESSION_MIN_SPEEDUP = 0.2
36-
argmaxtools_test_utils.TEST_DEFAULT_NBITS = None
37-
argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = True
29+
TEST_LATENT_HEIGHT = TEST_LATENT_SIZE
30+
TEST_LATENT_WIDTH = TEST_LATENT_SIZE
3831

3932
TEST_MODELS = {
4033
"2b": mmdit.SD3_2b,
4134
"8b": mmdit.SD3_8b,
4235
}
4336

4437

38+
def setup_test_config(
39+
min_speedup_vs_cpu=3.0,
40+
compute_precision=ct.precision.FLOAT32,
41+
compute_unit=ct.ComputeUnit.CPU_AND_GPU,
42+
compression_min_speedup=0.2,
43+
default_nbits=None,
44+
skip_speed_tests=True,
45+
compile_coreml=False,
46+
):
47+
argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = min_speedup_vs_cpu
48+
argmaxtools_test_utils.TEST_COREML_PRECISION = compute_precision
49+
argmaxtools_test_utils.TEST_COMPUTE_UNIT = compute_unit
50+
argmaxtools_test_utils.TEST_COMPRESSION_MIN_SPEEDUP = compression_min_speedup
51+
argmaxtools_test_utils.TEST_DEFAULT_NBITS = default_nbits
52+
argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = skip_speed_tests
53+
argmaxtools_test_utils.TEST_COMPILE_COREML = compile_coreml
54+
55+
4556
class TestSD3MMDiT(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):
4657
"""Unit tests for stable_duffusion_3.mmdit.MMDiT module"""
4758

59+
model_version = "2b"
60+
4861
@classmethod
4962
def setUpClass(cls):
5063
global TEST_SD3_CKPT_PATH
@@ -55,7 +68,7 @@ def setUpClass(cls):
5568
# Base test model
5669
logger.info("Initializing SD3 model")
5770
cls.test_torch_model = (
58-
mmdit.MMDiT(TEST_MODELS[args.model_version])
71+
mmdit.MMDiT(TEST_MODELS[cls.model_version])
5972
.to(TEST_DEV)
6073
.to(TEST_TORCH_DTYPE)
6174
.eval()
@@ -75,7 +88,7 @@ def setUpClass(cls):
7588

7689
# Sample inputs
7790
# TODO(atiorh): CLI configurable model version
78-
cls.test_torch_inputs = get_test_inputs(TEST_MODELS[args.model_version])
91+
cls.test_torch_inputs = get_test_inputs(TEST_MODELS[cls.model_version])
7992

8093
super().setUpClass()
8194

@@ -89,13 +102,14 @@ def tearDownClass(cls):
89102
def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]:
90103
"""Generate random inputs for the SD3 MMDiT model"""
91104
batch_size = 2 # classifier-free guidance
92-
assert TEST_LATENT_SIZE < cfg.max_latent_resolution
105+
assert TEST_LATENT_HEIGHT <= cfg.max_latent_resolution
106+
assert TEST_LATENT_WIDTH <= cfg.max_latent_resolution
93107

94108
latent_image_embeddings_dims = (
95109
batch_size,
96110
cfg.vae_latent_dim,
97-
TEST_LATENT_SIZE,
98-
TEST_LATENT_SIZE,
111+
TEST_LATENT_HEIGHT,
112+
TEST_LATENT_WIDTH,
99113
)
100114
pooled_text_embeddings_dims = (batch_size, cfg.pooled_text_embed_dim, 1, 1)
101115
token_level_text_embeddings_dims = (
@@ -118,6 +132,42 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]:
118132
}
119133

120134

135+
def convert_mmdit_to_mlpackage(
136+
model_version: str,
137+
latent_h: int,
138+
latent_w: int,
139+
output_dir: str = None,
140+
**test_config_kwargs,
141+
) -> str:
142+
"""Converts a MMDiT model to a CoreML package.
143+
144+
Returns:
145+
`str`: path to the converted model.
146+
"""
147+
global TEST_SD3_CKPT_PATH, TEST_SD3_HF_REPO, TEST_LATENT_WIDTH, TEST_LATENT_HEIGHT, TEST_CACHE_DIR
148+
149+
# Convert to CoreML
150+
TEST_SD3_HF_REPO = model_version
151+
TEST_LATENT_HEIGHT = latent_h or TEST_LATENT_SIZE
152+
TEST_LATENT_WIDTH = latent_w or TEST_LATENT_SIZE
153+
154+
setup_test_config(compile_coreml=False, **test_config_kwargs)
155+
156+
with argmaxtools_test_utils._get_test_cache_dir(
157+
persistent_cache_dir=output_dir
158+
) as TEST_CACHE_DIR:
159+
suite = unittest.TestSuite()
160+
suite.addTest(TestSD3MMDiT("test_torch2coreml_correctness_and_speedup"))
161+
162+
if os.getenv("DEBUG", False):
163+
suite.debug()
164+
else:
165+
runner = unittest.TextTestRunner()
166+
runner.run(suite)
167+
168+
return os.path.join(TEST_CACHE_DIR, f"{TestSD3MMDiT.model_name}.mlpackage")
169+
170+
121171
if __name__ == "__main__":
122172
import argparse
123173

@@ -142,6 +192,8 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]:
142192
TEST_LATENT_SIZE = args.latent_size
143193
TEST_CKPT_FILE_NAME = args.ckpt_file_name
144194

195+
setup_test_config()
196+
145197
with argmaxtools_test_utils._get_test_cache_dir(args.o) as TEST_CACHE_DIR:
146198
suite = unittest.TestSuite()
147199
suite.addTest(TestSD3MMDiT("test_torch2coreml_correctness_and_speedup"))

tests/torch2coreml/test_vae.py renamed to python/src/diffusionkit/tests/torch2coreml/test_vae.py

Lines changed: 62 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import torch
1212
from argmaxtools import test_utils as argmaxtools_test_utils
1313
from argmaxtools.utils import get_fastest_device, get_logger
14+
from diffusionkit.torch import vae
15+
from diffusionkit.torch.model_io import _load_vae_decoder_weights
1416
from huggingface_hub import hf_hub_download
1517

16-
from python.src.diffusionkit.torch import vae
17-
from python.src.diffusionkit.torch.model_io import _load_vae_decoder_weights
18-
1918
torch.set_grad_enabled(False)
2019
logger = get_logger(__name__)
2120

@@ -26,20 +25,31 @@
2625
TEST_TORCH_DTYPE = torch.float32
2726
TEST_PSNR_THR = 35
2827
TEST_LATENT_SIZE = 64 # 64 latent -> 512 image, 128 latent -> 1024 image
29-
30-
# Test configuration
31-
# argmaxtools_test_utils.TEST_DEFAULT_NBITS = 8
32-
argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = 3.0
33-
argmaxtools_test_utils.TEST_COREML_PRECISION = ct.precision.FLOAT16
34-
argmaxtools_test_utils.TEST_COMPRESSION_MIN_SPEEDUP = 0.5
35-
argmaxtools_test_utils.TEST_COMPUTE_UNIT = ct.ComputeUnit.CPU_AND_GPU
36-
argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = True
37-
28+
TEST_LATENT_HEIGHT = TEST_LATENT_SIZE
29+
TEST_LATENT_WIDTH = TEST_LATENT_SIZE
3830

3931
SD3_8b = vae.VAEDecoderConfig(resolution=1024)
4032
SD3_2b = vae.VAEDecoderConfig(resolution=512)
4133

4234

35+
def setup_test_config(
36+
min_speedup_vs_cpu=3.0,
37+
compute_precision=ct.precision.FLOAT16,
38+
compute_unit=ct.ComputeUnit.CPU_AND_GPU,
39+
compression_min_speedup=0.5,
40+
default_nbits=None,
41+
skip_speed_tests=True,
42+
compile_coreml=False,
43+
):
44+
argmaxtools_test_utils.TEST_MIN_SPEEDUP_VS_CPU = min_speedup_vs_cpu
45+
argmaxtools_test_utils.TEST_COREML_PRECISION = compute_precision
46+
argmaxtools_test_utils.TEST_COMPUTE_UNIT = compute_unit
47+
argmaxtools_test_utils.TEST_COMPRESSION_MIN_SPEEDUP = compression_min_speedup
48+
argmaxtools_test_utils.TEST_DEFAULT_NBITS = default_nbits
49+
argmaxtools_test_utils.TEST_SKIP_SPEED_TESTS = skip_speed_tests
50+
argmaxtools_test_utils.TEST_COMPILE_COREML = compile_coreml
51+
52+
4353
class TestSD3VAEDecoder(argmaxtools_test_utils.CoreMLTestsMixin, unittest.TestCase):
4454
"""Unit tests for stable_duffusion_3.vae.VAEDecoder module"""
4555

@@ -90,13 +100,49 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]:
90100
if TEST_LATENT_SIZE != config_expected_latent_resolution:
91101
logger.warning(
92102
f"TEST_LATENT_SIZE ({TEST_LATENT_SIZE}) does not match the implied "
93-
"latent resolution from the model config "
103+
f"latent resolution ({config_expected_latent_resolution}) from the model config "
94104
)
95105

96-
z_dims = (1, config.in_channels, TEST_LATENT_SIZE, TEST_LATENT_SIZE)
106+
z_dims = (1, config.in_channels, TEST_LATENT_HEIGHT, TEST_LATENT_WIDTH)
97107
return {"z": torch.randn(*z_dims).to(TEST_DEV).to(TEST_TORCH_DTYPE)}
98108

99109

110+
def convert_vae_to_mlpackage(
111+
model_version: str,
112+
latent_h: int,
113+
latent_w: int,
114+
output_dir: str = None,
115+
**test_config_kwargs,
116+
) -> str:
117+
"""Converts a VAE decoder model to a CoreML package.
118+
119+
Returns:
120+
`str`: path to the converted model.
121+
"""
122+
global TEST_SD3_CKPT_PATH, TEST_SD3_HF_REPO, TEST_LATENT_WIDTH, TEST_LATENT_HEIGHT, TEST_CACHE_DIR
123+
124+
# Convert to CoreML
125+
TEST_SD3_HF_REPO = model_version
126+
TEST_LATENT_HEIGHT = latent_h or TEST_LATENT_SIZE
127+
TEST_LATENT_WIDTH = latent_w or TEST_LATENT_SIZE
128+
129+
setup_test_config(compile_coreml=False, **test_config_kwargs)
130+
131+
with argmaxtools_test_utils._get_test_cache_dir(
132+
persistent_cache_dir=output_dir
133+
) as TEST_CACHE_DIR:
134+
suite = unittest.TestSuite()
135+
suite.addTest(TestSD3VAEDecoder("test_torch2coreml_correctness_and_speedup"))
136+
137+
if os.getenv("DEBUG", False):
138+
suite.debug()
139+
else:
140+
runner = unittest.TextTestRunner()
141+
runner.run(suite)
142+
143+
return os.path.join(TEST_CACHE_DIR, f"{TestSD3VAEDecoder.model_name}.mlpackage")
144+
145+
100146
if __name__ == "__main__":
101147
import argparse
102148

@@ -112,6 +158,8 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]:
112158
TEST_SD3_HF_REPO = args.sd3_ckpt_path
113159
TEST_LATENT_SIZE = args.latent_size
114160

161+
setup_test_config()
162+
115163
with argmaxtools_test_utils._get_test_cache_dir(args.o) as TEST_CACHE_DIR:
116164
suite = unittest.TestSuite()
117165
suite.addTest(TestSD3VAEDecoder("test_torch2coreml_correctness_and_speedup"))

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
argmaxtools
1+
argmaxtools>=0.1.13
22
torch
33
safetensors
44
mlx

setup.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
1+
import os
2+
13
from setuptools import find_packages, setup
4+
from setuptools.command.install import install
5+
6+
VERSION = "0.2.16"
7+
8+
9+
class VersionInstallCommand(install):
10+
def run(self):
11+
install.run(self)
12+
version_file = os.path.join(self.install_lib, "diffusionkit", "version.py")
13+
with open(version_file, "w") as f:
14+
f.write(f"__version__ = '{VERSION}'\n")
215

3-
VERSION = "0.2.0"
416

517
with open("README.md") as f:
618
readme = f.read()
@@ -14,7 +26,7 @@
1426
long_description_content_type="text/markdown",
1527
author="Argmax, Inc.",
1628
install_requires=[
17-
"argmaxtools",
29+
"argmaxtools>=0.1.13",
1830
"torch",
1931
"safetensors",
2032
"mlx",
@@ -23,13 +35,16 @@
2335
"pillow",
2436
"sentencepiece",
2537
],
26-
packages=["diffusionkit"],
27-
package_dir={"": "python/src", "diffusionkit": "python/src/diffusionkit"},
38+
packages=find_packages(where="python/src"),
39+
package_dir={"": "python/src"},
2840
entry_points={
2941
"console_scripts": [
3042
"diffusionkit-cli=diffusionkit.mlx.scripts.generate_images:cli",
3143
],
3244
},
45+
cmdclass={
46+
"install": VersionInstallCommand,
47+
},
3348
classifiers=[
3449
"Development Status :: 4 - Beta",
3550
"Intended Audience :: Developers",

0 commit comments

Comments
 (0)