Skip to content

Commit 768b4dd

Browse files
authored
Layout convolution as NHWC or NCHW only (#3729)
1 parent b2cc2fb commit 768b4dd

File tree

2 files changed

+22
-5
lines changed

2 files changed

+22
-5
lines changed

src/layout_convolution.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,17 @@ inline namespace MIGRAPHX_INLINE_NS {
4040
namespace {
4141
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
4242
{
43+
std::vector<int64_t> perm(ins->get_shape().ndim());
4344
if(lc.channels_last)
4445
{
45-
std::vector<int64_t> perm(ins->get_shape().ndim());
4646
std::iota(perm.begin() + 1, perm.end() - 1, 2);
4747
perm.back() = 1;
48-
return perm;
4948
}
50-
return find_permutation(ins->inputs().front()->get_shape());
49+
else
50+
{
51+
std::iota(perm.begin(), perm.end(), 0);
52+
}
53+
return perm;
5154
}
5255

5356
std::vector<int64_t> get_default_permutation(instruction_ref ins)

test/layout_convolution.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruc
4949

5050
migraphx::instruction_ref add_layout_nchw(migraphx::module& m, migraphx::instruction_ref ins)
5151
{
52-
return m.add_instruction(layout(), ins);
52+
return m.add_instruction(layout({0, 1, 2, 3}), ins);
5353
}
5454

5555
TEST_CASE(auto_conv_nchw)
@@ -90,8 +90,22 @@ TEST_CASE(auto_conv_nhwc)
9090
auto relu = m1.add_instruction(migraphx::make_op("relu"), conv);
9191
m1.add_return({relu});
9292
}
93-
migraphx::module m2 = m1;
9493
run_pass(m1);
94+
migraphx::module m2;
95+
{
96+
auto x = m2.add_parameter("x", {migraphx::shape::float_type, {1, 16, 16, 8}});
97+
auto xtranspose = add_layout_nchw(m2, m2.add_instruction(transpose, x));
98+
auto w = m2.add_literal(
99+
migraphx::generate_literal({migraphx::shape::float_type, {16, 3, 3, 8}}));
100+
auto wtranspose = add_layout_nchw(m2, m2.add_instruction(transpose, w));
101+
auto conv = m2.add_instruction(
102+
migraphx::make_op("convolution",
103+
{{"padding", {1, 1}}, {"stride", {2, 2}}, {"dilation", {1, 1}}}),
104+
xtranspose,
105+
wtranspose);
106+
auto relu = add_layout_nhwc(m2, m2.add_instruction(migraphx::make_op("relu"), conv));
107+
m2.add_return({relu});
108+
}
95109
EXPECT(m1.sort() == m2.sort());
96110
}
97111

0 commit comments

Comments
 (0)