Skip to content

Commit 75a636d

Browse files
baymax591白超hlky
authored
bugfix for npu not support float64 (#10123)
* bugfix for npu not support float64 * is_mps is_npu --------- Co-authored-by: 白超 <baichao19@huawei.com> Co-authored-by: hlky <hlky@hlky.ac>
1 parent 4842f5d commit 75a636d

21 files changed

+63
-42
lines changed

examples/community/fresco_v2v.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,11 @@ def forward(
404404
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
405405
# This would be a good case for the `match` statement (Python 3.10+)
406406
is_mps = sample.device.type == "mps"
407+
is_npu = sample.device.type == "npu"
407408
if isinstance(timestep, float):
408-
dtype = torch.float32 if is_mps else torch.float64
409+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
409410
else:
410-
dtype = torch.int32 if is_mps else torch.int64
411+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
411412
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
412413
elif len(timesteps.shape) == 0:
413414
timesteps = timesteps[None].to(sample.device)

examples/community/matryoshka.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2806,10 +2806,11 @@ def get_time_embed(
28062806
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
28072807
# This would be a good case for the `match` statement (Python 3.10+)
28082808
is_mps = sample.device.type == "mps"
2809+
is_npu = sample.device.type == "npu"
28092810
if isinstance(timestep, float):
2810-
dtype = torch.float32 if is_mps else torch.float64
2811+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
28112812
else:
2812-
dtype = torch.int32 if is_mps else torch.int64
2813+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
28132814
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
28142815
elif len(timesteps.shape) == 0:
28152816
timesteps = timesteps[None].to(sample.device)

examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1031,10 +1031,11 @@ def __call__(
10311031
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10321032
# This would be a good case for the `match` statement (Python 3.10+)
10331033
is_mps = latent_model_input.device.type == "mps"
1034+
is_npu = latent_model_input.device.type == "npu"
10341035
if isinstance(current_timestep, float):
1035-
dtype = torch.float32 if is_mps else torch.float64
1036+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
10361037
else:
1037-
dtype = torch.int32 if is_mps else torch.int64
1038+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
10381039
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
10391040
elif len(current_timestep.shape) == 0:
10401041
current_timestep = current_timestep[None].to(latent_model_input.device)

examples/research_projects/promptdiffusion/promptdiffusioncontrolnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,10 +258,11 @@ def forward(
258258
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
259259
# This would be a good case for the `match` statement (Python 3.10+)
260260
is_mps = sample.device.type == "mps"
261+
is_npu = sample.device.type == "npu"
261262
if isinstance(timestep, float):
262-
dtype = torch.float32 if is_mps else torch.float64
263+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
263264
else:
264-
dtype = torch.int32 if is_mps else torch.int64
265+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
265266
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
266267
elif len(timesteps.shape) == 0:
267268
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -740,10 +740,11 @@ def forward(
740740
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
741741
# This would be a good case for the `match` statement (Python 3.10+)
742742
is_mps = sample.device.type == "mps"
743+
is_npu = sample.device.type == "npu"
743744
if isinstance(timestep, float):
744-
dtype = torch.float32 if is_mps else torch.float64
745+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
745746
else:
746-
dtype = torch.int32 if is_mps else torch.int64
747+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
747748
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
748749
elif len(timesteps.shape) == 0:
749750
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_sparsectrl.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -671,10 +671,11 @@ def forward(
671671
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
672672
# This would be a good case for the `match` statement (Python 3.10+)
673673
is_mps = sample.device.type == "mps"
674+
is_npu = sample.device.type == "npu"
674675
if isinstance(timestep, float):
675-
dtype = torch.float32 if is_mps else torch.float64
676+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
676677
else:
677-
dtype = torch.int32 if is_mps else torch.int64
678+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
678679
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
679680
elif len(timesteps.shape) == 0:
680681
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_union.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,10 +681,11 @@ def forward(
681681
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
682682
# This would be a good case for the `match` statement (Python 3.10+)
683683
is_mps = sample.device.type == "mps"
684+
is_npu = sample.device.type == "npu"
684685
if isinstance(timestep, float):
685-
dtype = torch.float32 if is_mps else torch.float64
686+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
686687
else:
687-
dtype = torch.int32 if is_mps else torch.int64
688+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
688689
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
689690
elif len(timesteps.shape) == 0:
690691
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/controlnets/controlnet_xs.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,10 +1088,11 @@ def forward(
10881088
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
10891089
# This would be a good case for the `match` statement (Python 3.10+)
10901090
is_mps = sample.device.type == "mps"
1091+
is_npu = sample.device.type == "npu"
10911092
if isinstance(timestep, float):
1092-
dtype = torch.float32 if is_mps else torch.float64
1093+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
10931094
else:
1094-
dtype = torch.int32 if is_mps else torch.int64
1095+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
10951096
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
10961097
elif len(timesteps.shape) == 0:
10971098
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/unets/unet_2d_condition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -915,10 +915,11 @@ def get_time_embed(
915915
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
916916
# This would be a good case for the `match` statement (Python 3.10+)
917917
is_mps = sample.device.type == "mps"
918+
is_npu = sample.device.type == "npu"
918919
if isinstance(timestep, float):
919-
dtype = torch.float32 if is_mps else torch.float64
920+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
920921
else:
921-
dtype = torch.int32 if is_mps else torch.int64
922+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
922923
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
923924
elif len(timesteps.shape) == 0:
924925
timesteps = timesteps[None].to(sample.device)

src/diffusers/models/unets/unet_3d_condition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -624,10 +624,11 @@ def forward(
624624
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
625625
# This would be a good case for the `match` statement (Python 3.10+)
626626
is_mps = sample.device.type == "mps"
627+
is_npu = sample.device.type == "npu"
627628
if isinstance(timestep, float):
628-
dtype = torch.float32 if is_mps else torch.float64
629+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
629630
else:
630-
dtype = torch.int32 if is_mps else torch.int64
631+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
631632
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
632633
elif len(timesteps.shape) == 0:
633634
timesteps = timesteps[None].to(sample.device)

0 commit comments

Comments
 (0)