@@ -171,7 +171,7 @@ def test_rms_norm(self):
171171 x = torch .randn (16 , 64 )
172172 out = rms_norm (x )
173173 self .assertEqual (out .shape , (16 , 64 ))
174-
174+
175175 # Test with different eps
176176 rms_norm = RMSNorm (dim = 64 , eps = 1e-5 )
177177 out = rms_norm (x )
@@ -184,38 +184,50 @@ def test_rms_norm_linear_activation(self):
184184 out = model (x )
185185 self .assertEqual (out .shape , (16 , 32 ))
186186 self .assertEqual (out .dtype , torch .float32 )
187-
187+
188188 # Test with ReLU activation
189- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu" )
189+ model = RMSNormLinearActivation (
190+ fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "relu"
191+ )
190192 out = model (x )
191193 self .assertEqual (out .shape , (16 , 32 ))
192194 self .assertTrue (torch .all (out >= 0 )) # Check ReLU output range
193-
195+
194196 # Test with SiLU activation
195- model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu" )
197+ model = RMSNormLinearActivation (
198+ fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "silu"
199+ )
196200 out = model (x )
197201 self .assertEqual (out .shape , (16 , 32 ))
198-
202+
199203 # Test with invalid activation
200204 with self .assertRaises (ValueError ):
201- RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid" )
205+ RMSNormLinearActivation (
206+ fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 , activation = "invalid"
207+ )
202208
203209 def test_transformer_block (self ):
204210 # Test with default parameters
205- model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
211+ model = TransformerBlock (
212+ hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32
213+ )
206214 x = torch .randn (16 , 16 , 64 ) # [batch_size, seq_len, hidden_dim]
207215 out = model (x )
208216 self .assertEqual (out .shape , (16 , 16 , 64 ))
209217 self .assertEqual (out .dtype , torch .float32 )
210-
218+
211219 # Test with different parameters
212- model = TransformerBlock (hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32 )
220+ model = TransformerBlock (
221+ hidden_dim = 128 , num_heads = 4 , mlp_ratio = 2 , dtype = torch .float32
222+ )
213223 x = torch .randn (8 , 32 , 128 )
214224 out = model (x )
215225 self .assertEqual (out .shape , (8 , 32 , 128 ))
216-
226+
217227 # Test with different head dimensions
218- model = TransformerBlock (hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32 )
228+ model = TransformerBlock (
229+ hidden_dim = 96 , num_heads = 6 , mlp_ratio = 3 , dtype = torch .float32
230+ )
219231 x = torch .randn (4 , 8 , 96 )
220232 out = model (x )
221233 self .assertEqual (out .shape , (4 , 8 , 96 ))
@@ -255,7 +267,7 @@ def test_create_model_and_input(self):
255267 )
256268 self .assertIsInstance (model , RMSNormLinearActivation )
257269 self .assertEqual (input_data .shape , (m , k ))
258-
270+
259271 # Test TransformerBlock
260272 model , input_data = create_model_and_input (
261273 model_type = "transformer_block" ,
@@ -266,40 +278,50 @@ def test_create_model_and_input(self):
266278 device = "cpu" ,
267279 )
268280 self .assertIsInstance (model , TransformerBlock )
269- self .assertEqual (input_data .shape , (m , 16 , k )) # [batch_size, seq_len, hidden_dim]
281+ self .assertEqual (
282+ input_data .shape , (m , 16 , k )
283+ ) # [batch_size, seq_len, hidden_dim]
270284
271285 def test_quantization_on_models (self ):
272286 # Test quantization on RMSNormLinearActivation
273287 model = RMSNormLinearActivation (fc_dim1 = 64 , fc_dim2 = 32 , dtype = torch .float32 )
274288 x = torch .randn (16 , 64 )
275-
289+
276290 # Test with Int8WeightOnlyConfig
277291 config = string_to_config (quantization = "int8wo" , sparsity = None )
278292 if config is not None :
279293 # Skip quantization test if torchao.quantization.quantize is not available
280294 try :
281295 from torchao .quantization import quantize
296+
282297 quantized_model = quantize (model , config )
283298 out = quantized_model (x )
284299 self .assertEqual (out .shape , (16 , 32 ))
285300 except ImportError :
286- print ("Skipping quantization test: torchao.quantization.quantize not available" )
287-
301+ print (
302+ "Skipping quantization test: torchao.quantization.quantize not available"
303+ )
304+
288305 # Test quantization on TransformerBlock
289- model = TransformerBlock (hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32 )
306+ model = TransformerBlock (
307+ hidden_dim = 64 , num_heads = 8 , mlp_ratio = 4 , dtype = torch .float32
308+ )
290309 x = torch .randn (16 , 16 , 64 )
291-
310+
292311 # Test with Int8WeightOnlyConfig
293312 config = string_to_config (quantization = "int8wo" , sparsity = None )
294313 if config is not None :
295314 # Skip quantization test if torchao.quantization.quantize is not available
296315 try :
297316 from torchao .quantization import quantize
317+
298318 quantized_model = quantize (model , config )
299319 out = quantized_model (x )
300320 self .assertEqual (out .shape , (16 , 16 , 64 ))
301321 except ImportError :
302- print ("Skipping quantization test: torchao.quantization.quantize not available" )
322+ print (
323+ "Skipping quantization test: torchao.quantization.quantize not available"
324+ )
303325
304326 def test_generate_results_csv (self ):
305327 results = [
0 commit comments