Skip to content

Commit c83abf5

Browse files
committed
Add modeling tests for the Dream transformer
1 parent ff03fc2 commit c83abf5

File tree

1 file changed

+154
-0
lines changed

1 file changed

+154
-0
lines changed
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2025 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import DreamTransformer1DModel
20+
from diffusers.utils.testing_utils import (
21+
enable_full_determinism,
22+
require_torch_accelerator_with_training,
23+
torch_device,
24+
)
25+
26+
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class DreamTransformerTests(ModelTesterMixin, unittest.TestCase):
33+
model_class = DreamTransformer1DModel
34+
main_input_name = "text_ids"
35+
36+
# Skip setting testing with default: AttnProcessor
37+
uses_custom_attn_processor = True
38+
39+
@property
40+
def dummy_input(self):
41+
return self.prepare_dummy_input()
42+
43+
@property
44+
def input_shape(self):
45+
return (48,) # (sequence_length,)
46+
47+
@property
48+
def output_shape(self):
49+
return (48, 100) # (sequence_length, vocab_size)
50+
51+
def prepare_dummy_input(self, batch_size: int = 1, sequence_length: int = 48):
52+
vocab_size = 100
53+
54+
text_ids = torch.randint(vocab_size, size=(batch_size, sequence_length), device=torch_device)
55+
# NOTE: dummy timestep input for now (not used)
56+
# timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
57+
58+
inputs_dict = {"text_ids": text_ids}
59+
return inputs_dict
60+
61+
def prepare_init_args_and_inputs_for_common(self):
62+
init_dict = {
63+
"num_layers": 1,
64+
"attention_head_dim": 16,
65+
"num_attention_heads": 4,
66+
"num_attention_kv_heads": 2,
67+
"ff_intermediate_dim": 256, # 4 * (attention_head_dim * num_attention_heads)
68+
"vocab_size": 100,
69+
"pad_token_id": 90,
70+
}
71+
72+
inputs_dict = self.dummy_input
73+
return init_dict, inputs_dict
74+
75+
# NOTE: override ModelTesterMixin.test_output to supply a custom expected_output_shape as the expected output
76+
# shape of the Dream transformer is not the same as the input shape
77+
def test_output(self, expected_output_shape=None):
78+
if expected_output_shape is None:
79+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
80+
vocab_size = init_dict["vocab_size"]
81+
batch_size, seq_len = inputs_dict["text_ids"].shape
82+
expected_output_shape = (batch_size, seq_len, vocab_size)
83+
super().test_output(expected_output_shape=expected_output_shape)
84+
85+
def test_output_hidden_states_supplied(self):
86+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()\
87+
88+
# Prepare hidden_states argument manually, remove text_ids arg.
89+
hidden_dim = init_dict["attention_head_dim"] * init_dict["num_attention_heads"]
90+
vocab_size = init_dict["vocab_size"]
91+
batch_size, seq_len = inputs_dict["text_ids"].shape
92+
hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=torch_device)
93+
inputs_dict["hidden_states"] = hidden_states
94+
del inputs_dict["text_ids"]
95+
96+
expected_output_shape = (batch_size, seq_len, vocab_size)
97+
super().test_output(expected_output_shape=expected_output_shape)
98+
99+
def test_output_positions_ids_supplied(self):
100+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
101+
102+
# Prepare position_ids argument manually.
103+
vocab_size = init_dict["vocab_size"]
104+
position_ids = torch.arange(inputs_dict["text_ids"].shape[1], device=torch_device)
105+
position_ids = position_ids.unsqueeze(0).expand(inputs_dict["text_ids"].shape[0], -1)
106+
inputs_dict["position_ids"] = position_ids
107+
108+
expected_output_shape = (inputs_dict["text_ids"].shape[0], inputs_dict["text_ids"].shape[1], vocab_size)
109+
super().test_output(expected_output_shape=expected_output_shape)
110+
111+
@require_torch_accelerator_with_training
112+
def test_training_attention_mask_supplied(self):
113+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
114+
115+
vocab_size = init_dict["vocab_size"]
116+
batch_size, seq_len = inputs_dict["text_ids"].shape
117+
118+
model = self.model_class(**init_dict)
119+
model.to(torch_device)
120+
model.train()
121+
dtype = model.dtype
122+
123+
# Prepare causal attention mask for training, specifically a transformers-style 4D additive causal mask with
124+
# the upper triangular entries filled with -inf
125+
attention_mask = None
126+
causal_mask = torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=torch_device)
127+
positions = torch.arange(causal_mask.size(-1), device=torch_device)
128+
causal_mask.masked_fill_(positions < (positions + 1).view(causal_mask.size(-1), 1), 0)
129+
attention_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len)
130+
inputs_dict["attention_mask"] = attention_mask
131+
132+
logits = model(**inputs_dict)
133+
134+
if isinstance(logits, dict):
135+
logits = logits.to_tuple()[0]
136+
137+
input_tensor = inputs_dict[self.main_input_name]
138+
target = torch.randint(vocab_size, (input_tensor.shape[0], input_tensor.shape[1]), device=torch_device)
139+
loss = torch.nn.functional.cross_entropy(logits.view(-1, vocab_size), target.view(-1))
140+
loss.backward()
141+
142+
143+
class DreamTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
144+
model_class = DreamTransformer1DModel
145+
# NOTE: set to None to skip TorchCompileTesterMixin.test_compile_on_different_shapes because this test
146+
# currently assumes the input is image-like (specifically, that prepare_dummy_inputs accepts `height` and `width`
147+
# argunments). We could consider overriding this test to make it specific to the Dream transformer.
148+
different_shapes_for_compilation = None
149+
150+
def prepare_dummy_input(self, batch_size: int = 1, sequence_length: int = 48):
151+
return DreamTransformerTests().prepare_dummy_input(batch_size=batch_size, sequence_length=sequence_length)
152+
153+
def prepare_init_args_and_inputs_for_common(self):
154+
return DreamTransformerTests().prepare_init_args_and_inputs_for_common()

0 commit comments

Comments
 (0)