Skip to content

Commit 4b20cbc

Browse files
authored
Dont use mixed layouts with convolution (#3587) (#3614)
1 parent f7eb605 commit 4b20cbc

File tree

7 files changed

+226
-53
lines changed

7 files changed

+226
-53
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ add_library(migraphx
6666
insert_pad.cpp
6767
instruction.cpp
6868
json.cpp
69-
layout_nhwc.cpp
69+
layout_convolution.cpp
7070
lexing.cpp
7171
load_save.cpp
7272
make_op.cpp

src/include/migraphx/layout_nhwc.hpp renamed to src/include/migraphx/layout_convolution.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2222
* THE SOFTWARE.
2323
*/
24-
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
25-
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
24+
#ifndef MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP
25+
#define MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP
2626

2727
#include <string>
2828
#include <migraphx/instruction_ref.hpp>
@@ -34,14 +34,15 @@ inline namespace MIGRAPHX_INLINE_NS {
3434
struct module_pass_manager;
3535

3636
/**
37-
* Transform convolutions to nhwc
37+
* Transform convolutions layout
3838
*/
39-
struct MIGRAPHX_EXPORT layout_nhwc
39+
struct MIGRAPHX_EXPORT layout_convolution
4040
{
41-
std::string name() const { return "layout_nhwc"; }
41+
bool channels_last = false;
42+
std::string name() const { return "layout_convolution"; }
4243
void apply(module_pass_manager& mpm) const;
4344
};
4445

4546
} // namespace MIGRAPHX_INLINE_NS
4647
} // namespace migraphx
47-
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_NHWC_HPP
48+
#endif // MIGRAPHX_GUARD_MIGRAPHX_LAYOUT_CONVOLUTION_HPP

src/layout_nhwc.cpp renamed to src/layout_convolution.cpp

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
2222
* THE SOFTWARE.
2323
*/
24-
#include <migraphx/layout_nhwc.hpp>
24+
#include <migraphx/layout_convolution.hpp>
2525
#include <migraphx/module.hpp>
2626
#include <migraphx/instruction.hpp>
2727
#include <migraphx/iterator_for.hpp>
@@ -32,58 +32,71 @@
3232
#include <migraphx/eliminate_contiguous.hpp>
3333
#include <migraphx/dead_code_elimination.hpp>
3434
#include <migraphx/pass_manager.hpp>
35+
#include <migraphx/stringutils.hpp>
3536

3637
namespace migraphx {
3738
inline namespace MIGRAPHX_INLINE_NS {
3839

39-
template <class Predicate>
40-
std::vector<instruction_ref> find_lasts(const module& m, Predicate pred)
40+
namespace {
41+
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
4142
{
42-
std::vector<instruction_ref> result;
43-
fix([&](auto self, auto ins) {
44-
if(pred(ins))
45-
{
46-
result.push_back(ins);
47-
return;
48-
}
49-
for(auto input : ins->inputs())
50-
self(input);
51-
})(std::prev(m.end()));
52-
return result;
43+
if(lc.channels_last)
44+
{
45+
std::vector<int64_t> perm(ins->get_shape().ndim());
46+
std::iota(perm.begin() + 1, perm.end() - 1, 2);
47+
perm.back() = 1;
48+
return perm;
49+
}
50+
return find_permutation(ins->inputs().front()->get_shape());
51+
}
52+
53+
bool skip_layout(const shape& s)
54+
{
55+
return s.ndim() == 1 or s.dynamic() or s.type() == shape::tuple_type;
5356
}
5457

5558
void preserve_output_layout(module& m)
5659
{
5760
auto last = std::prev(m.end());
58-
std::vector<instruction_ref> outputs;
5961
if(last->name() == "@return")
60-
outputs = last->inputs();
61-
else
62-
outputs = {last};
63-
64-
for(auto output : outputs)
6562
{
66-
auto permutation = find_permutation(output->get_shape());
67-
auto layout = m.insert_instruction(
68-
std::next(output), make_op("layout", {{"permutation", permutation}}), output);
69-
m.replace_instruction(output, layout);
63+
std::vector<instruction_ref> outputs;
64+
std::transform(last->inputs().begin(),
65+
last->inputs().end(),
66+
std::back_inserter(outputs),
67+
[&](instruction_ref ins) {
68+
if(skip_layout(ins->get_shape()))
69+
return ins;
70+
auto permutation = find_permutation(ins->get_shape());
71+
return m.insert_instruction(
72+
last, make_op("layout", {{"permutation", permutation}}), ins);
73+
});
74+
m.replace_return(outputs);
75+
}
76+
else if(not skip_layout(last->get_shape()))
77+
{
78+
auto permutation = find_permutation(last->get_shape());
79+
m.add_instruction(make_op("layout", {{"permutation", permutation}}), last);
7080
}
7181
}
7282

73-
void transform_convolutions(module& m)
83+
void transform_convolutions(module& m, const layout_convolution& lc)
7484
{
7585
for(auto ins : iterator_for(m))
7686
{
77-
if(ins->name() != "convolution")
87+
if(not contains({"convolution", "quant_convolution"}, ins->name()))
88+
continue;
89+
if(ins->get_shape().dynamic())
7890
continue;
7991
if(ins->get_shape().lens().size() != 4)
8092
continue;
8193
auto v = ins->get_operator().to_value();
8294
if(v.at("group").to<int>() > 1)
8395
continue;
8496
auto args = ins->inputs();
97+
auto perm = get_permutation(ins, lc);
8598
std::transform(args.begin(), args.end(), args.begin(), [&](const auto& i) {
86-
return m.insert_instruction(ins, make_op("layout", {{"permutation", {0, 2, 3, 1}}}), i);
99+
return m.insert_instruction(ins, make_op("layout", {{"permutation", perm}}), i);
87100
});
88101
auto conv = m.insert_instruction(ins, ins->get_operator(), args);
89102
auto c = m.insert_instruction(ins, make_op("contiguous"), conv);
@@ -102,11 +115,12 @@ void remove_layout(module& m)
102115
m.replace_instruction(ins, ins->inputs().front());
103116
}
104117
}
118+
} // namespace
105119

106-
void layout_nhwc::apply(module_pass_manager& mpm) const
120+
void layout_convolution::apply(module_pass_manager& mpm) const
107121
{
108122
preserve_output_layout(mpm.get_module());
109-
transform_convolutions(mpm.get_module());
123+
transform_convolutions(mpm.get_module(), *this);
110124
mpm.run_pass(dead_code_elimination{});
111125
mpm.run_pass(eliminate_contiguous{"contiguous"});
112126
mpm.run_pass(dead_code_elimination{});

src/module.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,6 @@ instruction_ref module::replace_instruction(instruction_ref ins, instruction_ref
355355
{
356356
impl->changed.notify();
357357
assert(has_instruction(ins));
358-
assert(has_instruction(rep));
359358
assert(ins != rep);
360359

361360
if(ins == std::prev(this->end()))
@@ -541,7 +540,6 @@ instruction_ref module::insert_parameter(instruction_ref ins, std::string name,
541540
instruction_ref module::replace_return(std::vector<instruction_ref> args)
542541
{
543542
impl->changed.notify();
544-
assert(std::all_of(args.begin(), args.end(), [&](auto ins) { return has_instruction(ins); }));
545543
auto last = std::prev(this->end());
546544
// If there is no return then add a return
547545
if(last->name() != "@return")

src/targets/cpu/target.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
#include <migraphx/eliminate_identity.hpp>
3434
#include <migraphx/eliminate_pad.hpp>
3535
#include <migraphx/eliminate_convert.hpp>
36-
#include <migraphx/layout_nhwc.hpp>
3736
#include <migraphx/memory_coloring.hpp>
3837
#include <migraphx/propagate_constant.hpp>
3938
#include <migraphx/register_target.hpp>

src/targets/gpu/target.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#include <migraphx/fuse_pointwise_reduce.hpp>
3636
#include <migraphx/inline_module.hpp>
3737
#include <migraphx/insert_pad.hpp>
38-
#include <migraphx/layout_nhwc.hpp>
38+
#include <migraphx/layout_convolution.hpp>
3939
#include <migraphx/memory_coloring.hpp>
4040
#include <migraphx/normalize_ops.hpp>
4141
#include <migraphx/optimize_module.hpp>
@@ -182,7 +182,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
182182
dead_code_elimination{},
183183
rewrite_gelu{options.fast_math},
184184
optimize_module{},
185-
enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), layout_nhwc{}),
185+
layout_convolution{.channels_last = enabled(MIGRAPHX_ENABLE_NHWC{})},
186186
dead_code_elimination{},
187187
prefuse_ops{},
188188
dead_code_elimination{},

0 commit comments

Comments
 (0)