Skip to content

Commit f108ad8

Browse files
authored
Update modeling imports (#11129)
update
1 parent e30d3bf commit f108ad8

8 files changed

+29
-39
lines changed

src/diffusers/models/controlnets/controlnet_flux.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,12 @@
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...loaders import PeftAdapterMixin
23-
from ...models.attention_processor import AttentionProcessor
24-
from ...models.modeling_utils import ModelMixin
2523
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24+
from ..attention_processor import AttentionProcessor
2625
from ..controlnets.controlnet import ControlNetConditioningEmbedding, zero_module
2726
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
2827
from ..modeling_outputs import Transformer2DModelOutput
28+
from ..modeling_utils import ModelMixin
2929
from ..transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
3030

3131

src/diffusers/models/controlnets/multicontrolnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import torch
55
from torch import nn
66

7-
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
8-
from ...models.modeling_utils import ModelMixin
97
from ...utils import logging
8+
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
9+
from ..modeling_utils import ModelMixin
1010

1111

1212
logger = logging.get_logger(__name__)

src/diffusers/models/controlnets/multicontrolnet_union.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import torch
55
from torch import nn
66

7-
from ...models.controlnets.controlnet import ControlNetOutput
8-
from ...models.controlnets.controlnet_union import ControlNetUnionModel
9-
from ...models.modeling_utils import ModelMixin
107
from ...utils import logging
8+
from ..controlnets.controlnet import ControlNetOutput
9+
from ..controlnets.controlnet_union import ControlNetUnionModel
10+
from ..modeling_utils import ModelMixin
1111

1212

1313
logger = logging.get_logger(__name__)

src/diffusers/models/transformers/latte_transformer_3d.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@
1818
from torch import nn
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
21-
from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
2221
from ..attention import BasicTransformerBlock
2322
from ..cache_utils import CacheMixin
24-
from ..embeddings import PatchEmbed
23+
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
2524
from ..modeling_outputs import Transformer2DModelOutput
2625
from ..modeling_utils import ModelMixin
2726
from ..normalization import AdaLayerNormSingle

src/diffusers/models/transformers/stable_audio_transformer.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,12 @@
2121
import torch.utils.checkpoint
2222

2323
from ...configuration_utils import ConfigMixin, register_to_config
24-
from ...models.attention import FeedForward
25-
from ...models.attention_processor import (
26-
Attention,
27-
AttentionProcessor,
28-
StableAudioAttnProcessor2_0,
29-
)
30-
from ...models.modeling_utils import ModelMixin
31-
from ...models.transformers.transformer_2d import Transformer2DModelOutput
3224
from ...utils import logging
3325
from ...utils.torch_utils import maybe_allow_in_graph
26+
from ..attention import FeedForward
27+
from ..attention_processor import Attention, AttentionProcessor, StableAudioAttnProcessor2_0
28+
from ..modeling_utils import ModelMixin
29+
from ..transformers.transformer_2d import Transformer2DModelOutput
3430

3531

3632
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

src/diffusers/models/transformers/transformer_cogview3plus.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,13 @@
1919
import torch.nn as nn
2020

2121
from ...configuration_utils import ConfigMixin, register_to_config
22-
from ...models.attention import FeedForward
23-
from ...models.attention_processor import (
24-
Attention,
25-
AttentionProcessor,
26-
CogVideoXAttnProcessor2_0,
27-
)
28-
from ...models.modeling_utils import ModelMixin
29-
from ...models.normalization import AdaLayerNormContinuous
3022
from ...utils import logging
23+
from ..attention import FeedForward
24+
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
3125
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
3226
from ..modeling_outputs import Transformer2DModelOutput
33-
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
27+
from ..modeling_utils import ModelMixin
28+
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
3429

3530

3631
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

src/diffusers/models/transformers/transformer_flux.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -21,22 +21,22 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
24-
from ...models.attention import FeedForward
25-
from ...models.attention_processor import (
24+
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25+
from ...utils.import_utils import is_torch_npu_available
26+
from ...utils.torch_utils import maybe_allow_in_graph
27+
from ..attention import FeedForward
28+
from ..attention_processor import (
2629
Attention,
2730
AttentionProcessor,
2831
FluxAttnProcessor2_0,
2932
FluxAttnProcessor2_0_NPU,
3033
FusedFluxAttnProcessor2_0,
3134
)
32-
from ...models.modeling_utils import ModelMixin
33-
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
34-
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
35-
from ...utils.import_utils import is_torch_npu_available
36-
from ...utils.torch_utils import maybe_allow_in_graph
3735
from ..cache_utils import CacheMixin
3836
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
3937
from ..modeling_outputs import Transformer2DModelOutput
38+
from ..modeling_utils import ModelMixin
39+
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
4040

4141

4242
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

src/diffusers/models/transformers/transformer_sd3.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@
1818

1919
from ...configuration_utils import ConfigMixin, register_to_config
2020
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, SD3Transformer2DLoadersMixin
21-
from ...models.attention import FeedForward, JointTransformerBlock
22-
from ...models.attention_processor import (
21+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
22+
from ...utils.torch_utils import maybe_allow_in_graph
23+
from ..attention import FeedForward, JointTransformerBlock
24+
from ..attention_processor import (
2325
Attention,
2426
AttentionProcessor,
2527
FusedJointAttnProcessor2_0,
2628
JointAttnProcessor2_0,
2729
)
28-
from ...models.modeling_utils import ModelMixin
29-
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero
30-
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
31-
from ...utils.torch_utils import maybe_allow_in_graph
3230
from ..embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
3331
from ..modeling_outputs import Transformer2DModelOutput
32+
from ..modeling_utils import ModelMixin
33+
from ..normalization import AdaLayerNormContinuous, AdaLayerNormZero
3434

3535

3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

0 commit comments

Comments
 (0)