Skip to content

Commit f8a6acc

Browse files
committed
Fix bitsandbytes imports to avoid ImportErrors on MacOS.
1 parent f8ab414 commit f8a6acc

File tree

8 files changed

+121
-84
lines changed

8 files changed

+121
-84
lines changed
Lines changed: 1 addition & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,13 @@
1-
import copy
2-
from typing import TypeVar
3-
4-
import bitsandbytes as bnb
51
import torch
62

7-
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
8-
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
9-
10-
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
3+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
114

125
# This file contains custom torch.nn.Module classes that support streaming of weights to the target device.
136
# Each class sub-classes the original module type that is is replacing, so the following properties are preserved:
147
# - isinstance(m, torch.nn.OrginalModule) should still work.
158
# - Patching the weights (e.g. for LoRA) should still work if non-quantized.
169

1710

18-
def cast_to_device(t: T, to_device: torch.device) -> T:
19-
if t is None:
20-
return t
21-
22-
if t.device.type != to_device.type:
23-
return t.to(to_device)
24-
return t
25-
26-
2711
class CustomLinear(torch.nn.Linear):
2812
def forward(self, input: torch.Tensor) -> torch.Tensor:
2913
weight = cast_to_device(self.weight, input.device)
@@ -64,59 +48,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
6448
self.scale_grad_by_freq,
6549
self.sparse,
6650
)
67-
68-
69-
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
70-
def forward(self, x: torch.Tensor) -> torch.Tensor:
71-
matmul_state = bnb.MatmulLtState()
72-
matmul_state.threshold = self.state.threshold
73-
matmul_state.has_fp16_weights = self.state.has_fp16_weights
74-
matmul_state.use_pool = self.state.use_pool
75-
matmul_state.is_training = self.training
76-
# The underlying InvokeInt8Params weight must already be quantized.
77-
assert self.weight.CB is not None
78-
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
79-
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
80-
81-
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
82-
if self.bias is not None and self.bias.dtype != x.dtype:
83-
self.bias.data = self.bias.data.to(x.dtype)
84-
85-
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
86-
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
87-
# on the wrong device.
88-
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
89-
90-
91-
class CustomInvokeLinearNF4(InvokeLinearNF4):
92-
def forward(self, x: torch.Tensor) -> torch.Tensor:
93-
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
94-
95-
# weights are cast automatically as Int8Params, but the bias has to be cast manually
96-
if self.bias is not None and self.bias.dtype != x.dtype:
97-
self.bias.data = self.bias.data.to(x.dtype)
98-
99-
if not self.compute_type_is_set:
100-
self.set_compute_type(x)
101-
self.compute_type_is_set = True
102-
103-
inp_dtype = x.dtype
104-
if self.compute_dtype is not None:
105-
x = x.to(self.compute_dtype)
106-
107-
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
108-
109-
# HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it
110-
# does not follow the tensor semantics of returning a new copy when converting to a different device). This
111-
# means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To
112-
# avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing
113-
# this properly would require more invasive changes to the bitsandbytes library.
114-
115-
# Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting
116-
# to a new device.
117-
old_quant_state = copy.copy(self.weight.quant_state)
118-
weight = cast_to_device(self.weight, x.device)
119-
self.weight.quant_state = old_quant_state
120-
121-
bias = cast_to_device(self.bias, x.device)
122-
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from typing import TypeVar
2+
3+
import torch
4+
5+
T = TypeVar("T", torch.Tensor, None, torch.Tensor | None)
6+
7+
8+
def cast_to_device(t: T, to_device: torch.device) -> T:
9+
"""Helper function to cast an optional tensor to a target device."""
10+
if t is None:
11+
return t
12+
13+
if t.device.type != to_device.type:
14+
return t.to(to_device)
15+
return t
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import bitsandbytes as bnb
2+
import torch
3+
4+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
5+
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
6+
7+
8+
class CustomInvokeLinear8bitLt(InvokeLinear8bitLt):
9+
def forward(self, x: torch.Tensor) -> torch.Tensor:
10+
matmul_state = bnb.MatmulLtState()
11+
matmul_state.threshold = self.state.threshold
12+
matmul_state.has_fp16_weights = self.state.has_fp16_weights
13+
matmul_state.use_pool = self.state.use_pool
14+
matmul_state.is_training = self.training
15+
# The underlying InvokeInt8Params weight must already be quantized.
16+
assert self.weight.CB is not None
17+
matmul_state.CB = cast_to_device(self.weight.CB, x.device)
18+
matmul_state.SCB = cast_to_device(self.weight.SCB, x.device)
19+
20+
# weights are cast automatically as Int8Params, but the bias has to be cast manually.
21+
if self.bias is not None and self.bias.dtype != x.dtype:
22+
self.bias.data = self.bias.data.to(x.dtype)
23+
24+
# NOTE(ryand): The second parameter should not be needed at all given our expected inference configuration, but
25+
# it's dtype field must be accessible, even though it's not used. We pass in self.weight even though it could be
26+
# on the wrong device.
27+
return bnb.matmul(x, self.weight, bias=cast_to_device(self.bias, x.device), state=matmul_state)
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import copy
2+
3+
import bitsandbytes as bnb
4+
import torch
5+
6+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.cast_to_device import cast_to_device
7+
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
8+
9+
10+
class CustomInvokeLinearNF4(InvokeLinearNF4):
11+
def forward(self, x: torch.Tensor) -> torch.Tensor:
12+
bnb.nn.modules.fix_4bit_weight_quant_state_from_module(self)
13+
14+
# weights are cast automatically as Int8Params, but the bias has to be cast manually
15+
if self.bias is not None and self.bias.dtype != x.dtype:
16+
self.bias.data = self.bias.data.to(x.dtype)
17+
18+
if not self.compute_type_is_set:
19+
self.set_compute_type(x)
20+
self.compute_type_is_set = True
21+
22+
inp_dtype = x.dtype
23+
if self.compute_dtype is not None:
24+
x = x.to(self.compute_dtype)
25+
26+
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
27+
28+
# HACK(ryand): Casting self.weight to the device also casts the self.weight.quant_state in-place (i.e. it
29+
# does not follow the tensor semantics of returning a new copy when converting to a different device). This
30+
# means that quant_state elements that started on the CPU would be left on the GPU, which we don't want. To
31+
# avoid this side effect we make a shallow copy of the original quant_state so that we can restore it. Fixing
32+
# this properly would require more invasive changes to the bitsandbytes library.
33+
34+
# Make a shallow copy of the quant_state so that we can undo the in-place modification that occurs when casting
35+
# to a new device.
36+
old_quant_state = copy.copy(self.weight.quant_state)
37+
weight = cast_to_device(self.weight, x.device)
38+
self.weight.quant_state = old_quant_state
39+
40+
bias = cast_to_device(self.bias, x.device)
41+
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/torch_module_autocast.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,33 @@
55
CustomConv2d,
66
CustomEmbedding,
77
CustomGroupNorm,
8-
CustomInvokeLinear8bitLt,
98
CustomLinear,
109
)
11-
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
1210

1311
AUTOCAST_MODULE_TYPE_MAPPING: dict[type[torch.nn.Module], type[torch.nn.Module]] = {
1412
torch.nn.Linear: CustomLinear,
1513
torch.nn.Conv1d: CustomConv1d,
1614
torch.nn.Conv2d: CustomConv2d,
1715
torch.nn.GroupNorm: CustomGroupNorm,
1816
torch.nn.Embedding: CustomEmbedding,
19-
InvokeLinear8bitLt: CustomInvokeLinear8bitLt,
2017
}
2118

19+
try:
20+
# These dependencies are not expected to be present on MacOS.
21+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
22+
CustomInvokeLinear8bitLt,
23+
)
24+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
25+
CustomInvokeLinearNF4,
26+
)
27+
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
28+
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
29+
30+
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinear8bitLt] = CustomInvokeLinear8bitLt
31+
AUTOCAST_MODULE_TYPE_MAPPING[InvokeLinearNF4] = CustomInvokeLinearNF4
32+
except ImportError:
33+
pass
34+
2235

2336
def apply_custom_layers_to_model(model: torch.nn.Module):
2437
def apply_custom_layers(module: torch.nn.Module):

tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import pytest
22
import torch
33

4-
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import (
5-
CustomInvokeLinear8bitLt,
6-
CustomInvokeLinearNF4,
7-
)
8-
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
9-
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
4+
if not torch.cuda.is_available():
5+
pytest.skip("CUDA is not available", allow_module_level=True)
6+
else:
7+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_8_bit_lt import (
8+
CustomInvokeLinear8bitLt,
9+
)
10+
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.custom_invoke_linear_nf4 import (
11+
CustomInvokeLinearNF4,
12+
)
13+
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
14+
from invokeai.backend.quantization.bnb_nf4 import InvokeLinearNF4
1015

1116

1217
@pytest.fixture

tests/backend/model_manager/load/model_cache/torch_module_autocast/test_torch_module_autocast.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,14 @@
66
apply_custom_layers_to_model,
77
remove_custom_layers_from_model,
88
)
9-
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8
109
from tests.backend.quantization.gguf.test_ggml_tensor import quantize_tensor
1110

11+
try:
12+
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt, quantize_model_llm_int8
13+
except ImportError:
14+
# This is expected to fail on MacOS
15+
pass
16+
1217
cuda_and_mps = pytest.mark.parametrize(
1318
"device",
1419
[

tests/backend/quantization/test_bnb_llm_int8.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import pytest
22
import torch
33

4-
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
4+
try:
5+
from invokeai.backend.quantization.bnb_llm_int8 import InvokeLinear8bitLt
6+
except ImportError:
7+
pass
58

69

710
def test_invoke_linear_8bit_lt_quantization():

0 commit comments

Comments
 (0)