Skip to content

Commit 5c149ec

Browse files
authored
fix sdxl lora load (#108)
* fix sdxl lora load * fix device
1 parent 44df014 commit 5c149ec

File tree

6 files changed

+6
-10
lines changed

6 files changed

+6
-10
lines changed

diffsynth_engine/models/sd/sd_controlnet.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
AttentionBlock,
1111
PushBlock,
1212
DownSampler,
13-
PopBlock,
14-
UpSampler,
1513
)
1614

1715
class ControlNetConditioningLayer(nn.Module):
@@ -565,7 +563,6 @@ def forward(
565563
time_emb = self.time_embedding(timestep, dtype=sample.dtype)
566564

567565
# 2. pre-process
568-
height, width = sample.shape[2], sample.shape[3]
569566
hidden_states = self.conv_in(sample) + self.controlnet_conv_in(conditioning)
570567
text_emb = encoder_hidden_states
571568
res_stack = [hidden_states]

diffsynth_engine/models/sdxl/sdxl_controlnet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
import torch
2-
import torch.nn as nn
32
from typing import Optional, Dict
43
from diffsynth_engine.models.basic.unet_helper import (
54
ResnetBlock,
65
AttentionBlock,
76
PushBlock,
87
DownSampler,
9-
PopBlock,
10-
UpSampler,
118
)
129
from diffsynth_engine.models.sd.sd_controlnet import ControlNetConditioningLayer
1310
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
@@ -283,7 +280,6 @@ def forward(
283280
time_emb = t_emb + add_embeds + control_embeds
284281

285282
# 2. pre-process
286-
height, width = sample.shape[2], sample.shape[3]
287283
hidden_states = self.conv_in(sample)
288284
hidden_states = self.fuse_condition_to_input(hidden_states, task_id, conditioning)
289285
text_emb = encoder_hidden_states

diffsynth_engine/pipelines/controlnet_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from typing import Dict, List, Tuple, Union, Optional
3+
from typing import List, Union, Optional
44
from PIL import Image
55
from dataclasses import dataclass
66

diffsynth_engine/pipelines/flux_image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import os
33
import json
44
import torch
5-
import torch.nn as nn
65
import torch.distributed as dist
76
import math
87
from einops import rearrange
98
from enum import Enum
10-
from typing import Callable, Dict, List, Tuple, Optional, Union
9+
from typing import Callable, Dict, List, Tuple, Optional
1110
from tqdm import tqdm
1211
from PIL import Image
1312
from dataclasses import dataclass

diffsynth_engine/pipelines/sd_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from diffsynth_engine.algorithm.sampler import EulerSampler
1919
from diffsynth_engine.utils.prompt import tokenize_long_prompt
2020
from diffsynth_engine.utils.constants import SDXL_TOKENIZER_CONF_PATH
21+
from diffsynth_engine.utils.platform import empty_cache
2122
from diffsynth_engine.utils import logging
2223

2324
logger = logging.get_logger(__name__)

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from diffsynth_engine.algorithm.sampler import EulerSampler
2727
from diffsynth_engine.utils.prompt import tokenize_long_prompt
2828
from diffsynth_engine.utils.constants import SDXL_TOKENIZER_CONF_PATH, SDXL_TOKENIZER_2_CONF_PATH
29+
from diffsynth_engine.utils.platform import empty_cache
2930
from diffsynth_engine.utils import logging
3031

3132
logger = logging.get_logger(__name__)
@@ -89,6 +90,8 @@ def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dic
8990
unet_dict[key] = lora_args
9091
else:
9192
raise ValueError(f"Unsupported key: {key}")
93+
# clip skip
94+
te1_dict = {k: v for k, v in te1_dict.items() if not k.startswith('encoders.11')}
9295
return {"unet": unet_dict, "text_encoder": te1_dict, "text_encoder_2": te2_dict}
9396

9497
def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:

0 commit comments

Comments
 (0)