Skip to content

Commit 95c5ce4

Browse files
hlkysayakpaul
andauthored
PyTorch/XLA support (#10498)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
1 parent c096457 commit 95c5ce4

File tree

111 files changed

+1369
-34
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

111 files changed

+1369
-34
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
deprecate,
3434
is_bs4_available,
3535
is_ftfy_available,
36+
is_torch_xla_available,
3637
logging,
3738
replace_example_docstring,
3839
)
@@ -41,6 +42,14 @@
4142
from .pipeline_output import AllegroPipelineOutput
4243

4344

45+
if is_torch_xla_available():
46+
import torch_xla.core.xla_model as xm
47+
48+
XLA_AVAILABLE = True
49+
else:
50+
XLA_AVAILABLE = False
51+
52+
4453
logger = logging.get_logger(__name__)
4554

4655
if is_bs4_available():
@@ -921,6 +930,9 @@ def __call__(
921930
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
922931
progress_bar.update()
923932

933+
if XLA_AVAILABLE:
934+
xm.mark_step()
935+
924936
if not output_type == "latent":
925937
latents = latents.to(self.vae.dtype)
926938
video = self.decode_latents(latents)

src/diffusers/pipelines/amused/pipeline_amused.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,18 @@
2020
from ...image_processor import VaeImageProcessor
2121
from ...models import UVit2DModel, VQModel
2222
from ...schedulers import AmusedScheduler
23-
from ...utils import replace_example_docstring
23+
from ...utils import is_torch_xla_available, replace_example_docstring
2424
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2525

2626

27+
if is_torch_xla_available():
28+
import torch_xla.core.xla_model as xm
29+
30+
XLA_AVAILABLE = True
31+
else:
32+
XLA_AVAILABLE = False
33+
34+
2735
EXAMPLE_DOC_STRING = """
2836
Examples:
2937
```py
@@ -299,6 +307,9 @@ def __call__(
299307
step_idx = i // getattr(self.scheduler, "order", 1)
300308
callback(step_idx, timestep, latents)
301309

310+
if XLA_AVAILABLE:
311+
xm.mark_step()
312+
302313
if output_type == "latent":
303314
output = latents
304315
else:

src/diffusers/pipelines/amused/pipeline_amused_img2img.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,18 @@
2020
from ...image_processor import PipelineImageInput, VaeImageProcessor
2121
from ...models import UVit2DModel, VQModel
2222
from ...schedulers import AmusedScheduler
23-
from ...utils import replace_example_docstring
23+
from ...utils import is_torch_xla_available, replace_example_docstring
2424
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2525

2626

27+
if is_torch_xla_available():
28+
import torch_xla.core.xla_model as xm
29+
30+
XLA_AVAILABLE = True
31+
else:
32+
XLA_AVAILABLE = False
33+
34+
2735
EXAMPLE_DOC_STRING = """
2836
Examples:
2937
```py
@@ -325,6 +333,9 @@ def __call__(
325333
step_idx = i // getattr(self.scheduler, "order", 1)
326334
callback(step_idx, timestep, latents)
327335

336+
if XLA_AVAILABLE:
337+
xm.mark_step()
338+
328339
if output_type == "latent":
329340
output = latents
330341
else:

src/diffusers/pipelines/amused/pipeline_amused_inpaint.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,18 @@
2121
from ...image_processor import PipelineImageInput, VaeImageProcessor
2222
from ...models import UVit2DModel, VQModel
2323
from ...schedulers import AmusedScheduler
24-
from ...utils import replace_example_docstring
24+
from ...utils import is_torch_xla_available, replace_example_docstring
2525
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2626

2727

28+
if is_torch_xla_available():
29+
import torch_xla.core.xla_model as xm
30+
31+
XLA_AVAILABLE = True
32+
else:
33+
XLA_AVAILABLE = False
34+
35+
2836
EXAMPLE_DOC_STRING = """
2937
Examples:
3038
```py
@@ -356,6 +364,9 @@ def __call__(
356364
step_idx = i // getattr(self.scheduler, "order", 1)
357365
callback(step_idx, timestep, latents)
358366

367+
if XLA_AVAILABLE:
368+
xm.mark_step()
369+
359370
if output_type == "latent":
360371
output = latents
361372
else:

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from ...utils import (
3535
USE_PEFT_BACKEND,
3636
deprecate,
37+
is_torch_xla_available,
3738
logging,
3839
replace_example_docstring,
3940
scale_lora_layers,
@@ -47,8 +48,16 @@
4748
from .pipeline_output import AnimateDiffPipelineOutput
4849

4950

51+
if is_torch_xla_available():
52+
import torch_xla.core.xla_model as xm
53+
54+
XLA_AVAILABLE = True
55+
else:
56+
XLA_AVAILABLE = False
57+
5058
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5159

60+
5261
EXAMPLE_DOC_STRING = """
5362
Examples:
5463
```py
@@ -844,6 +853,9 @@ def __call__(
844853
if callback is not None and i % callback_steps == 0:
845854
callback(i, t, latents)
846855

856+
if XLA_AVAILABLE:
857+
xm.mark_step()
858+
847859
# 9. Post processing
848860
if output_type == "latent":
849861
video = latents

src/diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from ...models.lora import adjust_lora_scale_text_encoder
3333
from ...models.unets.unet_motion_model import MotionAdapter
3434
from ...schedulers import KarrasDiffusionSchedulers
35-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
35+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
3636
from ...utils.torch_utils import is_compiled_module, randn_tensor
3737
from ...video_processor import VideoProcessor
3838
from ..free_init_utils import FreeInitMixin
@@ -41,8 +41,16 @@
4141
from .pipeline_output import AnimateDiffPipelineOutput
4242

4343

44+
if is_torch_xla_available():
45+
import torch_xla.core.xla_model as xm
46+
47+
XLA_AVAILABLE = True
48+
else:
49+
XLA_AVAILABLE = False
50+
4451
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4552

53+
4654
EXAMPLE_DOC_STRING = """
4755
Examples:
4856
```py
@@ -1090,6 +1098,9 @@ def __call__(
10901098
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
10911099
progress_bar.update()
10921100

1101+
if XLA_AVAILABLE:
1102+
xm.mark_step()
1103+
10931104
# 9. Post processing
10941105
if output_type == "latent":
10951106
video = latents

src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
)
4949
from ...utils import (
5050
USE_PEFT_BACKEND,
51+
is_torch_xla_available,
5152
logging,
5253
replace_example_docstring,
5354
scale_lora_layers,
@@ -60,8 +61,16 @@
6061
from .pipeline_output import AnimateDiffPipelineOutput
6162

6263

64+
if is_torch_xla_available():
65+
import torch_xla.core.xla_model as xm
66+
67+
XLA_AVAILABLE = True
68+
else:
69+
XLA_AVAILABLE = False
70+
6371
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6472

73+
6574
EXAMPLE_DOC_STRING = """
6675
Examples:
6776
```py
@@ -1265,6 +1274,9 @@ def __call__(
12651274

12661275
progress_bar.update()
12671276

1277+
if XLA_AVAILABLE:
1278+
xm.mark_step()
1279+
12681280
# make sure the VAE is in float32 mode, as it overflows in float16
12691281
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
12701282

src/diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ...schedulers import KarrasDiffusionSchedulers
3131
from ...utils import (
3232
USE_PEFT_BACKEND,
33+
is_torch_xla_available,
3334
logging,
3435
replace_example_docstring,
3536
scale_lora_layers,
@@ -42,8 +43,16 @@
4243
from .pipeline_output import AnimateDiffPipelineOutput
4344

4445

46+
if is_torch_xla_available():
47+
import torch_xla.core.xla_model as xm
48+
49+
XLA_AVAILABLE = True
50+
else:
51+
XLA_AVAILABLE = False
52+
4553
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4654

55+
4756
EXAMPLE_DOC_STRING = """
4857
Examples:
4958
```python
@@ -994,6 +1003,9 @@ def __call__(
9941003
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
9951004
progress_bar.update()
9961005

1006+
if XLA_AVAILABLE:
1007+
xm.mark_step()
1008+
9971009
# 11. Post processing
9981010
if output_type == "latent":
9991011
video = latents

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
LMSDiscreteScheduler,
3232
PNDMScheduler,
3333
)
34-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
34+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
3535
from ...utils.torch_utils import randn_tensor
3636
from ...video_processor import VideoProcessor
3737
from ..free_init_utils import FreeInitMixin
@@ -40,8 +40,16 @@
4040
from .pipeline_output import AnimateDiffPipelineOutput
4141

4242

43+
if is_torch_xla_available():
44+
import torch_xla.core.xla_model as xm
45+
46+
XLA_AVAILABLE = True
47+
else:
48+
XLA_AVAILABLE = False
49+
4350
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4451

52+
4553
EXAMPLE_DOC_STRING = """
4654
Examples:
4755
```py
@@ -1037,6 +1045,9 @@ def __call__(
10371045
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
10381046
progress_bar.update()
10391047

1048+
if XLA_AVAILABLE:
1049+
xm.mark_step()
1050+
10401051
# 10. Post-processing
10411052
if output_type == "latent":
10421053
video = latents

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
LMSDiscreteScheduler,
4040
PNDMScheduler,
4141
)
42-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
42+
from ...utils import USE_PEFT_BACKEND, is_torch_xla_available, logging, scale_lora_layers, unscale_lora_layers
4343
from ...utils.torch_utils import is_compiled_module, randn_tensor
4444
from ...video_processor import VideoProcessor
4545
from ..free_init_utils import FreeInitMixin
@@ -48,8 +48,16 @@
4848
from .pipeline_output import AnimateDiffPipelineOutput
4949

5050

51+
if is_torch_xla_available():
52+
import torch_xla.core.xla_model as xm
53+
54+
XLA_AVAILABLE = True
55+
else:
56+
XLA_AVAILABLE = False
57+
5158
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5259

60+
5361
EXAMPLE_DOC_STRING = """
5462
Examples:
5563
```py
@@ -1325,6 +1333,9 @@ def __call__(
13251333
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
13261334
progress_bar.update()
13271335

1336+
if XLA_AVAILABLE:
1337+
xm.mark_step()
1338+
13281339
# 11. Post-processing
13291340
if output_type == "latent":
13301341
video = latents

0 commit comments

Comments
 (0)