Skip to content

Commit 7214d49

Browse files
committed
Workaround a weird quirk of QuantState.to() and add a unit test to exercise it.
1 parent a83a999 commit 7214d49

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

invokeai/backend/model_manager/load/model_cache/torch_module_autocast/custom_invoke_linear_nf4.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3737
weight = cast_to_device(self.weight, x.device)
3838
self.weight.quant_state = old_quant_state
3939

40+
# For some reason, the quant_state.to(...) implementation fails to cast the quant_state.code field. We do this
41+
# manually here.
42+
weight.quant_state.code = cast_to_device(weight.quant_state.code, x.device)
43+
4044
bias = cast_to_device(self.bias, x.device)
4145
return bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state).to(inp_dtype)

tests/backend/model_manager/load/model_cache/torch_module_autocast/test_autocast_modules.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,11 @@ def linear_nf4_layer():
8181

8282
torch.manual_seed(1)
8383

84-
orig_layer = torch.nn.Linear(32, 64)
84+
orig_layer = torch.nn.Linear(64, 16)
8585
orig_layer_state_dict = orig_layer.state_dict()
8686

8787
# Prepare a quantized InvokeLinearNF4 layer.
88-
quantized_layer = InvokeLinearNF4(input_features=32, output_features=64)
88+
quantized_layer = InvokeLinearNF4(input_features=64, output_features=16)
8989
quantized_layer.load_state_dict(orig_layer_state_dict)
9090
quantized_layer.to("cuda")
9191

@@ -98,7 +98,7 @@ def linear_nf4_layer():
9898
def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLinearNF4):
9999
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
100100
# Run inference on the original layer.
101-
x = torch.randn(1, 32).to("cuda")
101+
x = torch.randn(1, 64).to("cuda")
102102
y_quantized = linear_nf4_layer(x)
103103

104104
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
@@ -109,10 +109,13 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi
109109
assert torch.allclose(y_quantized, y_custom, atol=1e-5)
110110

111111

112-
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4):
112+
# We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
113+
# input dimension, and this has caused issues in the past.
114+
@pytest.mark.parametrize("input_dim_0", [1, 2])
115+
def test_custom_invoke_linear_nf4_all_weights_on_cpu(linear_nf4_layer: InvokeLinearNF4, input_dim_0: int):
113116
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
114117
# Run inference on the original layer.
115-
x = torch.randn(1, 32).to(device="cuda")
118+
x = torch.randn(input_dim_0, 64).to(device="cuda")
116119
y_quantized = linear_nf4_layer(x)
117120

118121
# Copy the state dict to the CPU and reload it.

0 commit comments

Comments
 (0)