Skip to content

Commit e9f3605

Browse files
Fix Linear4bit warnings/test for compute dtype
1 parent 812ef06 commit e9f3605

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,14 +455,14 @@ def set_compute_type(self, x):
455455
self.compute_dtype = x.dtype
456456
elif x.dtype == torch.float16:
457457
# we take the compoute dtype passed into the layer
458-
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
458+
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
459459
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
460460
# warn the user about this
461461
warnings.warn(
462462
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.",
463463
)
464464
warnings.filterwarnings("ignore", message=".*inference.")
465-
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
465+
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
466466
warnings.warn(
467467
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.",
468468
)

tests/test_modules.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -440,31 +440,23 @@ def test_4bit_linear_warnings(device):
440440
dim1 = 64
441441

442442
with pytest.warns(UserWarning, match=r"inference or training"):
443-
net = nn.Sequential(
444-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
445-
)
443+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
446444
net = net.to(device)
447445
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
448446
net(inp)
449447
with pytest.warns(UserWarning, match=r"inference."):
450-
net = nn.Sequential(
451-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
452-
)
448+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
453449
net = net.to(device)
454450
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
455451
net(inp)
456452

457453
with pytest.warns(UserWarning) as record:
458-
net = nn.Sequential(
459-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
460-
)
454+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
461455
net = net.to(device)
462456
inp = torch.rand(10, dim1, device=device, dtype=torch.float16)
463457
net(inp)
464458

465-
net = nn.Sequential(
466-
*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4", compute_dtype=torch.float32) for i in range(10)]
467-
)
459+
net = nn.Sequential(*[bnb.nn.Linear4bit(dim1, dim1, quant_type="nf4") for i in range(10)])
468460
net = net.to(device)
469461
inp = torch.rand(1, dim1, device=device, dtype=torch.float16)
470462
net(inp)

0 commit comments

Comments
 (0)