1717 Float8DynamicActivationFloat8SemiSparseWeightConfig ,
1818 Int4WeightOnlyConfig ,
1919 LNLinearSigmoid ,
20+ RMSNorm ,
21+ RMSNormLinearActivation ,
2022 SemiSparseWeightConfig ,
2123 ToyLinearModel ,
24+ TransformerBlock ,
2225 clean_caches ,
2326 create_model_and_input ,
2427 generate_results_csv ,
@@ -162,6 +165,61 @@ def test_ln_linear_sigmoid(self):
162165 torch .all ((out >= 0 ) & (out <= 1 ))
163166 ) # Check sigmoid output range
164167
168+ def test_rms_norm (self ):
169+ # Test RMSNorm
170+ rms_norm = RMSNorm (dim = 64 )
171+ x = torch .randn (16 , 64 )
172+ out = rms_norm (x )
173+ self .assertEqual (out .shape , (16 , 64 ))
174+
175+ # Test with different eps
176+ rms_norm = RMSNorm (dim = 64 , eps = 1e-5 )
177+ out = rms_norm (x )
178+ self .assertEqual (out .shape , (16 , 64 ))
179+
180+ def test_rms_norm_linear_activation (self ):
181+ # Test with default GELU activation
182+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
183+ x = torch .randn (16 , 64 )
184+ out = model (x )
185+ self .assertEqual (out .shape , (16 , 32 ))
186+ self .assertEqual (out .dtype , torch .float32 )
187+
188+ # Test with ReLU activation
189+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu" )
190+ out = model (x )
191+ self .assertEqual (out .shape , (16 , 32 ))
192+ self .assertTrue (torch .all (out >= 0 )) # Check ReLU output range
193+
194+ # Test with SiLU activation
195+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu" )
196+ out = model (x )
197+ self .assertEqual (out .shape , (16 , 32 ))
198+
199+ # Test with invalid activation
200+ with self .assertRaises (ValueError ):
201+ RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid" )
202+
203+ def test_transformer_block (self ):
204+ # Test with default parameters
205+ model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
206+ x = torch .randn (16 , 16 , 64 ) # [batch_size, seq_len, hidden_dim]
207+ out = model (x )
208+ self .assertEqual (out .shape , (16 , 16 , 64 ))
209+ self .assertEqual (out .dtype , torch .float32 )
210+
211+ # Test with different parameters
212+ model = TransformerBlock (hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32 )
213+ x = torch .randn (8 , 32 , 128 )
214+ out = model (x )
215+ self .assertEqual (out .shape , (8 , 32 , 128 ))
216+
217+ # Test with different head dimensions
218+ model = TransformerBlock (hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32 )
219+ x = torch .randn (4 , 8 , 96 )
220+ out = model (x )
221+ self .assertEqual (out .shape , (4 , 8 , 96 ))
222+
165223 def test_create_model_and_input (self ):
166224 m , k , n = 16 , 64 , 32
167225 model , input_data = create_model_and_input (
@@ -186,6 +244,63 @@ def test_create_model_and_input(self):
186244 self .assertIsInstance (model , LNLinearSigmoid )
187245 self .assertEqual (input_data .shape , (m , k ))
188246
247+ # Test RMSNormLinearActivation
248+ model , input_data = create_model_and_input (
249+ model_type = "rms_norm_linear_activation" ,
250+ m = m ,
251+ k = k ,
252+ n = n ,
253+ high_precision_dtype = torch .float32 ,
254+ device = "cpu" ,
255+ )
256+ self .assertIsInstance (model , RMSNormLinearActivation )
257+ self .assertEqual (input_data .shape , (m , k ))
258+
259+ # Test TransformerBlock
260+ model , input_data = create_model_and_input (
261+ model_type = "transformer_block" ,
262+ m = m ,
263+ k = k ,
264+ n = n , # n is not used for transformer_block
265+ high_precision_dtype = torch .float32 ,
266+ device = "cpu" ,
267+ )
268+ self .assertIsInstance (model , TransformerBlock )
269+ self .assertEqual (input_data .shape , (m , 16 , k )) # [batch_size, seq_len, hidden_dim]
270+
271+ def test_quantization_on_models (self ):
272+ # Test quantization on RMSNormLinearActivation
273+ model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
274+ x = torch .randn (16 , 64 )
275+
276+ # Test with Int8WeightOnlyConfig
277+ config = string_to_config (quantization = "int8wo" , sparsity = None )
278+ if config is not None :
279+ # Skip quantization test if torchao.quantization.quantize is not available
280+ try :
281+ from torchao .quantization import quantize
282+ quantized_model = quantize (model , config )
283+ out = quantized_model (x )
284+ self .assertEqual (out .shape , (16 , 32 ))
285+ except ImportError :
286+ print ("Skipping quantization test: torchao.quantization.quantize not available" )
287+
288+ # Test quantization on TransformerBlock
289+ model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
290+ x = torch .randn (16 , 16 , 64 )
291+
292+ # Test with Int8WeightOnlyConfig
293+ config = string_to_config (quantization = "int8wo" , sparsity = None )
294+ if config is not None :
295+ # Skip quantization test if torchao.quantization.quantize is not available
296+ try :
297+ from torchao .quantization import quantize
298+ quantized_model = quantize (model , config )
299+ out = quantized_model (x )
300+ self .assertEqual (out .shape , (16 , 16 , 64 ))
301+ except ImportError :
302+ print ("Skipping quantization test: torchao.quantization.quantize not available" )
303+
189304 def test_generate_results_csv (self ):
190305 results = [
191306 BenchmarkResult (
0 commit comments