Skip to content

Commit 5b4dad3

Browse files
authored
fix mllama ut (#1735)
1 parent 5b72c93 commit 5b4dad3

File tree

4 files changed

+46
-80
lines changed

4 files changed

+46
-80
lines changed

mindnlp/transformers/generation/candidate_generator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,16 @@ def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_en
412412
model_kwargs[mask_key] = mask[:, :mask_length_diff]
413413
elif mask_length_diff > 0:
414414
model_kwargs[mask_key] = ops.cat([mask, ops.ones((mask.shape[0], mask_length_diff), dtype=mask.dtype)], dim=-1)
415+
if "cross_attention_mask" in model_kwargs:
416+
# Mllama case is special and has another mask for cross attention model
417+
cross_mask = model_kwargs["cross_attention_mask"]
418+
if mask_length_diff < 0:
419+
model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
420+
elif mask_length_diff > 0:
421+
new_mask = cross_mask[:, -1:, :, :].tile((1, mask_length_diff, 1, 1))
422+
model_kwargs["cross_attention_mask"] = ops.cat([cross_mask, new_mask], dim=1)
415423
return model_kwargs
416424

417-
418425
def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
419426
"""Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
420427
if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:

mindnlp/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2388,7 +2388,7 @@ def _dola_decoding(
23882388
this_peer_finished = False
23892389

23902390
# prepare layers for DoLa decoding
2391-
final_layer = self.config.num_hidden_layers
2391+
final_layer = self.config.get_text_config().num_hidden_layers
23922392
# if the model has tied word embeddings, we skip the word embeddings (0-th) layer and start from the 2nd layer,
23932393
# as the early exit from word embeddings will become identity function
23942394
# if the model is really shallow (<=2 layers), we use the 1st layer if it's not the final layer and the 0-th

mindnlp/transformers/models/mllama/modeling_mllama.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ class MllamaPreTrainedModel(PreTrainedModel):
883883
_supports_cache_class = True
884884
_supports_static_cache = False
885885
# _supports_sdpa = True
886-
# _supports_quantized_cache = True
886+
_supports_quantized_cache = False
887887

888888
def _init_weights(self, module):
889889
std = self.config.get_text_config().initializer_range
@@ -1515,6 +1515,8 @@ def prepare_inputs_for_generation(
15151515

15161516

15171517
class MllamaForConditionalGeneration(MllamaPreTrainedModel, GenerationMixin):
1518+
_supports_quantized_cache = False # quant cache not supported in encoder-decoder setting
1519+
15181520
def __init__(self, config: MllamaConfig):
15191521
super().__init__(config)
15201522
self.vocab_size = config.text_config.vocab_size

tests/ut/transformers/models/mllama/test_modeling_mllama.py

Lines changed: 34 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,6 @@ def setUp(self):
131131
self.model_tester = MllamaText2TextModelTester(self)
132132
self.config_tester = ConfigTester(self, config_class=MllamaTextConfig, has_text_modality=True)
133133

134-
@unittest.skip(reason="The outputs don't match, no idea why")
135-
def test_beam_search_low_memory(self):
136-
pass
137-
138-
@unittest.skip(reason="Quanto test is borken")
139-
def test_generate_with_quant_cache(self):
140-
pass
141-
142134

143135
class MllamaVisionText2TextModelTester:
144136
def __init__(
@@ -201,6 +193,7 @@ def __init__(
201193
self.image_size = 224
202194
self.max_num_images = 1
203195
self.max_image_tiles = 4
196+
self.image_length = 904
204197

205198
def get_config(self):
206199
return MllamaConfig(
@@ -319,86 +312,50 @@ def test_inputs_embeds_matches_input_ids(self):
319312
out_embeds = model(inputs_embeds=inputs_embeds, **inputs)[0]
320313
self.assertTrue(ops.allclose(out_embeds, out_ids))
321314

322-
@unittest.skip(reason="Static cache not supported")
323-
def test_static_cache_matches_dynamic(self):
324-
# TypeError: list indices must be integers or slices, not tuple
325-
# TODO: @raushan, please look into this for new cache format
326-
pass
327-
328-
@unittest.skip(reason="Mllama has dynamic control flow which is not yet supported by compile")
329-
def test_generate_compile_fullgraph(self):
330-
pass
315+
def _check_attentions_for_generate(
316+
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
317+
):
318+
# Mllama has cross attention layers and those have a different shape than normal attention layers
319+
self.assertIsInstance(attentions, tuple)
320+
self.assertListEqual(
321+
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
322+
)
323+
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
324+
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
325+
for idx, iter_attentions in enumerate(attentions):
326+
tgt_len = min_length + idx if not use_cache else 1
327+
src_len = min_length + idx
328+
expected_shape = (
329+
batch_size * num_beam_groups,
330+
config.num_attention_heads,
331+
tgt_len,
332+
src_len,
333+
)
334+
expected_shape_cross = (
335+
batch_size * num_beam_groups,
336+
config.num_attention_heads,
337+
tgt_len,
338+
self.model_tester.image_length,
339+
)
340+
expected_shapes = [
341+
expected_shape if layer_idx not in cross_attention_layers else expected_shape_cross
342+
for layer_idx in range(len(iter_attentions))
343+
]
344+
self.assertListEqual([layer_attention.shape for layer_attention in iter_attentions], expected_shapes)
331345

332-
@unittest.skip(reason="The outputs don't match, no idea why")
333-
def test_beam_search_low_memory(self):
334-
pass
335346

336-
@unittest.skip(reason="Mllama is not yet supported by compile")
337-
def test_sdpa_can_compile_dynamic(self):
338-
# TODO: look into this, AttributeError("'tensor' object has no attribute '__pow__'")
339-
# relevant issue: https://github.com/pytorch/pytorch/issues/133166
347+
@unittest.skip(reason="The test itself is broken") # TODO @zucchini-nlp
348+
def test_generate_with_quant_cache(self):
340349
pass
341350

342351
@unittest.skip(reason="The test itself is broken") # TODO @zucchini-nlp
343-
def test_generate_with_quant_cache(self):
352+
def test_beam_search_low_memory(self):
344353
pass
345354

346355
@unittest.skip(reason="AssertionError: Items in the second set but not the first: might be a setting issue")
347356
def test_model_parallelism(self):
348357
pass
349358

350-
@unittest.skip(reason="Failing test, need to fix")
351-
def test_compile_cuda_graph_time(self):
352-
pass
353-
354-
@unittest.skip(reason="Failing test, need to fix")
355-
def test_torch_compile_fullgraph(self):
356-
pass
357-
358-
@unittest.skip(reason="Device side assert triggered")
359-
def test_assisted_decoding_with_num_logits_to_keep(self):
360-
pass
361-
362-
@unittest.skip(reason="Failing test, need to fix")
363-
def test_beam_sample_generate_dict_output():
364-
pass
365-
366-
@unittest.skip(reason="Failing test, need to fix")
367-
def test_beam_search_generate_dict_output():
368-
pass
369-
370-
@unittest.skip(reason="Failing test, need to fix")
371-
def test_constrained_beam_search_generate_dict_output():
372-
pass
373-
374-
@unittest.skip(reason="Failing test, need to fix")
375-
def test_dola_decoding_sample():
376-
pass
377-
378-
@unittest.skip(reason="Failing test, need to fix")
379-
def test_generate_methods_with_num_logits_to_keep():
380-
pass
381-
382-
@unittest.skip(reason="Failing test, need to fix")
383-
def test_greedy_generate_dict_outputs():
384-
pass
385-
386-
@unittest.skip(reason="Failing test, need to fix")
387-
def test_group_beam_search_generate_dict_output():
388-
pass
389-
390-
@unittest.skip(reason="Failing test, need to fix")
391-
def test_model_parallel_beam_search():
392-
pass
393-
394-
@unittest.skip(reason="Failing test, need to fix")
395-
def test_new_cache_format_2():
396-
pass
397-
398-
@unittest.skip(reason="Failing test, need to fix")
399-
def test_sample_generate_dict_output():
400-
pass
401-
402359

403360
@require_mindspore
404361
class MllamaForConditionalGenerationIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)