@@ -81,11 +81,11 @@ def linear_nf4_layer():
81
81
82
82
torch .manual_seed (1 )
83
83
84
- orig_layer = torch .nn .Linear (32 , 64 )
84
+ orig_layer = torch .nn .Linear (64 , 16 )
85
85
orig_layer_state_dict = orig_layer .state_dict ()
86
86
87
87
# Prepare a quantized InvokeLinearNF4 layer.
88
- quantized_layer = InvokeLinearNF4 (input_features = 32 , output_features = 64 )
88
+ quantized_layer = InvokeLinearNF4 (input_features = 64 , output_features = 16 )
89
89
quantized_layer .load_state_dict (orig_layer_state_dict )
90
90
quantized_layer .to ("cuda" )
91
91
@@ -98,7 +98,7 @@ def linear_nf4_layer():
98
98
def test_custom_invoke_linear_nf4_all_weights_on_cuda (linear_nf4_layer : InvokeLinearNF4 ):
99
99
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
100
100
# Run inference on the original layer.
101
- x = torch .randn (1 , 32 ).to ("cuda" )
101
+ x = torch .randn (1 , 64 ).to ("cuda" )
102
102
y_quantized = linear_nf4_layer (x )
103
103
104
104
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
@@ -109,10 +109,13 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi
109
109
assert torch .allclose (y_quantized , y_custom , atol = 1e-5 )
110
110
111
111
112
- def test_custom_invoke_linear_nf4_all_weights_on_cpu (linear_nf4_layer : InvokeLinearNF4 ):
112
+ # We run with two different input dimensions, because the NF4 layer follows a different code path depending on the
113
+ # input dimension, and this has caused issues in the past.
114
+ @pytest .mark .parametrize ("input_dim_0" , [1 , 2 ])
115
+ def test_custom_invoke_linear_nf4_all_weights_on_cpu (linear_nf4_layer : InvokeLinearNF4 , input_dim_0 : int ):
113
116
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
114
117
# Run inference on the original layer.
115
- x = torch .randn (1 , 32 ).to (device = "cuda" )
118
+ x = torch .randn (input_dim_0 , 64 ).to (device = "cuda" )
116
119
y_quantized = linear_nf4_layer (x )
117
120
118
121
# Copy the state dict to the CPU and reload it.
0 commit comments