@@ -131,14 +131,6 @@ def setUp(self):
131
131
self .model_tester = MllamaText2TextModelTester (self )
132
132
self .config_tester = ConfigTester (self , config_class = MllamaTextConfig , has_text_modality = True )
133
133
134
- @unittest .skip (reason = "The outputs don't match, no idea why" )
135
- def test_beam_search_low_memory (self ):
136
- pass
137
-
138
- @unittest .skip (reason = "Quanto test is borken" )
139
- def test_generate_with_quant_cache (self ):
140
- pass
141
-
142
134
143
135
class MllamaVisionText2TextModelTester :
144
136
def __init__ (
@@ -201,6 +193,7 @@ def __init__(
201
193
self .image_size = 224
202
194
self .max_num_images = 1
203
195
self .max_image_tiles = 4
196
+ self .image_length = 904
204
197
205
198
def get_config (self ):
206
199
return MllamaConfig (
@@ -319,86 +312,50 @@ def test_inputs_embeds_matches_input_ids(self):
319
312
out_embeds = model (inputs_embeds = inputs_embeds , ** inputs )[0 ]
320
313
self .assertTrue (ops .allclose (out_embeds , out_ids ))
321
314
322
- @unittest .skip (reason = "Static cache not supported" )
323
- def test_static_cache_matches_dynamic (self ):
324
- # TypeError: list indices must be integers or slices, not tuple
325
- # TODO: @raushan, please look into this for new cache format
326
- pass
327
-
328
- @unittest .skip (reason = "Mllama has dynamic control flow which is not yet supported by compile" )
329
- def test_generate_compile_fullgraph (self ):
330
- pass
315
+ def _check_attentions_for_generate (
316
+ self , batch_size , attentions , min_length , max_length , config , use_cache = False , num_beam_groups = 1
317
+ ):
318
+ # Mllama has cross attention layers and those have a different shape than normal attention layers
319
+ self .assertIsInstance (attentions , tuple )
320
+ self .assertListEqual (
321
+ [isinstance (iter_attentions , tuple ) for iter_attentions in attentions ], [True ] * len (attentions )
322
+ )
323
+ self .assertEqual (len (attentions ), (max_length - min_length ) * num_beam_groups )
324
+ cross_attention_layers = self .model_tester .text_config ["cross_attention_layers" ]
325
+ for idx , iter_attentions in enumerate (attentions ):
326
+ tgt_len = min_length + idx if not use_cache else 1
327
+ src_len = min_length + idx
328
+ expected_shape = (
329
+ batch_size * num_beam_groups ,
330
+ config .num_attention_heads ,
331
+ tgt_len ,
332
+ src_len ,
333
+ )
334
+ expected_shape_cross = (
335
+ batch_size * num_beam_groups ,
336
+ config .num_attention_heads ,
337
+ tgt_len ,
338
+ self .model_tester .image_length ,
339
+ )
340
+ expected_shapes = [
341
+ expected_shape if layer_idx not in cross_attention_layers else expected_shape_cross
342
+ for layer_idx in range (len (iter_attentions ))
343
+ ]
344
+ self .assertListEqual ([layer_attention .shape for layer_attention in iter_attentions ], expected_shapes )
331
345
332
- @unittest .skip (reason = "The outputs don't match, no idea why" )
333
- def test_beam_search_low_memory (self ):
334
- pass
335
346
336
- @unittest .skip (reason = "Mllama is not yet supported by compile" )
337
- def test_sdpa_can_compile_dynamic (self ):
338
- # TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'")
339
- # relevant issue: https://github.com/pytorch/pytorch/issues/133166
347
+ @unittest .skip (reason = "The test itself is broken" ) # TODO @zucchini-nlp
348
+ def test_generate_with_quant_cache (self ):
340
349
pass
341
350
342
351
@unittest .skip (reason = "The test itself is broken" ) # TODO @zucchini-nlp
343
- def test_generate_with_quant_cache (self ):
352
+ def test_beam_search_low_memory (self ):
344
353
pass
345
354
346
355
@unittest .skip (reason = "AssertionError: Items in the second set but not the first: might be a setting issue" )
347
356
def test_model_parallelism (self ):
348
357
pass
349
358
350
- @unittest .skip (reason = "Failing test, need to fix" )
351
- def test_compile_cuda_graph_time (self ):
352
- pass
353
-
354
- @unittest .skip (reason = "Failing test, need to fix" )
355
- def test_torch_compile_fullgraph (self ):
356
- pass
357
-
358
- @unittest .skip (reason = "Device side assert triggered" )
359
- def test_assisted_decoding_with_num_logits_to_keep (self ):
360
- pass
361
-
362
- @unittest .skip (reason = "Failing test, need to fix" )
363
- def test_beam_sample_generate_dict_output ():
364
- pass
365
-
366
- @unittest .skip (reason = "Failing test, need to fix" )
367
- def test_beam_search_generate_dict_output ():
368
- pass
369
-
370
- @unittest .skip (reason = "Failing test, need to fix" )
371
- def test_constrained_beam_search_generate_dict_output ():
372
- pass
373
-
374
- @unittest .skip (reason = "Failing test, need to fix" )
375
- def test_dola_decoding_sample ():
376
- pass
377
-
378
- @unittest .skip (reason = "Failing test, need to fix" )
379
- def test_generate_methods_with_num_logits_to_keep ():
380
- pass
381
-
382
- @unittest .skip (reason = "Failing test, need to fix" )
383
- def test_greedy_generate_dict_outputs ():
384
- pass
385
-
386
- @unittest .skip (reason = "Failing test, need to fix" )
387
- def test_group_beam_search_generate_dict_output ():
388
- pass
389
-
390
- @unittest .skip (reason = "Failing test, need to fix" )
391
- def test_model_parallel_beam_search ():
392
- pass
393
-
394
- @unittest .skip (reason = "Failing test, need to fix" )
395
- def test_new_cache_format_2 ():
396
- pass
397
-
398
- @unittest .skip (reason = "Failing test, need to fix" )
399
- def test_sample_generate_dict_output ():
400
- pass
401
-
402
359
403
360
@require_mindspore
404
361
class MllamaForConditionalGenerationIntegrationTest (unittest .TestCase ):
0 commit comments