@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
971
971
972
972
def __init__ (
973
973
self ,
974
+ args ,
974
975
instance_data_root ,
975
976
instance_prompt ,
976
977
class_prompt ,
@@ -980,10 +981,8 @@ def __init__(
980
981
class_num = None ,
981
982
size = 1024 ,
982
983
repeats = 1 ,
983
- center_crop = False ,
984
984
):
985
985
self .size = size
986
- self .center_crop = center_crop
987
986
988
987
self .instance_prompt = instance_prompt
989
988
self .custom_instance_prompts = None
@@ -1058,7 +1057,7 @@ def __init__(
1058
1057
if interpolation is None :
1059
1058
raise ValueError (f"Unsupported interpolation mode { interpolation = } ." )
1060
1059
train_resize = transforms .Resize (size , interpolation = interpolation )
1061
- train_crop = transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size )
1060
+ train_crop = transforms .CenterCrop (size ) if args . center_crop else transforms .RandomCrop (size )
1062
1061
train_flip = transforms .RandomHorizontalFlip (p = 1.0 )
1063
1062
train_transforms = transforms .Compose (
1064
1063
[
@@ -1075,11 +1074,11 @@ def __init__(
1075
1074
# flip
1076
1075
image = train_flip (image )
1077
1076
if args .center_crop :
1078
- y1 = max (0 , int (round ((image .height - args . resolution ) / 2.0 )))
1079
- x1 = max (0 , int (round ((image .width - args . resolution ) / 2.0 )))
1077
+ y1 = max (0 , int (round ((image .height - self . size ) / 2.0 )))
1078
+ x1 = max (0 , int (round ((image .width - self . size ) / 2.0 )))
1080
1079
image = train_crop (image )
1081
1080
else :
1082
- y1 , x1 , h , w = train_crop .get_params (image , (args . resolution , args . resolution ))
1081
+ y1 , x1 , h , w = train_crop .get_params (image , (self . size , self . size ))
1083
1082
image = crop (image , y1 , x1 , h , w )
1084
1083
image = train_transforms (image )
1085
1084
self .pixel_values .append (image )
@@ -1102,7 +1101,7 @@ def __init__(
1102
1101
self .image_transforms = transforms .Compose (
1103
1102
[
1104
1103
transforms .Resize (size , interpolation = interpolation ),
1105
- transforms .CenterCrop (size ) if center_crop else transforms .RandomCrop (size ),
1104
+ transforms .CenterCrop (size ) if args . center_crop else transforms .RandomCrop (size ),
1106
1105
transforms .ToTensor (),
1107
1106
transforms .Normalize ([0.5 ], [0.5 ]),
1108
1107
]
@@ -1827,6 +1826,7 @@ def load_model_hook(models, input_dir):
1827
1826
1828
1827
# Dataset and DataLoaders creation:
1829
1828
train_dataset = DreamBoothDataset (
1829
+ args = args ,
1830
1830
instance_data_root = args .instance_data_dir ,
1831
1831
instance_prompt = args .instance_prompt ,
1832
1832
train_text_encoder_ti = args .train_text_encoder_ti ,
@@ -1836,7 +1836,6 @@ def load_model_hook(models, input_dir):
1836
1836
class_num = args .num_class_images ,
1837
1837
size = args .resolution ,
1838
1838
repeats = args .repeats ,
1839
- center_crop = args .center_crop ,
1840
1839
)
1841
1840
1842
1841
train_dataloader = torch .utils .data .DataLoader (
0 commit comments