@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
58
58
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
59
59
Tuple of downsample block types.
60
60
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61
- Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D `.
61
+ Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None `.
62
62
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
63
63
Tuple of upsample block types.
64
64
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -103,6 +103,7 @@ def __init__(
103
103
freq_shift : int = 0 ,
104
104
flip_sin_to_cos : bool = True ,
105
105
down_block_types : Tuple [str , ...] = ("DownBlock2D" , "AttnDownBlock2D" , "AttnDownBlock2D" , "AttnDownBlock2D" ),
106
+ mid_block_type : Optional [str ] = "UNetMidBlock2D" ,
106
107
up_block_types : Tuple [str , ...] = ("AttnUpBlock2D" , "AttnUpBlock2D" , "AttnUpBlock2D" , "UpBlock2D" ),
107
108
block_out_channels : Tuple [int , ...] = (224 , 448 , 672 , 896 ),
108
109
layers_per_block : int = 2 ,
@@ -194,19 +195,22 @@ def __init__(
194
195
self .down_blocks .append (down_block )
195
196
196
197
# mid
197
- self .mid_block = UNetMidBlock2D (
198
- in_channels = block_out_channels [- 1 ],
199
- temb_channels = time_embed_dim ,
200
- dropout = dropout ,
201
- resnet_eps = norm_eps ,
202
- resnet_act_fn = act_fn ,
203
- output_scale_factor = mid_block_scale_factor ,
204
- resnet_time_scale_shift = resnet_time_scale_shift ,
205
- attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
206
- resnet_groups = norm_num_groups ,
207
- attn_groups = attn_norm_num_groups ,
208
- add_attention = add_attention ,
209
- )
198
+ if mid_block_type is None :
199
+ self .mid_block = None
200
+ else :
201
+ self .mid_block = UNetMidBlock2D (
202
+ in_channels = block_out_channels [- 1 ],
203
+ temb_channels = time_embed_dim ,
204
+ dropout = dropout ,
205
+ resnet_eps = norm_eps ,
206
+ resnet_act_fn = act_fn ,
207
+ output_scale_factor = mid_block_scale_factor ,
208
+ resnet_time_scale_shift = resnet_time_scale_shift ,
209
+ attention_head_dim = attention_head_dim if attention_head_dim is not None else block_out_channels [- 1 ],
210
+ resnet_groups = norm_num_groups ,
211
+ attn_groups = attn_norm_num_groups ,
212
+ add_attention = add_attention ,
213
+ )
210
214
211
215
# up
212
216
reversed_block_out_channels = list (reversed (block_out_channels ))
@@ -322,7 +326,8 @@ def forward(
322
326
down_block_res_samples += res_samples
323
327
324
328
# 4. mid
325
- sample = self .mid_block (sample , emb )
329
+ if self .mid_block is not None :
330
+ sample = self .mid_block (sample , emb )
326
331
327
332
# 5. up
328
333
skip_sample = None
0 commit comments