diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index e082d524e766..f29680bc4c17 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -2557,7 +2557,8 @@ def forward( b1=self.b1, b2=self.b2, ) - + if hidden_states.device != res_hidden_states.device: + res_hidden_states = res_hidden_states.to(hidden_states.device) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5087bd0094a5..aa6db128d5e8 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -70,7 +70,6 @@ require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, - require_torch_gpu, require_torch_multi_accelerator, run_test_in_subprocess, slow, @@ -1744,7 +1743,7 @@ def test_push_to_hub_library_name(self): delete_repo(self.repo_id, token=TOKEN) -@require_torch_gpu +@require_torch_accelerator @require_torch_2 @is_torch_compile @slow @@ -1789,7 +1788,7 @@ def test_compile_with_group_offloading(self): model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { - "onload_device": "cuda", + "onload_device": torch_device, "offload_device": "cpu", "offload_type": "block_level", "num_blocks_per_group": 1,