4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import tempfile
7
8
import unittest
8
9
9
10
import torch
10
11
from torch .testing ._internal .common_utils import (
11
12
TestCase ,
13
+ instantiate_parametrized_tests ,
14
+ parametrize ,
12
15
run_tests ,
13
16
)
14
17
18
+ from torchao .float8 .config import e4m3_dtype
15
19
from torchao .quantization import (
16
20
FbgemmConfig ,
17
21
quantize_ ,
23
27
is_sm_at_least_90 ,
24
28
)
25
29
30
+ if TORCH_VERSION_AT_LEAST_2_8 :
31
+ BF16_ACT_CONFIG = FbgemmConfig (
32
+ input_dtype = torch .bfloat16 ,
33
+ weight_dtype = torch .int4 ,
34
+ output_dtype = torch .bfloat16 ,
35
+ block_size = [1 , 128 ],
36
+ preshuffle = True ,
37
+ )
38
+
39
+ BF16_ACT_BMM_CONFIG = FbgemmConfig (
40
+ input_dtype = torch .bfloat16 ,
41
+ weight_dtype = torch .int4 ,
42
+ output_dtype = torch .bfloat16 ,
43
+ block_size = [1 , 1 , 128 ],
44
+ preshuffle = True ,
45
+ )
46
+
47
+ FP8_ACT_CONFIG = FbgemmConfig (
48
+ input_dtype = e4m3_dtype ,
49
+ weight_dtype = torch .int4 ,
50
+ output_dtype = torch .bfloat16 ,
51
+ block_size = [1 , 128 ],
52
+ preshuffle = True ,
53
+ )
54
+
55
+ FP8_ACT_BMM_CONFIG = FbgemmConfig (
56
+ input_dtype = e4m3_dtype ,
57
+ weight_dtype = torch .int4 ,
58
+ output_dtype = torch .bfloat16 ,
59
+ block_size = [1 , 1 , 128 ],
60
+ preshuffle = True ,
61
+ )
62
+
63
+ else :
64
+ BF16_ACT_CONFIG = None
65
+ BF16_ACT_BMM_CONFIG = None
66
+ FP8_ACT_CONFIG = None
67
+ FP8_ACT_BMM_CONFIG = None
68
+
26
69
27
70
@unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_8 , "Need pytorch 2.8+" )
28
71
@unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
29
72
@unittest .skipIf (not is_sm_at_least_90 (), "Nedd sm90+" )
30
73
@unittest .skipIf (
31
74
not _is_fbgemm_genai_gpu_available (), "Requires fbgemm-gpu-genai >= 1.2.0"
32
75
)
33
- class TestInt4GroupwisePreshuffleTensor (TestCase ):
76
+ class TestInt4PreshuffledTensor (TestCase ):
34
77
def setUp (self ):
35
- self .config = FbgemmConfig (
36
- input_dtype = torch .bfloat16 ,
37
- weight_dtype = torch .int4 ,
38
- output_dtype = torch .bfloat16 ,
39
- block_size = [1 , 128 ],
40
- preshuffle = True ,
41
- )
42
- self .bmm_config = FbgemmConfig (
43
- input_dtype = torch .bfloat16 ,
44
- weight_dtype = torch .int4 ,
45
- output_dtype = torch .bfloat16 ,
46
- block_size = [1 , 1 , 128 ],
47
- preshuffle = True ,
48
- )
49
78
self .GPU_DEVICES = ["cuda" ] if torch .cuda .is_available () else []
50
79
51
- def test_linear (self ):
80
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
81
+ def test_linear (self , config ):
52
82
dtype = torch .bfloat16
53
83
device = "cuda"
54
84
input = torch .randn (1 , 128 , dtype = dtype , device = device )
55
85
linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
56
86
original = linear (input )
57
- quantize_ (linear , self . config )
87
+ quantize_ (linear , config )
58
88
quantized = linear (input )
59
89
self .assertTrue (compute_error (original , quantized ) > 20 )
60
90
61
- def test_bmm (self ):
91
+ # Note: this order will error out: `Got bad cuda status: an illegal memory access was encountered at line: 449`
92
+ # @parametrize("bmm_config", [BF16_ACT_BMM_CONFIG, FP8_ACT_BMM_CONFIG])
93
+ @parametrize ("bmm_config" , [FP8_ACT_BMM_CONFIG , BF16_ACT_BMM_CONFIG ])
94
+ def test_bmm (self , bmm_config ):
62
95
class M (torch .nn .Module ):
63
96
def __init__ (self , weight ):
64
97
super ().__init__ ()
@@ -74,32 +107,46 @@ def forward(self, x):
74
107
m = M (weight ).eval ()
75
108
original = m (input )
76
109
m .weight = torch .nn .Parameter (m .weight .transpose (1 , 2 ).contiguous ())
77
- quantize_ (m , self . bmm_config , filter_fn = lambda x , fqn : True )
110
+ quantize_ (m , bmm_config , filter_fn = lambda x , fqn : True )
78
111
quantized = m (input )
79
112
self .assertTrue (compute_error (original , quantized ) > 18 )
80
113
81
- def test_to_device (self ):
114
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
115
+ def test_to_device (self , config ):
82
116
for device in self .GPU_DEVICES :
83
117
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
84
- quantize_ (linear , self . config )
118
+ quantize_ (linear , config )
85
119
linear .to (device )
86
120
87
121
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
88
- quantize_ (linear , self . config )
122
+ quantize_ (linear , config )
89
123
linear .to (device = device )
90
124
91
125
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
92
- quantize_ (linear , self . config )
126
+ quantize_ (linear , config )
93
127
linear .to (device )
94
128
95
- def test_module_path (self ):
129
+ @parametrize ("config" , [BF16_ACT_CONFIG , FP8_ACT_CONFIG ])
130
+ def test_module_path (self , config ):
96
131
linear = torch .nn .Linear (128 , 256 , dtype = torch .bfloat16 )
97
- quantize_ (linear , self . config )
132
+ quantize_ (linear , config )
98
133
self .assertEqual (
99
134
str (type (linear .weight )),
100
- "<class 'torchao.quantization.Int4GroupwisePreshuffleTensor '>" ,
135
+ "<class 'torchao.quantization.Int4PreshuffledTensor '>" ,
101
136
)
102
137
138
+ with tempfile .NamedTemporaryFile () as f :
139
+ torch .save (linear .state_dict (), f )
140
+ f .seek (0 )
141
+ state_dict = torch .load (f )
142
+ self .assertEqual (
143
+ str (type (state_dict ["weight" ])),
144
+ "<class 'torchao.quantization.Int4PreshuffledTensor'>" ,
145
+ )
146
+
147
+
148
+ instantiate_parametrized_tests (TestInt4PreshuffledTensor )
149
+
103
150
104
151
if __name__ == "__main__" :
105
152
run_tests ()
0 commit comments