Skip to content

Commit c741b89

Browse files
authored
fix peft issue in ut (#1450)
Signed-off-by: Xin He <xin3.he@intel.com>
1 parent 6c0ac59 commit c741b89

File tree

1 file changed

+57
-22
lines changed

1 file changed

+57
-22
lines changed

test/algorithm/test_smooth_quant.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,13 +1323,22 @@ def calib_func(model):
13231323

13241324
sq = TorchSmoothQuant(model, example_inputs=example_input, q_func=calib_func)
13251325
sq.transform(alpha=0.5, folding=False)
1326-
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, SQLinearWrapper))
1327-
self.assertTrue(
1328-
isinstance(
1329-
model.base_model.model.model.decoder.layers[0].self_attn.v_proj.sq_linear.lora_A.default,
1330-
SQLinearWrapper,
1331-
)
1332-
) # Linear in Linear
1326+
decoder = model.base_model.model.model.decoder
1327+
if Version(peft.__version__) < Version("0.7.0"):
1328+
self.assertTrue(isinstance(decoder.layers[0].self_attn.v_proj, SQLinearWrapper))
1329+
self.assertTrue(
1330+
isinstance(
1331+
decoder.layers[0].self_attn.v_proj.sq_linear.lora_A.default,
1332+
SQLinearWrapper,
1333+
)
1334+
) # Linear in Linear
1335+
else:
1336+
self.assertTrue(
1337+
isinstance(
1338+
decoder.layers[0].self_attn.v_proj.lora_A.default,
1339+
SQLinearWrapper,
1340+
)
1341+
) # Linear in Linear
13331342
self.assertTrue(
13341343
isinstance(model.base_model.model.score.original_module, torch.nn.Linear)
13351344
) # Linear that is not called in calibration
@@ -1348,13 +1357,22 @@ def calib_func(model):
13481357
# folding=False
13491358
sq = TorchSmoothQuant(model, example_inputs=example_input, q_func=calib_func)
13501359
sq.transform(alpha="auto", folding=False)
1351-
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, SQLinearWrapper))
1352-
self.assertTrue(
1353-
isinstance(
1354-
model.base_model.model.model.decoder.layers[0].self_attn.v_proj.sq_linear.lora_A.default,
1355-
SQLinearWrapper,
1356-
)
1357-
) # Linear in Linear
1360+
decoder = model.base_model.model.model.decoder
1361+
if Version(peft.__version__) < Version("0.7.0"):
1362+
self.assertTrue(isinstance(decoder.layers[0].self_attn.v_proj, SQLinearWrapper))
1363+
self.assertTrue(
1364+
isinstance(
1365+
decoder.layers[0].self_attn.v_proj.sq_linear.lora_A.default,
1366+
SQLinearWrapper,
1367+
)
1368+
) # Linear in Linear
1369+
else:
1370+
self.assertTrue(
1371+
isinstance(
1372+
decoder.layers[0].self_attn.v_proj.lora_A.default,
1373+
SQLinearWrapper,
1374+
)
1375+
) # Linear in Linear
13581376
self.assertTrue(
13591377
isinstance(model.base_model.model.score.original_module, torch.nn.Linear)
13601378
) # Linear that is not called in calibration
@@ -1369,7 +1387,16 @@ def calib_func(model):
13691387

13701388
sq = TorchSmoothQuant(model, example_inputs=example_input, q_func=calib_func)
13711389
sq.transform(alpha="auto", folding=True)
1372-
self.assertTrue(isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear))
1390+
if Version(peft.__version__) < Version("0.7.0"):
1391+
self.assertTrue(
1392+
isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear)
1393+
)
1394+
else:
1395+
self.assertTrue(
1396+
isinstance(
1397+
model.base_model.model.model.decoder.layers[0].self_attn.v_proj, peft.tuners.lora.layer.Linear
1398+
)
1399+
)
13731400
self.assertTrue(
13741401
isinstance(model.base_model.model.model.decoder.layers[0].self_attn.v_proj.lora_A.default, torch.nn.Linear)
13751402
) # Linear in Linear
@@ -1401,13 +1428,21 @@ def calib_func(model):
14011428
calib_func=calib_func,
14021429
)
14031430
decoder = q_model.model.base_model.model.model.decoder
1404-
self.assertTrue(isinstance(decoder.layers[0].self_attn.v_proj, SQLinearWrapper))
1405-
self.assertTrue(
1406-
isinstance(
1407-
decoder.layers[0].self_attn.v_proj.sq_linear.module.lora_A.default,
1408-
SQLinearWrapper,
1409-
)
1410-
) # Linear in Linear
1431+
if Version(peft.__version__) < Version("0.7.0"):
1432+
self.assertTrue(isinstance(decoder.layers[0].self_attn.v_proj, SQLinearWrapper))
1433+
self.assertTrue(
1434+
isinstance(
1435+
decoder.layers[0].self_attn.v_proj.sq_linear.lora_A.default,
1436+
SQLinearWrapper,
1437+
)
1438+
) # Linear in Linear
1439+
else:
1440+
self.assertTrue(
1441+
isinstance(
1442+
decoder.layers[0].self_attn.v_proj.lora_A.default,
1443+
SQLinearWrapper,
1444+
)
1445+
) # Linear in Linear
14111446
self.assertTrue(
14121447
isinstance(q_model.model.base_model.model.score.original_module, torch.nn.Linear)
14131448
) # Linear that is not called in calibration

0 commit comments

Comments
 (0)