@@ -1323,13 +1323,22 @@ def calib_func(model):
1323
1323
1324
1324
sq = TorchSmoothQuant (model , example_inputs = example_input , q_func = calib_func )
1325
1325
sq .transform (alpha = 0.5 , folding = False )
1326
- self .assertTrue (isinstance (model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj , SQLinearWrapper ))
1327
- self .assertTrue (
1328
- isinstance (
1329
- model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj .sq_linear .lora_A .default ,
1330
- SQLinearWrapper ,
1331
- )
1332
- ) # Linear in Linear
1326
+ decoder = model .base_model .model .model .decoder
1327
+ if Version (peft .__version__ ) < Version ("0.7.0" ):
1328
+ self .assertTrue (isinstance (decoder .layers [0 ].self_attn .v_proj , SQLinearWrapper ))
1329
+ self .assertTrue (
1330
+ isinstance (
1331
+ decoder .layers [0 ].self_attn .v_proj .sq_linear .lora_A .default ,
1332
+ SQLinearWrapper ,
1333
+ )
1334
+ ) # Linear in Linear
1335
+ else :
1336
+ self .assertTrue (
1337
+ isinstance (
1338
+ decoder .layers [0 ].self_attn .v_proj .lora_A .default ,
1339
+ SQLinearWrapper ,
1340
+ )
1341
+ ) # Linear in Linear
1333
1342
self .assertTrue (
1334
1343
isinstance (model .base_model .model .score .original_module , torch .nn .Linear )
1335
1344
) # Linear that is not called in calibration
@@ -1348,13 +1357,22 @@ def calib_func(model):
1348
1357
# folding=False
1349
1358
sq = TorchSmoothQuant (model , example_inputs = example_input , q_func = calib_func )
1350
1359
sq .transform (alpha = "auto" , folding = False )
1351
- self .assertTrue (isinstance (model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj , SQLinearWrapper ))
1352
- self .assertTrue (
1353
- isinstance (
1354
- model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj .sq_linear .lora_A .default ,
1355
- SQLinearWrapper ,
1356
- )
1357
- ) # Linear in Linear
1360
+ decoder = model .base_model .model .model .decoder
1361
+ if Version (peft .__version__ ) < Version ("0.7.0" ):
1362
+ self .assertTrue (isinstance (decoder .layers [0 ].self_attn .v_proj , SQLinearWrapper ))
1363
+ self .assertTrue (
1364
+ isinstance (
1365
+ decoder .layers [0 ].self_attn .v_proj .sq_linear .lora_A .default ,
1366
+ SQLinearWrapper ,
1367
+ )
1368
+ ) # Linear in Linear
1369
+ else :
1370
+ self .assertTrue (
1371
+ isinstance (
1372
+ decoder .layers [0 ].self_attn .v_proj .lora_A .default ,
1373
+ SQLinearWrapper ,
1374
+ )
1375
+ ) # Linear in Linear
1358
1376
self .assertTrue (
1359
1377
isinstance (model .base_model .model .score .original_module , torch .nn .Linear )
1360
1378
) # Linear that is not called in calibration
@@ -1369,7 +1387,16 @@ def calib_func(model):
1369
1387
1370
1388
sq = TorchSmoothQuant (model , example_inputs = example_input , q_func = calib_func )
1371
1389
sq .transform (alpha = "auto" , folding = True )
1372
- self .assertTrue (isinstance (model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj , torch .nn .Linear ))
1390
+ if Version (peft .__version__ ) < Version ("0.7.0" ):
1391
+ self .assertTrue (
1392
+ isinstance (model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj , torch .nn .Linear )
1393
+ )
1394
+ else :
1395
+ self .assertTrue (
1396
+ isinstance (
1397
+ model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj , peft .tuners .lora .layer .Linear
1398
+ )
1399
+ )
1373
1400
self .assertTrue (
1374
1401
isinstance (model .base_model .model .model .decoder .layers [0 ].self_attn .v_proj .lora_A .default , torch .nn .Linear )
1375
1402
) # Linear in Linear
@@ -1401,13 +1428,21 @@ def calib_func(model):
1401
1428
calib_func = calib_func ,
1402
1429
)
1403
1430
decoder = q_model .model .base_model .model .model .decoder
1404
- self .assertTrue (isinstance (decoder .layers [0 ].self_attn .v_proj , SQLinearWrapper ))
1405
- self .assertTrue (
1406
- isinstance (
1407
- decoder .layers [0 ].self_attn .v_proj .sq_linear .module .lora_A .default ,
1408
- SQLinearWrapper ,
1409
- )
1410
- ) # Linear in Linear
1431
+ if Version (peft .__version__ ) < Version ("0.7.0" ):
1432
+ self .assertTrue (isinstance (decoder .layers [0 ].self_attn .v_proj , SQLinearWrapper ))
1433
+ self .assertTrue (
1434
+ isinstance (
1435
+ decoder .layers [0 ].self_attn .v_proj .sq_linear .lora_A .default ,
1436
+ SQLinearWrapper ,
1437
+ )
1438
+ ) # Linear in Linear
1439
+ else :
1440
+ self .assertTrue (
1441
+ isinstance (
1442
+ decoder .layers [0 ].self_attn .v_proj .lora_A .default ,
1443
+ SQLinearWrapper ,
1444
+ )
1445
+ ) # Linear in Linear
1411
1446
self .assertTrue (
1412
1447
isinstance (q_model .model .base_model .model .score .original_module , torch .nn .Linear )
1413
1448
) # Linear that is not called in calibration
0 commit comments