2424
2525# please check model card for how to generate these models 
2626
27- _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES  =  [
27+ # high precision model, used for testing config deprecation warning 
28+ _HIGH_PRECISION_MODEL  =  "facebook/opt-125m" 
29+ 
30+ _DEPRECATED_SINGLE_LINEAR_MODEL_INFO  =  [
2831    # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev 
29-     "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev" 
32+     (
33+         "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev" ,
34+         1 ,
35+         "Float8DynamicActivationFloat8WeightConfig" ,
36+     ),
37+     # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev 
38+     (
39+         "torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev" ,
40+         1 ,
41+         "Int4WeightOnlyConfig" ,
42+     ),
3043]
3144
3245_DEPRECATED_MODEL_INFO  =  [
3649        1 ,
3750        "Float8DynamicActivationFloat8WeightConfig" ,
3851    ),
52+     # model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev 
53+     (
54+         "torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev" ,
55+         1 ,
56+         "Int4WeightOnlyConfig" ,
57+     ),
3958]
4059
41- _SINGLE_LINEAR_MODEL_NAMES  =  [
60+ _SINGLE_LINEAR_MODEL_INFO  =  [
4261    # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev 
43-     "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev" ,
62+     (
63+         "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev" ,
64+         2 ,
65+         "Float8DynamicActivationFloat8WeightConfig" ,
66+     ),
4467    # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev 
45-     "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev" ,
68+     (
69+         "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev" ,
70+         2 ,
71+         "Int4WeightOnlyConfig" ,
72+     ),
4673    # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev 
47-     "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev" ,
74+     (
75+         "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev" ,
76+         2 ,
77+         "Int4WeightOnlyConfig" ,
78+     ),
4879]
4980
5081
5586    "Skipping the test in fbcode for now, not sure how to download from transformers" , 
5687) 
5788class  TestLoadAndRunCheckpoint (TestCase ):
58-     def  _test_single_linear_helper (self , model_name ):
89+     def  _test_single_linear_helper (
90+         self , model_name , version , config_name , is_deprecated 
91+     ):
5992        from  huggingface_hub  import  hf_hub_download 
6093
6194        downloaded_model  =  hf_hub_download (model_name , filename = "model.pt" )
@@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name):
69102            model  =  torch .nn .Sequential (
70103                torch .nn .Linear (32 , 256 , dtype = torch .bfloat16 , device = "cuda" )
71104            )
72-         with  open (downloaded_model , "rb" ) as  f :
105+ 
106+         with  (
107+             open (downloaded_model , "rb" ) as  f ,
108+             warnings .catch_warnings (record = True ) as  caught_warnings ,
109+         ):
73110            model .load_state_dict (torch .load (f ), assign = True )
111+             if  is_deprecated :
112+                 assert  any (
113+                     f"Models quantized with version { version } { config_name }  
114+                     in  str (w .message )
115+                     for  w  in  caught_warnings 
116+                 ), (
117+                     f"Didn't get expected warning message for deprecation for model: { model_name }  
118+                 )
74119
75120        downloaded_example_inputs  =  hf_hub_download (
76121            model_name , filename = "model_inputs.pt" 
@@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name):
84129        output  =  model (* example_inputs )
85130        self .assertTrue (torch .equal (output , ref_output ))
86131
87-     @common_utils .parametrize ("model_name" , _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES ) 
88-     def  test_deprecated_single_linear (self , model_name ):
89-         self ._test_single_linear_helper (model_name )
132+     @common_utils .parametrize ("model_info" , _DEPRECATED_SINGLE_LINEAR_MODEL_INFO ) 
133+     def  test_deprecated_single_linear (self , model_info ):
134+         model_name , version , config_name  =  model_info 
135+         self ._test_single_linear_helper (
136+             model_name , version , config_name , is_deprecated = True 
137+         )
90138
91-     @common_utils .parametrize ("model_name " , _SINGLE_LINEAR_MODEL_NAMES ) 
92-     def  test_single_linear (self , model_name ):
139+     @common_utils .parametrize ("model_info " , _SINGLE_LINEAR_MODEL_INFO ) 
140+     def  test_single_linear (self , model_info ):
93141        """Test that we can load and run the quantized linear checkpoint with saved sample input 
94142        and match the saved output, to make sure there is no BC breaking changes 
95143        when we make changes to tensor subclass implementations 
96144        """ 
97-         self ._test_single_linear_helper (model_name )
145+         model_name , version , config_name  =  model_info 
146+         self ._test_single_linear_helper (
147+             model_name , version , config_name , is_deprecated = False 
148+         )
98149
99150    @common_utils .parametrize ("model_info" , _DEPRECATED_MODEL_INFO ) 
100151    def  test_deprecated_hf_models (self , model_info ):
@@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info):
109160                torch_dtype = "bfloat16" ,
110161                device_map = "cuda:0" ,
111162            )
163+             # version mismatch check in config.py 
112164            assert  any (
113165                "Stored version is not the same as current default version of the config" 
114166                in  str (w .message )
115167                for  w  in  caught_warnings 
116-             ), "Didn't get expected warning message for version mismatch" 
168+             ), (
169+                 f"Didn't get expected warning message for version mismatch for config { config_name } { model_name }  
170+             )
117171
172+             # checkpoint deprecation 
118173            assert  any (
119-                 f"Models quantized with version 1  of { config_name }  
174+                 f"Models quantized with version { version } { config_name }  
120175                in  str (w .message )
121176                for  w  in  caught_warnings 
122-             ), "Didn't get expected warning message for deprecation" 
177+             ), (
178+                 f"Didn't get expected warning message for deprecation for model { model_name }  
179+             )
123180            assert  isinstance (quantized_model .config .quantization_config , TorchAoConfig )
124181            assert  (
125182                quantized_model .config .quantization_config .quant_type .version  ==  version 
@@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info):
139196            return_tensors = "pt" ,
140197        ).to ("cuda" )
141198        generated_ids  =  quantized_model .generate (
142-             ** inputs , max_new_tokens = 128 , temperature = 0 
199+             ** inputs ,
200+             max_new_tokens = 128 ,
143201        )
144202
145203        downloaded_output  =  hf_hub_download (model_name , filename = "model_output.pt" )
@@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info):
153211            generated_ids , skip_special_tokens = True , clean_up_tokenization_spaces = False 
154212        )
155213
214+         # make sure we throw warning for config deprecation 
215+         with  warnings .catch_warnings (record = True ) as  caught_warnings :
216+             _  =  AutoModelForCausalLM .from_pretrained (
217+                 _HIGH_PRECISION_MODEL ,
218+                 torch_dtype = "bfloat16" ,
219+                 device_map = "cuda:0" ,
220+                 quantization_config = quantized_model .config .quantization_config ,
221+             )
222+             # config version deprecation in quant_api.py 
223+             assert  any (
224+                 f"Config Deprecation: version { version } { config_name }  
225+                 in  str (w .message )
226+                 for  w  in  caught_warnings 
227+             ), (
228+                 f"Didn't get expected warning message for version deprecation for config { config_name } { model_name }  
229+             )
230+ 
156231
157232common_utils .instantiate_parametrized_tests (TestLoadAndRunCheckpoint )
158233
0 commit comments