@@ -30,6 +30,7 @@ def test_train(mode, val_while_train, model="resnet18"):
30
30
DownLoad ().download_and_extract_archive (dataset_url , root_dir )
31
31
32
32
# ---------------- test running train.py using the toy data ---------
33
+ device_target = "CPU"
33
34
dataset = "imagenet"
34
35
num_classes = 2
35
36
ckpt_dir = "./tests/ckpt_tmp"
@@ -48,7 +49,8 @@ def test_train(mode, val_while_train, model="resnet18"):
48
49
f"python { train_file } --dataset={ dataset } --num_classes={ num_classes } --model={ model } "
49
50
f"--epoch_size={ num_epochs } --ckpt_save_interval=2 --lr=0.0001 --num_samples={ num_samples } --loss=CE "
50
51
f"--weight_decay=1e-6 --ckpt_save_dir={ ckpt_dir } { download_str } --train_split=train --batch_size={ batch_size } "
51
- f"--pretrained --num_parallel_workers=2 --val_while_train={ val_while_train } --val_split=val --val_interval=1"
52
+ f"--pretrained --num_parallel_workers=2 --val_while_train={ val_while_train } --val_split=val --val_interval=1 "
53
+ f"--device_target={ device_target } "
52
54
)
53
55
54
56
print (f"Running command: \n { cmd } " )
@@ -57,10 +59,11 @@ def test_train(mode, val_while_train, model="resnet18"):
57
59
58
60
# --------- Test running validate.py using the trained model ------------- #
59
61
# begin_ckpt = os.path.join(ckpt_dir, f'{model}-1_1.ckpt')
60
- end_ckpt = os .path .join (ckpt_dir , f"{ model } -{ num_epochs } _{ num_samples // batch_size } .ckpt" )
62
+ end_ckpt = os .path .join (ckpt_dir , f"{ model } -{ num_epochs } _{ num_samples // batch_size } .ckpt" )
61
63
cmd = (
62
64
f"python validate.py --model={ model } --dataset={ dataset } --val_split=val --data_dir={ data_dir } "
63
- f"--num_classes={ num_classes } --ckpt_path={ end_ckpt } --batch_size=40 --num_parallel_workers=2"
65
+ f"--num_classes={ num_classes } --ckpt_path={ end_ckpt } --batch_size=40 --num_parallel_workers=2 "
66
+ f"--device_target={ device_target } "
64
67
)
65
68
# ret = subprocess.call(cmd.split(), stdout=sys.stdout, stderr=sys.stderr)
66
69
print (f"Running command: \n { cmd } " )
0 commit comments