Skip to content

Commit 695506a

Browse files
authored
Dont fuse broadcast after conv/gemm in mlir #3863 (#3867)
* Fuse reshapes on pointwise inputs for mlir output fusion (#3569) * Dont fuse broadcast after conv/gemm in mlir (#3863) * Layout convolution as NHWC or NCHW only (#3729)
1 parent 656c594 commit 695506a

File tree

4 files changed

+240
-148
lines changed

4 files changed

+240
-148
lines changed

src/layout_convolution.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ inline namespace MIGRAPHX_INLINE_NS {
4040
namespace {
4141
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
4242
{
43+
std::vector<int64_t> perm(ins->get_shape().ndim());
4344
if(lc.channels_last)
4445
{
45-
std::vector<int64_t> perm(ins->get_shape().ndim());
4646
std::iota(perm.begin() + 1, perm.end() - 1, 2);
4747
perm.back() = 1;
48-
return perm;
4948
}
50-
return find_permutation(ins->inputs().front()->get_shape());
49+
else
50+
{
51+
std::iota(perm.begin(), perm.end(), 0);
52+
}
53+
return perm;
5154
}
5255

5356
std::vector<int64_t> get_default_permutation(instruction_ref ins)

src/targets/gpu/fuse_mlir.cpp

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,6 @@ const auto& reshaper_names()
205205
{
206206
// clang-format off
207207
static const std::unordered_set<std::string> names = {
208-
"slice",
209208
"transpose",
210209
"multibroadcast",
211210
"broadcast",
@@ -220,12 +219,17 @@ const auto& reshaper_names()
220219
return names;
221220
}
222221

222+
bool is_fusable_input_op(const std::string& name)
223+
{
224+
return contains(reshaper_names(), name) or contains({"slice"}, name);
225+
}
226+
223227
std::tuple<instruction_ref, std::vector<operation>>
224228
get_fusable_input_op_stream(instruction_ref lower_input)
225229
{
226230
instruction_ref upper_input = lower_input;
227231
std::vector<operation> op_stream;
228-
while(contains(reshaper_names(), upper_input->name()))
232+
while(is_fusable_input_op(upper_input->name()))
229233
{
230234
operation op = upper_input->get_operator();
231235
op_stream.push_back(op);
@@ -364,6 +368,18 @@ create_param_map_with_literals(module_ref mm, const module* pm, const shape& sha
364368
return ins_map;
365369
}
366370

371+
instruction_ref insert_pointwise(module& m,
372+
instruction_ref ins,
373+
const operation& op,
374+
const std::vector<instruction_ref>& inputs,
375+
const std::vector<module_ref>& mod_args)
376+
{
377+
// Only used in assert
378+
(void)mod_args;
379+
assert(mod_args.empty());
380+
return insert_common_op(m, ins, op, inputs, {.common_type = false});
381+
}
382+
367383
instruction_ref unroll_pointwise(module& main_mod,
368384
instruction_ref pos,
369385
const operation& op,
@@ -501,9 +517,7 @@ MIGRAPHX_PRED_MATCHER(mlir_split_reduce, instruction_ref ins)
501517
{
502518
if(ins->name() != "split_fused_reduce")
503519
return false;
504-
auto* mod_arg = ins->module_inputs().front();
505-
auto supported_reshapes = reshaper_names();
506-
supported_reshapes.erase("slice");
520+
auto* mod_arg = ins->module_inputs().front();
507521
std::unordered_set<std::string> builtins = {"@param", "@literal", "@return"};
508522
for(const auto i : iterator_for(*mod_arg))
509523
{
@@ -627,12 +641,19 @@ struct find_mlir_fused_ops
627641
{
628642
mlir_mode conv_mode = mlir_mode::none;
629643
mlir_mode dot_mode = mlir_mode::none;
644+
645+
static auto make_conv_dot_reshaper_names()
646+
{
647+
auto names = reshaper_names();
648+
names.erase("broadcast");
649+
names.erase("multibroadcast");
650+
return names;
651+
}
652+
630653
auto matcher() const
631654
{
632-
auto reshapes = reshaper_names();
633-
// slice is not supported
634-
reshapes.erase("slice");
635-
auto dot_or_conv = match::skip(match::name(reshapes))(
655+
static const auto conv_dot_reshaper_names = make_conv_dot_reshaper_names();
656+
auto dot_or_conv = match::skip(match::name(conv_dot_reshaper_names))(
636657
match::any_of(is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)).bind("gemm_based_op"));
637658
return mlir_pointwise()(match::any_of[match::inputs()](dot_or_conv.bind("x")));
638659
}
@@ -650,68 +671,62 @@ struct find_mlir_fused_ops
650671
return i != x_ins and reaches(gemm_based_op, i);
651672
}))
652673
return;
653-
auto names = pm->get_parameter_names();
654-
std::sort(names.begin(), names.end());
674+
675+
std::unordered_map<instruction_ref, instruction_ref> map_ins;
655676
module_ref mm = mpm.create_module("mlir_" + pm->name());
656677
mm->set_bypass();
657-
auto [anchor_op, top_inputs] = fuse_input_ops_and_gemm_based_op(
658-
mm, gemm_based_op->inputs(), gemm_based_op->get_operator());
659-
std::unordered_map<instruction_ref, instruction_ref> param_map =
660-
create_param_map_with_literals(mm, pm, pw_ins->get_shape());
661-
auto [upper_input, op_stream] = get_fusable_input_op_stream(x_ins);
662-
assert(upper_input == gemm_based_op);
663-
auto prev_input = anchor_op;
664-
for(const auto& op : reverse(op_stream))
665-
{
666-
prev_input = mm->add_instruction(op, {prev_input});
667-
}
668-
assert(prev_input->get_shape().lens() == x_ins->get_shape().lens());
669-
param_map[x_ins] = prev_input; // this is to avoid adding parameter for gemm/conv reshaped
670-
// input to pointwise in new fused module
678+
fuse_input_ops(mm, gemm_based_op->inputs(), &map_ins);
679+
671680
bool gemm_has_multi_outs = gemm_based_op->outputs().size() > 1;
672-
auto reshaped_gemm = x_ins;
673-
std::vector<instruction_ref> reshapes_vec;
674-
while(reshaped_gemm != gemm_based_op)
681+
std::vector<instruction_ref> inss_to_insert;
682+
auto reshape_ins = x_ins;
683+
for(; reshape_ins != gemm_based_op; reshape_ins = reshape_ins->inputs().front())
675684
{
676-
reshapes_vec.push_back(reshaped_gemm);
677-
gemm_has_multi_outs = gemm_has_multi_outs or reshaped_gemm->outputs().size() > 1;
678-
reshaped_gemm = reshaped_gemm->inputs().at(0);
685+
inss_to_insert.push_back(reshape_ins);
686+
gemm_has_multi_outs |= reshape_ins->outputs().size() > 1;
679687
}
680-
reshapes_vec.push_back(reshaped_gemm);
688+
inss_to_insert.push_back(gemm_based_op);
689+
std::reverse(inss_to_insert.begin(), inss_to_insert.end());
690+
mm->add_instructions(inss_to_insert, &map_ins);
681691

682-
auto return_vals = mm->fuse(*pm, pw_ins->inputs(), &param_map);
692+
fuse_input_ops(mm, pw_ins->inputs(), &map_ins);
693+
auto rins = mm->fuse(*pm, pw_ins->inputs(), &map_ins, &insert_pointwise);
683694
if(gemm_has_multi_outs)
684695
{
685-
return_vals.insert(return_vals.begin(), anchor_op);
696+
rins.push_back(map_ins.at(gemm_based_op));
686697
}
687-
mm->add_return(return_vals);
698+
mm->add_return(rins);
688699

689-
std::vector<instruction_ref> inputs;
690-
std::copy_if(pw_ins->inputs().begin(),
691-
pw_ins->inputs().end(),
692-
std::back_inserter(inputs),
693-
[&](auto input) { return input != x_ins; });
694-
inputs.insert(inputs.end(), top_inputs.begin(), top_inputs.end());
700+
auto inputs = find_inputs(map_ins, &mpm.get_module(), mm);
701+
auto fused_ins = mpm.get_module().insert_instruction(
702+
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
695703
if(gemm_has_multi_outs)
696704
{
697-
auto fused_ins = mpm.get_module().insert_instruction(
698-
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
699-
mpm.get_module().replace_instruction(
700-
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 1}}), fused_ins);
701705
auto dot_ins = mpm.get_module().insert_instruction(
702-
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
703-
// move all the reshape instructions and original GEMM instruction after the fused op to
704-
// avoid generating invalid migraphx program
705-
for(const auto& orig_i : reverse(reshapes_vec))
706+
pw_ins,
707+
migraphx::make_op("get_tuple_elem", {{"index", rins.size() - 1}}),
708+
fused_ins);
709+
710+
// move all the reshape instructions after the fused op to avoid
711+
// generating invalid migraphx program since the reshapes can be
712+
// used by the replaced dot_ins
713+
for(instruction_ref x : inss_to_insert)
706714
{
707-
mpm.get_module().move_instruction(orig_i, pw_ins);
715+
if(x == gemm_based_op)
716+
continue;
717+
mpm.get_module().move_instruction(x, pw_ins);
708718
}
719+
709720
mpm.get_module().replace_instruction(gemm_based_op, dot_ins);
721+
if(rins.size() == 2)
722+
{
723+
mpm.get_module().replace_instruction(
724+
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
725+
}
710726
}
711727
else
712728
{
713-
mpm.get_module().replace_instruction(
714-
pw_ins, mlir_op{gemm_based_op->get_operator()}, mlir_contiguous(mpm, inputs), {mm});
729+
mpm.get_module().replace_instruction(pw_ins, fused_ins);
715730
}
716731
}
717732
};
@@ -851,9 +866,8 @@ struct find_mlir_standalone_attention_op
851866
map_main_to_mattn[fused_reduce] = softmax;
852867

853868
// all preceeding ops should be fusable ops
854-
if(not std::all_of(m_gemm1, softmax, [](auto i) {
855-
return (is_pointwise_op_supported_by_mlir(i) or
856-
contains(reshaper_names(), i.name()));
869+
if(not std::all_of(m_gemm1, softmax, [](const instruction& i) {
870+
return (is_pointwise_op_supported_by_mlir(i) or is_fusable_input_op(i.name()));
857871
}))
858872
return;
859873

@@ -938,18 +952,6 @@ struct find_pointwise_mlir
938952
return contains(op_names, op_ins->name());
939953
}
940954

941-
static instruction_ref insert_pointwise(module& m,
942-
instruction_ref ins,
943-
const operation& op,
944-
const std::vector<instruction_ref>& inputs,
945-
const std::vector<module_ref>& mod_args)
946-
{
947-
// Only used in assert
948-
(void)mod_args;
949-
assert(mod_args.empty());
950-
return insert_common_op(m, ins, op, inputs, {.common_type = false});
951-
}
952-
953955
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
954956
{
955957
auto ins = r.result;

0 commit comments

Comments
 (0)