3
3
# Copyright (C) 2024 Argmax, Inc. All Rights Reserved.
4
4
#
5
5
6
- import gc
7
6
from functools import partial
8
7
9
8
import mlx .core as mx
10
9
import mlx .nn as nn
11
- import mlx .utils as utils
12
10
import numpy as np
13
11
from argmaxtools .utils import get_logger
14
12
from beartype .typing import Dict , List , Optional , Tuple
@@ -77,35 +75,31 @@ def cache_modulation_params(
77
75
by offloading all adaLN_modulation parameters
78
76
"""
79
77
y_embed = self .y_embedder (pooled_text_embeddings )
78
+ batch_size = pooled_text_embeddings .shape [0 ]
80
79
81
80
offload_size = 0
82
81
to_offload = []
83
82
84
83
for timestep in timesteps :
85
84
final_timestep = timestep .item () == timesteps [- 1 ].item ()
86
- modulation_inputs = y_embed + self .t_embedder (timestep [None ] * 1000.0 )
85
+ timestep_key = timestep .item ()
86
+ modulation_inputs = y_embed [:, None , None , :] + self .t_embedder (
87
+ mx .repeat (timestep [None ], batch_size , axis = 0 )
88
+ )
87
89
88
90
for block in self .multimodal_transformer_blocks :
89
91
if not hasattr (block .image_transformer_block , "_modulation_params" ):
90
92
block .image_transformer_block ._modulation_params = dict ()
91
93
block .text_transformer_block ._modulation_params = dict ()
92
94
93
95
block .image_transformer_block ._modulation_params [
94
- ( timestep * 1000 ). item ()
96
+ timestep_key
95
97
] = block .image_transformer_block .adaLN_modulation (modulation_inputs )
96
98
block .text_transformer_block ._modulation_params [
97
- ( timestep * 1000 ). item ()
99
+ timestep_key
98
100
] = block .text_transformer_block .adaLN_modulation (modulation_inputs )
99
- mx .eval (
100
- block .image_transformer_block ._modulation_params [
101
- (timestep * 1000 ).item ()
102
- ]
103
- )
104
- mx .eval (
105
- block .text_transformer_block ._modulation_params [
106
- (timestep * 1000 ).item ()
107
- ]
108
- )
101
+ mx .eval (block .image_transformer_block ._modulation_params [timestep_key ])
102
+ mx .eval (block .text_transformer_block ._modulation_params [timestep_key ])
109
103
110
104
if final_timestep :
111
105
offload_size += (
@@ -131,33 +125,34 @@ def cache_modulation_params(
131
125
]
132
126
)
133
127
134
- for block in self .unified_transformer_blocks :
135
- if not hasattr (block .transformer_block , "_modulation_params" ):
136
- block .transformer_block ._modulation_params = dict ()
137
- block .transformer_block ._modulation_params [
138
- (timestep * 1000 ).item ()
139
- ] = block .transformer_block .adaLN_modulation (modulation_inputs )
140
- mx .eval (
141
- block .transformer_block ._modulation_params [(timestep * 1000 ).item ()]
142
- )
143
-
144
- if final_timestep :
145
- offload_size += (
146
- block .transformer_block .adaLN_modulation .layers [1 ].weight .size
147
- * block .transformer_block .adaLN_modulation .layers [
148
- 1
149
- ].weight .dtype .size
150
- )
151
- to_offload .extend (
152
- [block .transformer_block .adaLN_modulation .layers [1 ]]
153
- )
128
+ if self .config .depth_unified > 0 :
129
+ for block in self .unified_transformer_blocks :
130
+ if not hasattr (block .transformer_block , "_modulation_params" ):
131
+ block .transformer_block ._modulation_params = dict ()
132
+ block .transformer_block ._modulation_params [
133
+ timestep_key
134
+ ] = block .transformer_block .adaLN_modulation (modulation_inputs )
135
+ mx .eval (block .transformer_block ._modulation_params [timestep_key ])
136
+
137
+ if final_timestep :
138
+ offload_size += (
139
+ block .transformer_block .adaLN_modulation .layers [
140
+ 1
141
+ ].weight .size
142
+ * block .transformer_block .adaLN_modulation .layers [
143
+ 1
144
+ ].weight .dtype .size
145
+ )
146
+ to_offload .extend (
147
+ [block .transformer_block .adaLN_modulation .layers [1 ]]
148
+ )
154
149
155
150
if not hasattr (self .final_layer , "_modulation_params" ):
156
151
self .final_layer ._modulation_params = dict ()
157
152
self .final_layer ._modulation_params [
158
- ( timestep * 1000 ). item ()
153
+ timestep_key
159
154
] = self .final_layer .adaLN_modulation (modulation_inputs )
160
- mx .eval (self .final_layer ._modulation_params [( timestep * 1000 ). item () ])
155
+ mx .eval (self .final_layer ._modulation_params [timestep_key ])
161
156
162
157
if final_timestep :
163
158
offload_size += (
@@ -246,6 +241,7 @@ def __call__(
246
241
latent_image_embeddings ,
247
242
timestep ,
248
243
)
244
+
249
245
if self .config .patchify_via_reshape :
250
246
latent_image_embeddings = self .x_embedder .unpack (
251
247
latent_image_embeddings , (latent_height , latent_width )
@@ -437,7 +433,10 @@ def pre_sdpa(
437
433
tensor : mx .array ,
438
434
timestep : mx .array ,
439
435
) -> Dict [str , mx .array ]:
436
+ if timestep .size > 1 :
437
+ timestep = timestep [0 ]
440
438
modulation_params = self ._modulation_params [timestep .item ()]
439
+
441
440
modulation_params = mx .split (
442
441
modulation_params , self .num_modulation_params , axis = - 1
443
442
)
@@ -771,6 +770,8 @@ def __call__(
771
770
latent_image_embeddings : mx .array ,
772
771
timestep : mx .array ,
773
772
) -> mx .array :
773
+ if timestep .size > 1 :
774
+ timestep = timestep [0 ]
774
775
modulation_params = self ._modulation_params [timestep .item ()]
775
776
776
777
shift , residual_scale = mx .split (modulation_params , 2 , axis = - 1 )
0 commit comments