Skip to content

Commit 7be08f4

Browse files
authored
Merge branch 'main' into refactor-hooks
2 parents 1d84505 + ef1e628 commit 7be08f4

File tree

1 file changed

+7
-8
lines changed

1 file changed

+7
-8
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
971971

972972
def __init__(
973973
self,
974+
args,
974975
instance_data_root,
975976
instance_prompt,
976977
class_prompt,
@@ -980,10 +981,8 @@ def __init__(
980981
class_num=None,
981982
size=1024,
982983
repeats=1,
983-
center_crop=False,
984984
):
985985
self.size = size
986-
self.center_crop = center_crop
987986

988987
self.instance_prompt = instance_prompt
989988
self.custom_instance_prompts = None
@@ -1058,7 +1057,7 @@ def __init__(
10581057
if interpolation is None:
10591058
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
10601059
train_resize = transforms.Resize(size, interpolation=interpolation)
1061-
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
1060+
train_crop = transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size)
10621061
train_flip = transforms.RandomHorizontalFlip(p=1.0)
10631062
train_transforms = transforms.Compose(
10641063
[
@@ -1075,11 +1074,11 @@ def __init__(
10751074
# flip
10761075
image = train_flip(image)
10771076
if args.center_crop:
1078-
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
1079-
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
1077+
y1 = max(0, int(round((image.height - self.size) / 2.0)))
1078+
x1 = max(0, int(round((image.width - self.size) / 2.0)))
10801079
image = train_crop(image)
10811080
else:
1082-
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
1081+
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
10831082
image = crop(image, y1, x1, h, w)
10841083
image = train_transforms(image)
10851084
self.pixel_values.append(image)
@@ -1102,7 +1101,7 @@ def __init__(
11021101
self.image_transforms = transforms.Compose(
11031102
[
11041103
transforms.Resize(size, interpolation=interpolation),
1105-
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
1104+
transforms.CenterCrop(size) if args.center_crop else transforms.RandomCrop(size),
11061105
transforms.ToTensor(),
11071106
transforms.Normalize([0.5], [0.5]),
11081107
]
@@ -1827,6 +1826,7 @@ def load_model_hook(models, input_dir):
18271826

18281827
# Dataset and DataLoaders creation:
18291828
train_dataset = DreamBoothDataset(
1829+
args=args,
18301830
instance_data_root=args.instance_data_dir,
18311831
instance_prompt=args.instance_prompt,
18321832
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1836,7 +1836,6 @@ def load_model_hook(models, input_dir):
18361836
class_num=args.num_class_images,
18371837
size=args.resolution,
18381838
repeats=args.repeats,
1839-
center_crop=args.center_crop,
18401839
)
18411840

18421841
train_dataloader = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)