Utilities and documentation for porting PyTorch models to Swift/MLX while keeping feature parity with custom layers and Metal pipelines.
- 🇬🇧 English guide:
docs/pytorch_to_mlx_migration_guide_en.md - 🇯🇵 日本語ガイド:
docs/pytorch_to_mlx_migration_guide_ja.md
Both guides walk through environment setup, module porting with @ModuleInfo, verification via safetensors, and performance tuning on Apple Silicon.
tools/pytorch_mlx/ contains reusable scripts:
convert_checkpoint_to_mlx.py— convert PyTorch checkpoints into MLX-friendlyweights.npz+config.json.export_state_metadata.py— dumpstate_dictmetadata (pytorch_analysis.json).weight_mapping_analyzer.py— compare PyTorch/MLX parameter shapes and generate transpose plans.structure_comparator.py— report missing/extra keys between PyTorch and MLX graphs.comparison/safetensor_visualizer.py— visualise safetensor differences with statistics and plots.
Sample analysis JSON files reside in tools/pytorch_mlx/examples/.
- Export PyTorch metadata:
python3 tools/pytorch_mlx/export_state_metadata.py checkpoint.ckpt. - Convert checkpoints (optional):
python3 tools/pytorch_mlx/convert_checkpoint_to_mlx.py checkpoint.ckpt -o ./mlx_model. - Analyse shape differences:
python3 tools/pytorch_mlx/weight_mapping_analyzer.py --pytorch pytorch_analysis.json --mlx mlx_analysis.json. - Keep the PyTorch JSON alongside the Swift app bundle to perform on-device structure comparisons.
- Compare intermediate results with safetensors using
comparison/safetensor_visualizer.py.
- The workflow is applied in the Swift/MLX port of TripoSR: TripoSRMlx
このリポジトリに含まれるガイドとスクリプトを使うことで、PyTorch のモデルを Swift/MLX へ効率的に移植し、各ステップで数値比較を行いながら品質を担保できます。