@@ -40,7 +40,7 @@ def linear_8bit_lt_layer():
40
40
def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda (linear_8bit_lt_layer : InvokeLinear8bitLt ):
41
41
"""Test CustomInvokeLinear8bitLt inference with all weights on the GPU."""
42
42
# Run inference on the original layer.
43
- x = torch .randn (10 , 32 ).to ("cuda" )
43
+ x = torch .randn (1 , 32 ).to ("cuda" )
44
44
y_quantized = linear_8bit_lt_layer (x )
45
45
46
46
# Wrap the InvokeLinear8bitLt layer in a CustomInvokeLinear8bitLt layer, and run inference on it.
@@ -54,7 +54,7 @@ def test_custom_invoke_linear_8bit_lt_all_weights_on_cuda(linear_8bit_lt_layer:
54
54
def test_custom_invoke_linear_8bit_lt_all_weights_on_cpu (linear_8bit_lt_layer : InvokeLinear8bitLt ):
55
55
"""Test CustomInvokeLinear8bitLt inference with all weights on the CPU (streaming to the GPU)."""
56
56
# Run inference on the original layer.
57
- x = torch .randn (10 , 32 ).to ("cuda" )
57
+ x = torch .randn (1 , 32 ).to ("cuda" )
58
58
y_quantized = linear_8bit_lt_layer (x )
59
59
60
60
# Copy the state dict to the CPU and reload it.
@@ -98,7 +98,7 @@ def linear_nf4_layer():
98
98
def test_custom_invoke_linear_nf4_all_weights_on_cuda (linear_nf4_layer : InvokeLinearNF4 ):
99
99
"""Test CustomInvokeLinearNF4 inference with all weights on the GPU."""
100
100
# Run inference on the original layer.
101
- x = torch .randn (10 , 32 ).to ("cuda" )
101
+ x = torch .randn (1 , 32 ).to ("cuda" )
102
102
y_quantized = linear_nf4_layer (x )
103
103
104
104
# Wrap the InvokeLinearNF4 layer in a CustomInvokeLinearNF4 layer, and run inference on it.
@@ -112,7 +112,7 @@ def test_custom_invoke_linear_nf4_all_weights_on_cuda(linear_nf4_layer: InvokeLi
112
112
def test_custom_invoke_linear_nf4_all_weights_on_cpu (linear_nf4_layer : InvokeLinearNF4 ):
113
113
"""Test CustomInvokeLinearNF4 inference with all weights on the CPU (streaming to the GPU)."""
114
114
# Run inference on the original layer.
115
- x = torch .randn (10 , 32 ).to (device = "cuda" )
115
+ x = torch .randn (1 , 32 ).to (device = "cuda" )
116
116
y_quantized = linear_nf4_layer (x )
117
117
118
118
# Copy the state dict to the CPU and reload it.
0 commit comments