Skip to content

Commit 416fb76

Browse files
committed
black style change
1 parent f5abc62 commit 416fb76

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

tests/torch2coreml/test_mmdit.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def setUpClass(cls):
6161
.eval()
6262
)
6363
logger.info("Initialized.")
64-
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
64+
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(
65+
TEST_SD3_HF_REPO, "sd3_medium.safetensors"
66+
)
6567
if TEST_SD3_CKPT_PATH is not None:
6668

6769
logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}")
@@ -134,7 +136,9 @@ def get_test_inputs(cfg: mmdit.MMDiTConfig) -> Dict[str, torch.Tensor]:
134136
parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int)
135137
args = parser.parse_args()
136138

137-
TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
139+
TEST_SD3_CKPT_PATH = (
140+
args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
141+
)
138142
TEST_SD3_HF_REPO = args.sd3_ckpt_path
139143
TEST_LATENT_SIZE = args.latent_size
140144
TEST_CKPT_FILE_NAME = args.ckpt_file_name

tests/torch2coreml/test_vae.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def setUpClass(cls):
5757
)
5858
logger.info("Initialized.")
5959

60-
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(TEST_SD3_HF_REPO, "sd3_medium.safetensors")
60+
TEST_SD3_CKPT_PATH = TEST_SD3_CKPT_PATH or hf_hub_download(
61+
TEST_SD3_HF_REPO, "sd3_medium.safetensors"
62+
)
6163
if TEST_SD3_CKPT_PATH is not None:
6264
logger.info(f"Loading SD3 model checkpoint from {TEST_SD3_CKPT_PATH}")
6365
_load_vae_decoder_weights(cls.test_torch_model, TEST_SD3_CKPT_PATH)
@@ -104,7 +106,9 @@ def get_test_inputs(config: vae.VAEDecoderConfig) -> Dict[str, torch.Tensor]:
104106
parser.add_argument("--latent-size", default=TEST_LATENT_SIZE, type=int)
105107
args = parser.parse_args()
106108

107-
TEST_SD3_CKPT_PATH = args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
109+
TEST_SD3_CKPT_PATH = (
110+
args.sd3_ckpt_path if os.path.exists(args.sd3_ckpt_path) else None
111+
)
108112
TEST_SD3_HF_REPO = args.sd3_ckpt_path
109113
TEST_LATENT_SIZE = args.latent_size
110114

0 commit comments

Comments
 (0)