|
| 1 | +# Copyright 2025 HuggingFace Inc. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import unittest |
| 16 | + |
| 17 | +import torch |
| 18 | + |
| 19 | +from diffusers import DreamTransformer1DModel |
| 20 | +from diffusers.utils.testing_utils import ( |
| 21 | + enable_full_determinism, |
| 22 | + require_torch_accelerator_with_training, |
| 23 | + torch_device, |
| 24 | +) |
| 25 | + |
| 26 | +from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin |
| 27 | + |
| 28 | + |
| 29 | +enable_full_determinism() |
| 30 | + |
| 31 | + |
| 32 | +class DreamTransformerTests(ModelTesterMixin, unittest.TestCase): |
| 33 | + model_class = DreamTransformer1DModel |
| 34 | + main_input_name = "text_ids" |
| 35 | + |
| 36 | + # Skip setting testing with default: AttnProcessor |
| 37 | + uses_custom_attn_processor = True |
| 38 | + |
| 39 | + @property |
| 40 | + def dummy_input(self): |
| 41 | + return self.prepare_dummy_input() |
| 42 | + |
| 43 | + @property |
| 44 | + def input_shape(self): |
| 45 | + return (48,) # (sequence_length,) |
| 46 | + |
| 47 | + @property |
| 48 | + def output_shape(self): |
| 49 | + return (48, 100) # (sequence_length, vocab_size) |
| 50 | + |
| 51 | + def prepare_dummy_input(self, batch_size: int = 1, sequence_length: int = 48): |
| 52 | + vocab_size = 100 |
| 53 | + |
| 54 | + text_ids = torch.randint(vocab_size, size=(batch_size, sequence_length), device=torch_device) |
| 55 | + # NOTE: dummy timestep input for now (not used) |
| 56 | + # timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) |
| 57 | + |
| 58 | + inputs_dict = {"text_ids": text_ids} |
| 59 | + return inputs_dict |
| 60 | + |
| 61 | + def prepare_init_args_and_inputs_for_common(self): |
| 62 | + init_dict = { |
| 63 | + "num_layers": 1, |
| 64 | + "attention_head_dim": 16, |
| 65 | + "num_attention_heads": 4, |
| 66 | + "num_attention_kv_heads": 2, |
| 67 | + "ff_intermediate_dim": 256, # 4 * (attention_head_dim * num_attention_heads) |
| 68 | + "vocab_size": 100, |
| 69 | + "pad_token_id": 90, |
| 70 | + } |
| 71 | + |
| 72 | + inputs_dict = self.dummy_input |
| 73 | + return init_dict, inputs_dict |
| 74 | + |
| 75 | + # NOTE: override ModelTesterMixin.test_output to supply a custom expected_output_shape as the expected output |
| 76 | + # shape of the Dream transformer is not the same as the input shape |
| 77 | + def test_output(self, expected_output_shape=None): |
| 78 | + if expected_output_shape is None: |
| 79 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 80 | + vocab_size = init_dict["vocab_size"] |
| 81 | + batch_size, seq_len = inputs_dict["text_ids"].shape |
| 82 | + expected_output_shape = (batch_size, seq_len, vocab_size) |
| 83 | + super().test_output(expected_output_shape=expected_output_shape) |
| 84 | + |
| 85 | + def test_output_hidden_states_supplied(self): |
| 86 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()\ |
| 87 | + |
| 88 | + # Prepare hidden_states argument manually, remove text_ids arg. |
| 89 | + hidden_dim = init_dict["attention_head_dim"] * init_dict["num_attention_heads"] |
| 90 | + vocab_size = init_dict["vocab_size"] |
| 91 | + batch_size, seq_len = inputs_dict["text_ids"].shape |
| 92 | + hidden_states = torch.randn((batch_size, seq_len, hidden_dim), device=torch_device) |
| 93 | + inputs_dict["hidden_states"] = hidden_states |
| 94 | + del inputs_dict["text_ids"] |
| 95 | + |
| 96 | + expected_output_shape = (batch_size, seq_len, vocab_size) |
| 97 | + super().test_output(expected_output_shape=expected_output_shape) |
| 98 | + |
| 99 | + def test_output_positions_ids_supplied(self): |
| 100 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 101 | + |
| 102 | + # Prepare position_ids argument manually. |
| 103 | + vocab_size = init_dict["vocab_size"] |
| 104 | + position_ids = torch.arange(inputs_dict["text_ids"].shape[1], device=torch_device) |
| 105 | + position_ids = position_ids.unsqueeze(0).expand(inputs_dict["text_ids"].shape[0], -1) |
| 106 | + inputs_dict["position_ids"] = position_ids |
| 107 | + |
| 108 | + expected_output_shape = (inputs_dict["text_ids"].shape[0], inputs_dict["text_ids"].shape[1], vocab_size) |
| 109 | + super().test_output(expected_output_shape=expected_output_shape) |
| 110 | + |
| 111 | + @require_torch_accelerator_with_training |
| 112 | + def test_training_attention_mask_supplied(self): |
| 113 | + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
| 114 | + |
| 115 | + vocab_size = init_dict["vocab_size"] |
| 116 | + batch_size, seq_len = inputs_dict["text_ids"].shape |
| 117 | + |
| 118 | + model = self.model_class(**init_dict) |
| 119 | + model.to(torch_device) |
| 120 | + model.train() |
| 121 | + dtype = model.dtype |
| 122 | + |
| 123 | + # Prepare causal attention mask for training, specifically a transformers-style 4D additive causal mask with |
| 124 | + # the upper triangular entries filled with -inf |
| 125 | + attention_mask = None |
| 126 | + causal_mask = torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=torch_device) |
| 127 | + positions = torch.arange(causal_mask.size(-1), device=torch_device) |
| 128 | + causal_mask.masked_fill_(positions < (positions + 1).view(causal_mask.size(-1), 1), 0) |
| 129 | + attention_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_len, seq_len) |
| 130 | + inputs_dict["attention_mask"] = attention_mask |
| 131 | + |
| 132 | + logits = model(**inputs_dict) |
| 133 | + |
| 134 | + if isinstance(logits, dict): |
| 135 | + logits = logits.to_tuple()[0] |
| 136 | + |
| 137 | + input_tensor = inputs_dict[self.main_input_name] |
| 138 | + target = torch.randint(vocab_size, (input_tensor.shape[0], input_tensor.shape[1]), device=torch_device) |
| 139 | + loss = torch.nn.functional.cross_entropy(logits.view(-1, vocab_size), target.view(-1)) |
| 140 | + loss.backward() |
| 141 | + |
| 142 | + |
| 143 | +class DreamTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): |
| 144 | + model_class = DreamTransformer1DModel |
| 145 | + # NOTE: set to None to skip TorchCompileTesterMixin.test_compile_on_different_shapes because this test |
| 146 | + # currently assumes the input is image-like (specifically, that prepare_dummy_inputs accepts `height` and `width` |
| 147 | + # argunments). We could consider overriding this test to make it specific to the Dream transformer. |
| 148 | + different_shapes_for_compilation = None |
| 149 | + |
| 150 | + def prepare_dummy_input(self, batch_size: int = 1, sequence_length: int = 48): |
| 151 | + return DreamTransformerTests().prepare_dummy_input(batch_size=batch_size, sequence_length=sequence_length) |
| 152 | + |
| 153 | + def prepare_init_args_and_inputs_for_common(self): |
| 154 | + return DreamTransformerTests().prepare_init_args_and_inputs_for_common() |
0 commit comments