Skip to content

Commit b13cdbb

Browse files
authored
UNet2DModel mid_block_type (#10469)
1 parent a0acbdc commit b13cdbb

File tree

2 files changed

+49
-15
lines changed

2 files changed

+49
-15
lines changed

src/diffusers/models/unets/unet_2d.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
5858
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
5959
Tuple of downsample block types.
6060
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61-
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
61+
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `None`.
6262
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
6363
Tuple of upsample block types.
6464
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -103,6 +103,7 @@ def __init__(
103103
freq_shift: int = 0,
104104
flip_sin_to_cos: bool = True,
105105
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
106+
mid_block_type: Optional[str] = "UNetMidBlock2D",
106107
up_block_types: Tuple[str, ...] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
107108
block_out_channels: Tuple[int, ...] = (224, 448, 672, 896),
108109
layers_per_block: int = 2,
@@ -194,19 +195,22 @@ def __init__(
194195
self.down_blocks.append(down_block)
195196

196197
# mid
197-
self.mid_block = UNetMidBlock2D(
198-
in_channels=block_out_channels[-1],
199-
temb_channels=time_embed_dim,
200-
dropout=dropout,
201-
resnet_eps=norm_eps,
202-
resnet_act_fn=act_fn,
203-
output_scale_factor=mid_block_scale_factor,
204-
resnet_time_scale_shift=resnet_time_scale_shift,
205-
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
206-
resnet_groups=norm_num_groups,
207-
attn_groups=attn_norm_num_groups,
208-
add_attention=add_attention,
209-
)
198+
if mid_block_type is None:
199+
self.mid_block = None
200+
else:
201+
self.mid_block = UNetMidBlock2D(
202+
in_channels=block_out_channels[-1],
203+
temb_channels=time_embed_dim,
204+
dropout=dropout,
205+
resnet_eps=norm_eps,
206+
resnet_act_fn=act_fn,
207+
output_scale_factor=mid_block_scale_factor,
208+
resnet_time_scale_shift=resnet_time_scale_shift,
209+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
210+
resnet_groups=norm_num_groups,
211+
attn_groups=attn_norm_num_groups,
212+
add_attention=add_attention,
213+
)
210214

211215
# up
212216
reversed_block_out_channels = list(reversed(block_out_channels))
@@ -322,7 +326,8 @@ def forward(
322326
down_block_res_samples += res_samples
323327

324328
# 4. mid
325-
sample = self.mid_block(sample, emb)
329+
if self.mid_block is not None:
330+
sample = self.mid_block(sample, emb)
326331

327332
# 5. up
328333
skip_sample = None

tests/models/unets/test_models_unet_2d.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,35 @@ def test_mid_block_attn_groups(self):
105105
expected_shape = inputs_dict["sample"].shape
106106
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
107107

108+
def test_mid_block_none(self):
109+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
110+
mid_none_init_dict, mid_none_inputs_dict = self.prepare_init_args_and_inputs_for_common()
111+
mid_none_init_dict["mid_block_type"] = None
112+
113+
model = self.model_class(**init_dict)
114+
model.to(torch_device)
115+
model.eval()
116+
117+
mid_none_model = self.model_class(**mid_none_init_dict)
118+
mid_none_model.to(torch_device)
119+
mid_none_model.eval()
120+
121+
self.assertIsNone(mid_none_model.mid_block, "Mid block should not exist.")
122+
123+
with torch.no_grad():
124+
output = model(**inputs_dict)
125+
126+
if isinstance(output, dict):
127+
output = output.to_tuple()[0]
128+
129+
with torch.no_grad():
130+
mid_none_output = mid_none_model(**mid_none_inputs_dict)
131+
132+
if isinstance(mid_none_output, dict):
133+
mid_none_output = mid_none_output.to_tuple()[0]
134+
135+
self.assertFalse(torch.allclose(output, mid_none_output, rtol=1e-3), "outputs should be different.")
136+
108137
def test_gradient_checkpointing_is_applied(self):
109138
expected_set = {
110139
"AttnUpBlock2D",

0 commit comments

Comments
 (0)