@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
440
440
dim1 = 64
441
441
442
442
with pytest .warns (UserWarning , match = r"inference or training" ):
443
- net = nn .Sequential (
444
- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
445
- )
443
+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
446
444
net = net .to (device )
447
445
inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
448
446
net (inp )
449
447
with pytest .warns (UserWarning , match = r"inference." ):
450
- net = nn .Sequential (
451
- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
452
- )
448
+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
453
449
net = net .to (device )
454
450
inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
455
451
net (inp )
456
452
457
453
with pytest .warns (UserWarning ) as record :
458
- net = nn .Sequential (
459
- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
460
- )
454
+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
461
455
net = net .to (device )
462
456
inp = torch .rand (10 , dim1 , device = device , dtype = torch .float16 )
463
457
net (inp )
464
458
465
- net = nn .Sequential (
466
- * [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" , compute_dtype = torch .float32 ) for i in range (10 )]
467
- )
459
+ net = nn .Sequential (* [bnb .nn .Linear4bit (dim1 , dim1 , quant_type = "nf4" ) for i in range (10 )])
468
460
net = net .to (device )
469
461
inp = torch .rand (1 , dim1 , device = device , dtype = torch .float16 )
470
462
net (inp )
0 commit comments