@@ -57,11 +57,55 @@ def tearDown(self):
5757        shutil .rmtree (self .temp_dir )
5858
5959    def  test_get_shapes_for_config (self ):
60+         # Test custom shapes 
6061        shapes  =  get_shapes_for_config (
6162            self .test_config ["model_params" ][0 ]["matrix_shapes" ]
6263        )
6364        self .assertEqual (len (shapes ), 1 )
6465        self .assertEqual (shapes [0 ], ("custom" , [1024 , 1024 , 1024 ]))
66+         
67+         # Test llama shapes 
68+         llama_shapes  =  get_shapes_for_config ([
69+             {"name" : "llama" }
70+         ])
71+         self .assertEqual (len (llama_shapes ), 4 )  # 4 LLaMa shapes 
72+         self .assertTrue (any (name .startswith ("llama_attn.wqkv" ) for  name , _  in  llama_shapes ))
73+         self .assertTrue (any (name .startswith ("llama_attn.w0" ) for  name , _  in  llama_shapes ))
74+         self .assertTrue (any (name .startswith ("llama_ffn.w13" ) for  name , _  in  llama_shapes ))
75+         self .assertTrue (any (name .startswith ("llama_ffn.w2" ) for  name , _  in  llama_shapes ))
76+         
77+         # Test pow2 shapes 
78+         pow2_shapes  =  get_shapes_for_config ([
79+             {"name" : "pow2" , "min_power" : 10 , "max_power" : 12 }
80+         ])
81+         self .assertEqual (len (pow2_shapes ), 3 )  # 3 powers of 2 (10, 11, 12) 
82+         self .assertEqual (pow2_shapes [0 ], ("pow2_0" , [1024 , 1024 , 1024 ]))  # 2^10 
83+         self .assertEqual (pow2_shapes [1 ], ("pow2_1" , [2048 , 2048 , 2048 ]))  # 2^11 
84+         self .assertEqual (pow2_shapes [2 ], ("pow2_2" , [4096 , 4096 , 4096 ]))  # 2^12 
85+         
86+         # Test pow2_extended shapes 
87+         pow2_extended_shapes  =  get_shapes_for_config ([
88+             {"name" : "pow2_extended" , "min_power" : 10 , "max_power" : 11 }
89+         ])
90+         self .assertEqual (len (pow2_extended_shapes ), 4 )  # 2 powers of 2, each with 2 variants 
91+         self .assertEqual (pow2_extended_shapes [0 ], ("pow2_extended_0" , [1024 , 1024 , 1024 ]))  # 2^10 
92+         self .assertEqual (pow2_extended_shapes [1 ], ("pow2_extended_1" , [1536 , 1536 , 1536 ]))  # 2^10 + 2^9 
93+         self .assertEqual (pow2_extended_shapes [2 ], ("pow2_extended_2" , [2048 , 2048 , 2048 ]))  # 2^11 
94+         self .assertEqual (pow2_extended_shapes [3 ], ("pow2_extended_3" , [3072 , 3072 , 3072 ]))  # 2^11 + 2^10 
95+         
96+         # Test sweep shapes (limited to a small range for testing) 
97+         sweep_shapes  =  get_shapes_for_config ([
98+             {"name" : "sweep" , "min_power" : 8 , "max_power" : 9 }
99+         ])
100+         # For min_power=8, max_power=9, we should have 8 shapes (2^3 = 8 combinations) 
101+         self .assertEqual (len (sweep_shapes ), 8 )
102+         # Check that all shapes have the expected format 
103+         for  name , shape  in  sweep_shapes :
104+             self .assertTrue (name .startswith ("sweep_" ))
105+             self .assertEqual (len (shape ), 3 )  # [M, K, N] 
106+             # Check that all dimensions are powers of 2 between 2^8 and 2^9 
107+             for  dim  in  shape :
108+                 self .assertTrue (dim  in  [256 , 512 ])  # 2^8, 2^9 
65109
66110    def  test_get_param_combinations (self ):
67111        model_param  =  self .test_config ["model_params" ][0 ]
0 commit comments