@@ -205,7 +205,6 @@ const auto& reshaper_names()
205
205
{
206
206
// clang-format off
207
207
static const std::unordered_set<std::string> names = {
208
- " slice" ,
209
208
" transpose" ,
210
209
" multibroadcast" ,
211
210
" broadcast" ,
@@ -220,12 +219,17 @@ const auto& reshaper_names()
220
219
return names;
221
220
}
222
221
222
+ bool is_fusable_input_op (const std::string& name)
223
+ {
224
+ return contains (reshaper_names (), name) or contains ({" slice" }, name);
225
+ }
226
+
223
227
std::tuple<instruction_ref, std::vector<operation>>
224
228
get_fusable_input_op_stream (instruction_ref lower_input)
225
229
{
226
230
instruction_ref upper_input = lower_input;
227
231
std::vector<operation> op_stream;
228
- while (contains ( reshaper_names (), upper_input->name ()))
232
+ while (is_fusable_input_op ( upper_input->name ()))
229
233
{
230
234
operation op = upper_input->get_operator ();
231
235
op_stream.push_back (op);
@@ -364,6 +368,18 @@ create_param_map_with_literals(module_ref mm, const module* pm, const shape& sha
364
368
return ins_map;
365
369
}
366
370
371
+ instruction_ref insert_pointwise (module& m,
372
+ instruction_ref ins,
373
+ const operation& op,
374
+ const std::vector<instruction_ref>& inputs,
375
+ const std::vector<module_ref>& mod_args)
376
+ {
377
+ // Only used in assert
378
+ (void )mod_args;
379
+ assert (mod_args.empty ());
380
+ return insert_common_op (m, ins, op, inputs, {.common_type = false });
381
+ }
382
+
367
383
instruction_ref unroll_pointwise (module& main_mod,
368
384
instruction_ref pos,
369
385
const operation& op,
@@ -501,9 +517,7 @@ MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins)
501
517
{
502
518
if (ins->name () != " split_fused_reduce" )
503
519
return false ;
504
- auto * mod_arg = ins->module_inputs ().front ();
505
- auto supported_reshapes = reshaper_names ();
506
- supported_reshapes.erase (" slice" );
520
+ auto * mod_arg = ins->module_inputs ().front ();
507
521
std::unordered_set<std::string> builtins = {" @param" , " @literal" , " @return" };
508
522
for (const auto i : iterator_for (*mod_arg))
509
523
{
@@ -627,12 +641,19 @@ struct find_mlir_fused_ops
627
641
{
628
642
mlir_mode conv_mode = mlir_mode::none;
629
643
mlir_mode dot_mode = mlir_mode::none;
644
+
645
+ static auto make_conv_dot_reshaper_names ()
646
+ {
647
+ auto names = reshaper_names ();
648
+ names.erase (" broadcast" );
649
+ names.erase (" multibroadcast" );
650
+ return names;
651
+ }
652
+
630
653
auto matcher () const
631
654
{
632
- auto reshapes = reshaper_names ();
633
- // slice is not supported
634
- reshapes.erase (" slice" );
635
- auto dot_or_conv = match::skip (match::name (reshapes))(
655
+ static const auto conv_dot_reshaper_names = make_conv_dot_reshaper_names ();
656
+ auto dot_or_conv = match::skip (match::name (conv_dot_reshaper_names))(
636
657
match::any_of (is_mlir_dot (dot_mode), is_mlir_conv (conv_mode)).bind (" gemm_based_op" ));
637
658
return mlir_pointwise ()(match::any_of[match::inputs ()](dot_or_conv.bind (" x" )));
638
659
}
@@ -650,68 +671,62 @@ struct find_mlir_fused_ops
650
671
return i != x_ins and reaches (gemm_based_op, i);
651
672
}))
652
673
return ;
653
- auto names = pm-> get_parameter_names ();
654
- std::sort (names. begin (), names. end ()) ;
674
+
675
+ std::unordered_map<instruction_ref, instruction_ref> map_ins ;
655
676
module_ref mm = mpm.create_module (" mlir_" + pm->name ());
656
677
mm->set_bypass ();
657
- auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op (
658
- mm, gemm_based_op->inputs (), gemm_based_op->get_operator ());
659
- std::unordered_map<instruction_ref, instruction_ref> param_map =
660
- create_param_map_with_literals (mm, pm, pw_ins->get_shape ());
661
- auto [upper_input, op_stream] = get_fusable_input_op_stream (x_ins);
662
- assert (upper_input == gemm_based_op);
663
- auto prev_input = anchor_op;
664
- for (const auto & op : reverse (op_stream))
665
- {
666
- prev_input = mm->add_instruction (op, {prev_input});
667
- }
668
- assert (prev_input->get_shape ().lens () == x_ins->get_shape ().lens ());
669
- param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped
670
- // input to pointwise in new fused module
678
+ fuse_input_ops (mm, gemm_based_op->inputs (), &map_ins);
679
+
671
680
bool gemm_has_multi_outs = gemm_based_op->outputs ().size () > 1 ;
672
- auto reshaped_gemm = x_ins ;
673
- std::vector<instruction_ref> reshapes_vec ;
674
- while (reshaped_gemm != gemm_based_op)
681
+ std::vector<instruction_ref> inss_to_insert ;
682
+ auto reshape_ins = x_ins ;
683
+ for (; reshape_ins != gemm_based_op; reshape_ins = reshape_ins-> inputs (). front () )
675
684
{
676
- reshapes_vec.push_back (reshaped_gemm);
677
- gemm_has_multi_outs = gemm_has_multi_outs or reshaped_gemm->outputs ().size () > 1 ;
678
- reshaped_gemm = reshaped_gemm->inputs ().at (0 );
685
+ inss_to_insert.push_back (reshape_ins);
686
+ gemm_has_multi_outs |= reshape_ins->outputs ().size () > 1 ;
679
687
}
680
- reshapes_vec.push_back (reshaped_gemm);
688
+ inss_to_insert.push_back (gemm_based_op);
689
+ std::reverse (inss_to_insert.begin (), inss_to_insert.end ());
690
+ mm->add_instructions (inss_to_insert, &map_ins);
681
691
682
- auto return_vals = mm->fuse (*pm, pw_ins->inputs (), ¶m_map);
692
+ fuse_input_ops (mm, pw_ins->inputs (), &map_ins);
693
+ auto rins = mm->fuse (*pm, pw_ins->inputs (), &map_ins, &insert_pointwise);
683
694
if (gemm_has_multi_outs)
684
695
{
685
- return_vals. insert (return_vals. begin (), anchor_op );
696
+ rins. push_back (map_ins. at (gemm_based_op) );
686
697
}
687
- mm->add_return (return_vals );
698
+ mm->add_return (rins );
688
699
689
- std::vector<instruction_ref> inputs;
690
- std::copy_if (pw_ins->inputs ().begin (),
691
- pw_ins->inputs ().end (),
692
- std::back_inserter (inputs),
693
- [&](auto input) { return input != x_ins; });
694
- inputs.insert (inputs.end (), top_inputs.begin (), top_inputs.end ());
700
+ auto inputs = find_inputs (map_ins, &mpm.get_module (), mm);
701
+ auto fused_ins = mpm.get_module ().insert_instruction (
702
+ pw_ins, mlir_op{gemm_based_op->get_operator ()}, mlir_contiguous (mpm, inputs), {mm});
695
703
if (gemm_has_multi_outs)
696
704
{
697
- auto fused_ins = mpm.get_module ().insert_instruction (
698
- pw_ins, mlir_op{gemm_based_op->get_operator ()}, mlir_contiguous (mpm, inputs), {mm});
699
- mpm.get_module ().replace_instruction (
700
- pw_ins, migraphx::make_op (" get_tuple_elem" , {{" index" , 1 }}), fused_ins);
701
705
auto dot_ins = mpm.get_module ().insert_instruction (
702
- pw_ins, migraphx::make_op (" get_tuple_elem" , {{" index" , 0 }}), fused_ins);
703
- // move all the reshape instructions and original GEMM instruction after the fused op to
704
- // avoid generating invalid migraphx program
705
- for (const auto & orig_i : reverse (reshapes_vec))
706
+ pw_ins,
707
+ migraphx::make_op (" get_tuple_elem" , {{" index" , rins.size () - 1 }}),
708
+ fused_ins);
709
+
710
+ // move all the reshape instructions after the fused op to avoid
711
+ // generating invalid migraphx program since the reshapes can be
712
+ // used by the replaced dot_ins
713
+ for (instruction_ref x : inss_to_insert)
706
714
{
707
- mpm.get_module ().move_instruction (orig_i, pw_ins);
715
+ if (x == gemm_based_op)
716
+ continue ;
717
+ mpm.get_module ().move_instruction (x, pw_ins);
708
718
}
719
+
709
720
mpm.get_module ().replace_instruction (gemm_based_op, dot_ins);
721
+ if (rins.size () == 2 )
722
+ {
723
+ mpm.get_module ().replace_instruction (
724
+ pw_ins, migraphx::make_op (" get_tuple_elem" , {{" index" , 0 }}), fused_ins);
725
+ }
710
726
}
711
727
else
712
728
{
713
- mpm.get_module ().replace_instruction (
714
- pw_ins, mlir_op{gemm_based_op->get_operator ()}, mlir_contiguous (mpm, inputs), {mm});
729
+ mpm.get_module ().replace_instruction (pw_ins, fused_ins);
715
730
}
716
731
}
717
732
};
@@ -851,9 +866,8 @@ struct find_mlir_standalone_attention_op
851
866
map_main_to_mattn[fused_reduce] = softmax;
852
867
853
868
// all preceeding ops should be fusable ops
854
- if (not std::all_of (m_gemm1, softmax, [](auto i) {
855
- return (is_pointwise_op_supported_by_mlir (i) or
856
- contains (reshaper_names (), i.name ()));
869
+ if (not std::all_of (m_gemm1, softmax, [](const instruction& i) {
870
+ return (is_pointwise_op_supported_by_mlir (i) or is_fusable_input_op (i.name ()));
857
871
}))
858
872
return ;
859
873
@@ -938,18 +952,6 @@ struct find_pointwise_mlir
938
952
return contains (op_names, op_ins->name ());
939
953
}
940
954
941
- static instruction_ref insert_pointwise (module& m,
942
- instruction_ref ins,
943
- const operation& op,
944
- const std::vector<instruction_ref>& inputs,
945
- const std::vector<module_ref>& mod_args)
946
- {
947
- // Only used in assert
948
- (void )mod_args;
949
- assert (mod_args.empty ());
950
- return insert_common_op (m, ins, op, inputs, {.common_type = false });
951
- }
952
-
953
955
void apply (module_pass_manager& mpm, const match::matcher_result& r) const
954
956
{
955
957
auto ins = r.result ;
0 commit comments