Skip to content

Commit 0143736

Browse files
authored
feat: add device_target config and set default to Ascend. (#796)
1 parent 6e2a4c7 commit 0143736

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def create_parser():
4343
help='Interval for print training log. Unit: step (default=100)')
4444
group.add_argument('--seed', type=int, default=42,
4545
help='Seed value for determining randomness in numpy, random, and mindspore (default=42)')
46+
group.add_argument('--device_target', type=str, default='Ascend',
47+
help='Device target for computing, which can be Ascend, GPU or CPU. (default=Ascend)')
4648

4749
# Dataset parameters
4850
group = parser.add_argument_group('Dataset parameters')
@@ -94,7 +96,7 @@ def create_parser():
9496
'Example: "randaug-m10-n2-w0-mstd0.5-mmax10-inc0", "autoaug-mstd0.5" or autoaugr-mstd0.5.')
9597
group.add_argument('--aug_splits', type=int, default=0,
9698
help='Number of augmentation splits (default: 0, valid: 3 (currently, only support 3 splits))'
97-
'it should be set with one auto_augment')
99+
'it should be set with one auto_augment')
98100
group.add_argument('--re_prob', type=float, default=0.0,
99101
help='Probability of performing erasing (default=0.0)')
100102
group.add_argument('--re_scale', type=tuple, default=(0.02, 0.33),
@@ -267,6 +269,8 @@ def create_parser():
267269
help='Whether to shuffle the evaluation data (default=False)')
268270

269271
return parser_config, parser
272+
273+
270274
# fmt: on
271275

272276

tests/tasks/test_train_val_imagenet_subset.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_train(mode, val_while_train, model="resnet18"):
3030
DownLoad().download_and_extract_archive(dataset_url, root_dir)
3131

3232
# ---------------- test running train.py using the toy data ---------
33+
device_target = "CPU"
3334
dataset = "imagenet"
3435
num_classes = 2
3536
ckpt_dir = "./tests/ckpt_tmp"
@@ -48,7 +49,8 @@ def test_train(mode, val_while_train, model="resnet18"):
4849
f"python {train_file} --dataset={dataset} --num_classes={num_classes} --model={model} "
4950
f"--epoch_size={num_epochs} --ckpt_save_interval=2 --lr=0.0001 --num_samples={num_samples} --loss=CE "
5051
f"--weight_decay=1e-6 --ckpt_save_dir={ckpt_dir} {download_str} --train_split=train --batch_size={batch_size} "
51-
f"--pretrained --num_parallel_workers=2 --val_while_train={val_while_train} --val_split=val --val_interval=1"
52+
f"--pretrained --num_parallel_workers=2 --val_while_train={val_while_train} --val_split=val --val_interval=1 "
53+
f"--device_target={device_target}"
5254
)
5355

5456
print(f"Running command: \n{cmd}")
@@ -57,10 +59,11 @@ def test_train(mode, val_while_train, model="resnet18"):
5759

5860
# --------- Test running validate.py using the trained model ------------- #
5961
# begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt')
60-
end_ckpt = os.path.join(ckpt_dir, f"{model}-{num_epochs}_{num_samples//batch_size}.ckpt")
62+
end_ckpt = os.path.join(ckpt_dir, f"{model}-{num_epochs}_{num_samples // batch_size}.ckpt")
6163
cmd = (
6264
f"python validate.py --model={model} --dataset={dataset} --val_split=val --data_dir={data_dir} "
63-
f"--num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2"
65+
f"--num_classes={num_classes} --ckpt_path={end_ckpt} --batch_size=40 --num_parallel_workers=2 "
66+
f"--device_target={device_target}"
6467
)
6568
# ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
6669
print(f"Running command: \n{cmd}")

validate.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def check_batch_size(num_samples, ori_batch_size=32, refine=True):
2626

2727

2828
def validate(args):
29+
ms.set_context(device_target=args.device_target)
2930
ms.set_context(mode=args.mode)
3031
if args.mode == ms.GRAPH_MODE:
3132
ms.set_context(jit_config={"jit_level": "O2"})

0 commit comments

Comments
 (0)