Skip to content

Commit d9333aa

Browse files
Improvement for torch.compile support on Params4bit (#1673)
1 parent 11df723 commit d9333aa

File tree

2 files changed

+1
-11
lines changed

2 files changed

+1
-11
lines changed

bitsandbytes/nn/modules.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,6 @@ def from_prequantized(
291291

292292
return self
293293

294-
@classmethod
295-
def __torch_function__(cls, func, types, args=(), kwargs=None):
296-
if kwargs is None:
297-
kwargs = {}
298-
with torch._C.DisableTorchFunctionSubclass():
299-
return func(*args, **kwargs)
300-
301294
def _quantize(self, device):
302295
w = self.data.contiguous().to(device)
303296
w_4bit, quant_state = bnb.functional.quantize_4bit(

tests/test_linear4bit.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,7 @@ def test_params4bit_real_serialization(device, quant_type, blocksize, compress_s
270270
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
271271
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
272272
def test_linear4bit_torch_compile(device, quant_type, compute_dtype, compress_statistics, bias, fullgraph, mode):
273-
if device == "cpu" and quant_type == "fp4":
274-
pytest.skip("FP4 is not supported for CPU")
275-
276-
if fullgraph and torch.__version__ < (2, 8):
273+
if fullgraph and torch.__version__ < (2, 8, 0, "dev"):
277274
pytest.skip("fullgraph mode requires torch 2.8 or higher")
278275

279276
if device == "cuda" and platform.system() == "Windows":

0 commit comments

Comments
 (0)