From 6001899f30abd180bd1b74839ace5372f67cc8cf Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Tue, 10 Jun 2025 23:21:44 +0000 Subject: [PATCH 1/6] xx --- src/diffusers/models/model_loading_utils.py | 1 + src/diffusers/models/modeling_utils.py | 2 ++ tests/models/test_modeling_common.py | 4 ++-- tests/models/unets/test_models_unet_2d_condition.py | 5 +++++ 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index ebc7d79aeb28..3330cab61655 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -107,6 +107,7 @@ def _determine_device_map( device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) + print(f"333333 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 55ce0cf79fb9..1b2a57b35eb9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1201,9 +1201,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) # Now that the model is loaded, we can determine the device_map + print(f"111111 device_map: {device_map}") device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer ) + print(f"222222 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5087bd0094a5..448e4d076b71 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -1744,7 +1744,7 @@ def test_push_to_hub_library_name(self): delete_repo(self.repo_id, token=TOKEN) -@require_torch_gpu +@require_torch_accelerator @require_torch_2 @is_torch_compile @slow @@ -1789,7 +1789,7 @@ def test_compile_with_group_offloading(self): model.eval() # TODO: Can test for other group offloading kwargs later if needed. group_offload_kwargs = { - "onload_device": "cuda", + "onload_device": torch_device, "offload_device": "cpu", "offload_type": "block_level", "num_blocks_per_group": 1, diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index ab0dcbc1de11..c542c19b9b79 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1015,6 +1015,8 @@ def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) + from torchviz import make_dot + make_dot(yhat, params=dict(list(loaded_model.named_parameters()))).render("unet_torchviz", format="png") loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) @@ -1067,8 +1069,11 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") + from torchviz import make_dot new_output = loaded_model(**inputs_dict) + make_dot(new_output.sample, params=dict(loaded_model.named_parameters())).render("unet", format="png") + assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) From 8a1d6e5d885f35feed34fc15b314924dba9e10ac Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 11 Jun 2025 01:14:14 +0000 Subject: [PATCH 2/6] fix Signed-off-by: YAO Matrix --- src/diffusers/models/modeling_utils.py | 2 -- src/diffusers/models/unets/unet_2d_blocks.py | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index 1b2a57b35eb9..55ce0cf79fb9 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -1201,11 +1201,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) # Now that the model is loaded, we can determine the device_map - print(f"111111 device_map: {device_map}") device_map = _determine_device_map( model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer ) - print(f"222222 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index e082d524e766..f29680bc4c17 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -2557,7 +2557,8 @@ def forward( b1=self.b1, b2=self.b2, ) - + if hidden_states.device != res_hidden_states.device: + res_hidden_states = res_hidden_states.to(hidden_states.device) hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if torch.is_grad_enabled() and self.gradient_checkpointing: From 603257bd997e982c3e6515a35bb0a84f702480c1 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 11 Jun 2025 09:16:48 +0800 Subject: [PATCH 3/6] Update model_loading_utils.py --- src/diffusers/models/model_loading_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/model_loading_utils.py b/src/diffusers/models/model_loading_utils.py index 3330cab61655..ebc7d79aeb28 100644 --- a/src/diffusers/models/model_loading_utils.py +++ b/src/diffusers/models/model_loading_utils.py @@ -107,7 +107,6 @@ def _determine_device_map( device_map_kwargs["max_memory"] = max_memory device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs) - print(f"333333 device_map: {device_map}") if hf_quantizer is not None: hf_quantizer.validate_environment(device_map=device_map) From 8cdfdd8e775bd94cd23dd5915d0f923a73495b3c Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 11 Jun 2025 09:17:42 +0800 Subject: [PATCH 4/6] Update test_models_unet_2d_condition.py --- tests/models/unets/test_models_unet_2d_condition.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index c542c19b9b79..0518a5c8e2c1 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1015,8 +1015,6 @@ def test_load_sharded_checkpoint_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True) - from torchviz import make_dot - make_dot(yhat, params=dict(list(loaded_model.named_parameters()))).render("unet_torchviz", format="png") loaded_model = loaded_model.to(torch_device) new_output = loaded_model(**inputs_dict) From 45e29bdff51db8e94c6200eb88ac982b5df2a86a Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Wed, 11 Jun 2025 09:18:16 +0800 Subject: [PATCH 5/6] Update test_models_unet_2d_condition.py --- tests/models/unets/test_models_unet_2d_condition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 0518a5c8e2c1..ab0dcbc1de11 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -1067,11 +1067,8 @@ def test_load_sharded_checkpoint_device_map_from_hub_local(self): _, inputs_dict = self.prepare_init_args_and_inputs_for_common() ckpt_path = snapshot_download("hf-internal-testing/unet2d-sharded-dummy") loaded_model = self.model_class.from_pretrained(ckpt_path, local_files_only=True, device_map="auto") - from torchviz import make_dot new_output = loaded_model(**inputs_dict) - make_dot(new_output.sample, params=dict(loaded_model.named_parameters())).render("unet", format="png") - assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) From fae7c7064b6cd7b14d1bd6c83f4a721e1a54816c Mon Sep 17 00:00:00 2001 From: YAO Matrix Date: Wed, 11 Jun 2025 13:29:46 +0000 Subject: [PATCH 6/6] fix style Signed-off-by: YAO Matrix --- tests/models/test_modeling_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 448e4d076b71..aa6db128d5e8 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -70,7 +70,6 @@ require_torch_2, require_torch_accelerator, require_torch_accelerator_with_training, - require_torch_gpu, require_torch_multi_accelerator, run_test_in_subprocess, slow,