@@ -49,7 +49,7 @@ migraphx::instruction_ref add_layout_nhwc(migraphx::module& m, migraphx::instruc
4949
5050migraphx::instruction_ref add_layout_nchw (migraphx::module & m, migraphx::instruction_ref ins)
5151{
52-     return  m.add_instruction (layout (), ins);
52+     return  m.add_instruction (layout ({ 0 ,  1 ,  2 ,  3 } ), ins);
5353}
5454
5555TEST_CASE (auto_conv_nchw)
@@ -90,8 +90,22 @@ TEST_CASE(auto_conv_nhwc)
9090        auto  relu = m1.add_instruction (migraphx::make_op (" relu" 
9191        m1.add_return ({relu});
9292    }
93-     migraphx::module  m2 = m1;
9493    run_pass (m1);
94+     migraphx::module  m2;
95+     {
96+         auto  x          = m2.add_parameter (" x" 1 , 16 , 16 , 8 }});
97+         auto  xtranspose = add_layout_nchw (m2, m2.add_instruction (transpose, x));
98+         auto  w          = m2.add_literal (
99+             migraphx::generate_literal ({migraphx::shape::float_type, {16 , 3 , 3 , 8 }}));
100+         auto  wtranspose = add_layout_nchw (m2, m2.add_instruction (transpose, w));
101+         auto  conv       = m2.add_instruction (
102+             migraphx::make_op (" convolution" 
103+                                     {{" padding" 1 , 1 }}, {" stride" 2 , 2 }}, {" dilation" 1 , 1 }}}),
104+             xtranspose,
105+             wtranspose);
106+         auto  relu = add_layout_nhwc (m2, m2.add_instruction (migraphx::make_op (" relu" 
107+         m2.add_return ({relu});
108+     }
95109    EXPECT (m1.sort () == m2.sort ());
96110}
97111
0 commit comments