Skip to content

Commit c09e888

Browse files
authored
remove invalid ut path name (#1099)
1 parent 34c7054 commit c09e888

File tree

2 files changed

+12
-13
lines changed

2 files changed

+12
-13
lines changed

test/pruning/test_tensorflow_distributed_pruning.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ def test_tensorflow_pruning(self):
299299
prune.evaluation_distributed = True
300300
prune.train_dataloader = common.DataLoader(TrainDataset(), batch_size=16)
301301
prune.eval_dataloader = common.DataLoader(EvalDataset(), batch_size=32)
302-
prune.model = './baseline_model'
302+
prune.model = '/tmp/.neural_compressor/inc_ut/resnet_v2/baseline_model'
303303
pruned_model = prune()
304304
stats, sparsity = pruned_model.report_sparsity()
305305
logger.info(stats)
@@ -315,8 +315,6 @@ def test_tensorflow_pruning(self):
315315
with open('fake_ut.py', 'w', encoding="utf-8") as f:
316316
f.write(fake_ut)
317317
build_fake_yaml()
318-
cmd = 'cp -r /home/tensorflow/inc_ut/resnet_v2/baseline_model ./'
319-
os.popen(cmd).readlines()
320318

321319

322320
def build_fake_yaml():
@@ -361,14 +359,14 @@ class TestDistributed(unittest.TestCase):
361359
def setUpClass(cls):
362360
build_fake_ut()
363361
build_fake_yaml()
364-
cmd = 'cp -r /home/tensorflow/inc_ut/resnet_v2/baseline_model ./'
365-
os.popen(cmd).readlines()
362+
if not os.path.exists(r'/tmp/.neural_compressor/inc_ut/resnet_v2/baseline_model'):
363+
print("resnet_v2 baseline_model doesn't exist")
364+
return unittest.skip("resnet_v2 baseline_model doesn't exist")(TestDistributed)
366365

367366
@classmethod
368367
def tearDownClass(cls):
369368
os.remove('fake_ut.py')
370369
os.remove('fake_yaml.yaml')
371-
shutil.rmtree('baseline_model', ignore_errors=True)
372370
shutil.rmtree('nc_workspace', ignore_errors=True)
373371

374372
def test_tf_distributed_pruning(self):

test/pruning/test_tensorflow_pruning.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def resnet_v2(input_shape, depth, num_classes=10):
223223
n = 1
224224
depth = n * 9 + 2
225225

226-
def train():
226+
def train(dst_path):
227227
# Load the CIFAR10 data.
228228
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
229229

@@ -269,7 +269,7 @@ def train():
269269
scores = model.evaluate(x_test, y_test, verbose=1)
270270
print('Test loss:', scores[0])
271271
print('Test accuracy:', scores[1])
272-
model.save("baseline_model")
272+
model.save(dst_path)
273273

274274
class TrainDataset(object):
275275
def __init__(self):
@@ -322,16 +322,17 @@ def __getitem__(self, idx):
322322
return self.test_images[idx], self.test_labels[idx]
323323

324324
class TestTensorflowPruning(unittest.TestCase):
325+
dst_path = '/tmp/.neural_compressor/inc_ut/resnet_v2/baseline_model'
325326
@classmethod
326327
def setUpClass(self):
327-
build_fake_yaml()
328-
cmd = 'cp -r /home/tensorflow/inc_ut/resnet_v2/baseline_model ./'
329-
os.popen(cmd).readlines()
328+
build_fake_yaml()
329+
if not os.path.exists(self.dst_path):
330+
print("resnet_v2 baseline_model doesn't exist")
331+
return unittest.skip("resnet_v2 baseline_model doesn't exist")(TestTensorflowPruning)
330332

331333
@classmethod
332334
def tearDownClass(self):
333335
os.remove('fake_yaml.yaml')
334-
shutil.rmtree('baseline_model',ignore_errors=True)
335336
shutil.rmtree('nc_workspace',ignore_errors=True)
336337

337338
@unittest.skipIf(tensorflow.version.VERSION < '2.3.0', "Keras model need tensorflow version >= 2.3.0, so the case is skipped")
@@ -387,7 +388,7 @@ def test_tensorflow_pruning(self):
387388
prune = Pruning("./fake_yaml.yaml")
388389
prune.train_dataloader = common.DataLoader(TrainDataset(), batch_size=32)
389390
prune.eval_dataloader = common.DataLoader(EvalDataset(), batch_size=32)
390-
prune.model = './baseline_model'
391+
prune.model = self.dst_path
391392
pruned_model = prune()
392393
stats, sparsity = pruned_model.report_sparsity()
393394
logger.info(stats)

0 commit comments

Comments
 (0)