Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 48a0113

Browse files
authored
Add rgb-mean and rgb-std arguments (#1546)
1 parent 641c830 commit 48a0113

File tree

1 file changed

+24
-0
lines changed
  • src/sparseml/pytorch/torchvision

1 file changed

+24
-0
lines changed

src/sparseml/pytorch/torchvision/train.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def load_data(traindir, valdir, args):
267267
traindir,
268268
presets.ClassificationPresetTrain(
269269
crop_size=train_crop_size,
270+
mean=args.rgb_mean,
271+
std=args.rgb_std,
270272
interpolation=interpolation,
271273
auto_augment_policy=auto_augment_policy,
272274
random_erase_prob=random_erase_prob,
@@ -289,6 +291,8 @@ def load_data(traindir, valdir, args):
289291
else:
290292
preprocessing = presets.ClassificationPresetEval(
291293
crop_size=val_crop_size,
294+
mean=args.rgb_mean,
295+
std=args.rgb_std,
292296
resize_size=val_resize_size,
293297
interpolation=interpolation,
294298
)
@@ -1212,6 +1216,26 @@ def new_func(*args, **kwargs):
12121216
"Note: Will be read from the checkpoint if not specified"
12131217
),
12141218
)
1219+
@click.option(
1220+
"--rgb-mean",
1221+
nargs=3,
1222+
default=(0.485, 0.456, 0.406),
1223+
type=float,
1224+
help=(
1225+
"RGB mean values used to shift input RGB values; "
1226+
"Note: Will use ImageNet values if not specified."
1227+
),
1228+
)
1229+
@click.option(
1230+
"--rgb-std",
1231+
default=(0.229, 0.224, 0.225),
1232+
nargs=3,
1233+
type=float,
1234+
help=(
1235+
"RGB standard-deviation values used to normalize input RGB values; "
1236+
"Note: Will use ImageNet values if not specified."
1237+
),
1238+
)
12151239
@click.pass_context
12161240
def cli(ctx, **kwargs):
12171241
"""

0 commit comments

Comments
 (0)