@@ -223,7 +223,7 @@ def resnet_v2(input_shape, depth, num_classes=10):
223
223
n = 1
224
224
depth = n * 9 + 2
225
225
226
- def train ():
226
+ def train (dst_path ):
227
227
# Load the CIFAR10 data.
228
228
(x_train , y_train ), (x_test , y_test ) = cifar10 .load_data ()
229
229
@@ -269,7 +269,7 @@ def train():
269
269
scores = model .evaluate (x_test , y_test , verbose = 1 )
270
270
print ('Test loss:' , scores [0 ])
271
271
print ('Test accuracy:' , scores [1 ])
272
- model .save ("baseline_model" )
272
+ model .save (dst_path )
273
273
274
274
class TrainDataset (object ):
275
275
def __init__ (self ):
@@ -322,16 +322,17 @@ def __getitem__(self, idx):
322
322
return self .test_images [idx ], self .test_labels [idx ]
323
323
324
324
class TestTensorflowPruning (unittest .TestCase ):
325
+ dst_path = '/tmp/.neural_compressor/inc_ut/resnet_v2/baseline_model'
325
326
@classmethod
326
327
def setUpClass (self ):
327
- build_fake_yaml ()
328
- cmd = 'cp -r /home/tensorflow/inc_ut/resnet_v2/baseline_model ./'
329
- os .popen (cmd ).readlines ()
328
+ build_fake_yaml ()
329
+ if not os .path .exists (self .dst_path ):
330
+ print ("resnet_v2 baseline_model doesn't exist" )
331
+ return unittest .skip ("resnet_v2 baseline_model doesn't exist" )(TestTensorflowPruning )
330
332
331
333
@classmethod
332
334
def tearDownClass (self ):
333
335
os .remove ('fake_yaml.yaml' )
334
- shutil .rmtree ('baseline_model' ,ignore_errors = True )
335
336
shutil .rmtree ('nc_workspace' ,ignore_errors = True )
336
337
337
338
@unittest .skipIf (tensorflow .version .VERSION < '2.3.0' , "Keras model need tensorflow version >= 2.3.0, so the case is skipped" )
@@ -387,7 +388,7 @@ def test_tensorflow_pruning(self):
387
388
prune = Pruning ("./fake_yaml.yaml" )
388
389
prune .train_dataloader = common .DataLoader (TrainDataset (), batch_size = 32 )
389
390
prune .eval_dataloader = common .DataLoader (EvalDataset (), batch_size = 32 )
390
- prune .model = './baseline_model'
391
+ prune .model = self . dst_path
391
392
pruned_model = prune ()
392
393
stats , sparsity = pruned_model .report_sparsity ()
393
394
logger .info (stats )
0 commit comments