Skip to content

Commit c096457

Browse files
authored
[Sana 4K] (#10493)
add 4K support for Sana
1 parent b13cdbb commit c096457

File tree

2 files changed

+54
-5
lines changed

2 files changed

+54
-5
lines changed

scripts/convert_sana_to_diffusers.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CTX = init_empty_weights if is_accelerate_available else nullcontext
2626

2727
ckpt_ids = [
28+
"Efficient-Large-Model/Sana_1600M_4Kpx_BF16/checkpoints/Sana_1600M_4Kpx_BF16.pth",
2829
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
2930
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
3031
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
@@ -89,7 +90,10 @@ def main(args):
8990
converted_state_dict["caption_norm.weight"] = state_dict.pop("attention_y_norm.weight")
9091

9192
# scheduler
92-
flow_shift = 3.0
93+
if args.image_size == 4096:
94+
flow_shift = 6.0
95+
else:
96+
flow_shift = 3.0
9397

9498
# model config
9599
if args.model_type == "SanaMS_1600M_P1_D20":
@@ -99,7 +103,7 @@ def main(args):
99103
else:
100104
raise ValueError(f"{args.model_type} is not supported.")
101105
# Positional embedding interpolation scale.
102-
interpolation_scale = {512: None, 1024: None, 2048: 1.0}
106+
interpolation_scale = {512: None, 1024: None, 2048: 1.0, 4096: 2.0}
103107

104108
for depth in range(layer_num):
105109
# Transformer blocks.
@@ -272,9 +276,9 @@ def main(args):
272276
"--image_size",
273277
default=1024,
274278
type=int,
275-
choices=[512, 1024, 2048],
279+
choices=[512, 1024, 2048, 4096],
276280
required=False,
277-
help="Image size of pretrained model, 512, 1024 or 2048.",
281+
help="Image size of pretrained model, 512, 1024, 2048 or 4096.",
278282
)
279283
parser.add_argument(
280284
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,49 @@
6363
import ftfy
6464

6565

66+
ASPECT_RATIO_4096_BIN = {
67+
"0.25": [2048.0, 8192.0],
68+
"0.26": [2048.0, 7936.0],
69+
"0.27": [2048.0, 7680.0],
70+
"0.28": [2048.0, 7424.0],
71+
"0.32": [2304.0, 7168.0],
72+
"0.33": [2304.0, 6912.0],
73+
"0.35": [2304.0, 6656.0],
74+
"0.4": [2560.0, 6400.0],
75+
"0.42": [2560.0, 6144.0],
76+
"0.48": [2816.0, 5888.0],
77+
"0.5": [2816.0, 5632.0],
78+
"0.52": [2816.0, 5376.0],
79+
"0.57": [3072.0, 5376.0],
80+
"0.6": [3072.0, 5120.0],
81+
"0.68": [3328.0, 4864.0],
82+
"0.72": [3328.0, 4608.0],
83+
"0.78": [3584.0, 4608.0],
84+
"0.82": [3584.0, 4352.0],
85+
"0.88": [3840.0, 4352.0],
86+
"0.94": [3840.0, 4096.0],
87+
"1.0": [4096.0, 4096.0],
88+
"1.07": [4096.0, 3840.0],
89+
"1.13": [4352.0, 3840.0],
90+
"1.21": [4352.0, 3584.0],
91+
"1.29": [4608.0, 3584.0],
92+
"1.38": [4608.0, 3328.0],
93+
"1.46": [4864.0, 3328.0],
94+
"1.67": [5120.0, 3072.0],
95+
"1.75": [5376.0, 3072.0],
96+
"2.0": [5632.0, 2816.0],
97+
"2.09": [5888.0, 2816.0],
98+
"2.4": [6144.0, 2560.0],
99+
"2.5": [6400.0, 2560.0],
100+
"2.89": [6656.0, 2304.0],
101+
"3.0": [6912.0, 2304.0],
102+
"3.11": [7168.0, 2304.0],
103+
"3.62": [7424.0, 2048.0],
104+
"3.75": [7680.0, 2048.0],
105+
"3.88": [7936.0, 2048.0],
106+
"4.0": [8192.0, 2048.0],
107+
}
108+
66109
EXAMPLE_DOC_STRING = """
67110
Examples:
68111
```py
@@ -734,7 +777,9 @@ def __call__(
734777

735778
# 1. Check inputs. Raise error if not correct
736779
if use_resolution_binning:
737-
if self.transformer.config.sample_size == 64:
780+
if self.transformer.config.sample_size == 128:
781+
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
782+
elif self.transformer.config.sample_size == 64:
738783
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
739784
elif self.transformer.config.sample_size == 32:
740785
aspect_ratio_bin = ASPECT_RATIO_1024_BIN

0 commit comments

Comments
 (0)