@@ -73,15 +73,12 @@ def test_cuda_kernels_vs_native(self):
73
73
74
74
for quant_type in test_quant_types :
75
75
qtype = getattr (gguf .GGMLQuantizationType , quant_type )
76
- block_size , type_size = gguf .GGML_QUANT_SIZES [qtype ]
77
-
78
76
in_features , out_features = 512 , 512
79
- total_elements = in_features * out_features
80
- n_blocks = total_elements // block_size
81
- weight_bytes = n_blocks * type_size
82
77
83
78
torch .manual_seed (42 )
84
- weight_data = torch .randint (0 , 256 , (weight_bytes ,), dtype = torch .uint8 , device = torch_device )
79
+ float_weight = torch .randn (out_features , in_features , dtype = torch .float32 )
80
+ quantized_data = gguf .quants .quantize (float_weight .numpy (), qtype )
81
+ weight_data = torch .from_numpy (quantized_data ).to (device = torch_device )
85
82
weight = GGUFParameter (weight_data , quant_type = qtype )
86
83
87
84
x = torch .randn (test_shape , dtype = compute_dtype , device = torch_device )
@@ -95,9 +92,9 @@ def test_cuda_kernels_vs_native(self):
95
92
output_native = linear .forward_native (x )
96
93
output_cuda = linear .forward_cuda (x )
97
94
98
- # Compare outputs
99
- max_diff = torch . abs ( output_cuda - output_native ). max ()
100
- assert max_diff < 1e-4 , "GGUF CUDA Kernel Output is different from Native Output"
95
+ assert torch . allclose ( output_native , output_cuda , 1e-2 ), (
96
+ f"GGUF CUDA Kernel Output is different from Native Output for { quant_type } "
97
+ )
101
98
102
99
103
100
@nightly
0 commit comments