21
21
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22
22
* THE SOFTWARE.
23
23
*/
24
- #include < migraphx/layout_nhwc .hpp>
24
+ #include < migraphx/layout_convolution .hpp>
25
25
#include < migraphx/module.hpp>
26
26
#include < migraphx/instruction.hpp>
27
27
#include < migraphx/iterator_for.hpp>
32
32
#include < migraphx/eliminate_contiguous.hpp>
33
33
#include < migraphx/dead_code_elimination.hpp>
34
34
#include < migraphx/pass_manager.hpp>
35
+ #include < migraphx/stringutils.hpp>
35
36
36
37
namespace migraphx {
37
38
inline namespace MIGRAPHX_INLINE_NS {
38
39
39
- template < class Predicate >
40
- std::vector<instruction_ref> find_lasts ( const module& m, Predicate pred )
40
+ namespace {
41
+ std::vector<int64_t > get_permutation (instruction_ref ins, const layout_convolution& lc )
41
42
{
42
- std::vector<instruction_ref> result;
43
- fix ([&](auto self, auto ins) {
44
- if (pred (ins))
45
- {
46
- result.push_back (ins);
47
- return ;
48
- }
49
- for (auto input : ins->inputs ())
50
- self (input);
51
- })(std::prev (m.end ()));
52
- return result;
43
+ if (lc.channels_last )
44
+ {
45
+ std::vector<int64_t > perm (ins->get_shape ().ndim ());
46
+ std::iota (perm.begin () + 1 , perm.end () - 1 , 2 );
47
+ perm.back () = 1 ;
48
+ return perm;
49
+ }
50
+ return find_permutation (ins->inputs ().front ()->get_shape ());
51
+ }
52
+
53
+ bool skip_layout (const shape& s)
54
+ {
55
+ return s.ndim () == 1 or s.dynamic () or s.type () == shape::tuple_type;
53
56
}
54
57
55
58
void preserve_output_layout (module& m)
56
59
{
57
60
auto last = std::prev (m.end ());
58
- std::vector<instruction_ref> outputs;
59
61
if (last->name () == " @return" )
60
- outputs = last->inputs ();
61
- else
62
- outputs = {last};
63
-
64
- for (auto output : outputs)
65
62
{
66
- auto permutation = find_permutation (output->get_shape ());
67
- auto layout = m.insert_instruction (
68
- std::next (output), make_op (" layout" , {{" permutation" , permutation}}), output);
69
- m.replace_instruction (output, layout);
63
+ std::vector<instruction_ref> outputs;
64
+ std::transform (last->inputs ().begin (),
65
+ last->inputs ().end (),
66
+ std::back_inserter (outputs),
67
+ [&](instruction_ref ins) {
68
+ if (skip_layout (ins->get_shape ()))
69
+ return ins;
70
+ auto permutation = find_permutation (ins->get_shape ());
71
+ return m.insert_instruction (
72
+ last, make_op (" layout" , {{" permutation" , permutation}}), ins);
73
+ });
74
+ m.replace_return (outputs);
75
+ }
76
+ else if (not skip_layout (last->get_shape ()))
77
+ {
78
+ auto permutation = find_permutation (last->get_shape ());
79
+ m.add_instruction (make_op (" layout" , {{" permutation" , permutation}}), last);
70
80
}
71
81
}
72
82
73
- void transform_convolutions (module& m)
83
+ void transform_convolutions (module& m, const layout_convolution& lc )
74
84
{
75
85
for (auto ins : iterator_for (m))
76
86
{
77
- if (ins->name () != " convolution" )
87
+ if (not contains ({" convolution" , " quant_convolution" }, ins->name ()))
88
+ continue ;
89
+ if (ins->get_shape ().dynamic ())
78
90
continue ;
79
91
if (ins->get_shape ().lens ().size () != 4 )
80
92
continue ;
81
93
auto v = ins->get_operator ().to_value ();
82
94
if (v.at (" group" ).to <int >() > 1 )
83
95
continue ;
84
96
auto args = ins->inputs ();
97
+ auto perm = get_permutation (ins, lc);
85
98
std::transform (args.begin (), args.end (), args.begin (), [&](const auto & i) {
86
- return m.insert_instruction (ins, make_op (" layout" , {{" permutation" , { 0 , 2 , 3 , 1 } }}), i);
99
+ return m.insert_instruction (ins, make_op (" layout" , {{" permutation" , perm }}), i);
87
100
});
88
101
auto conv = m.insert_instruction (ins, ins->get_operator (), args);
89
102
auto c = m.insert_instruction (ins, make_op (" contiguous" ), conv);
@@ -102,11 +115,12 @@ void remove_layout(module& m)
102
115
m.replace_instruction (ins, ins->inputs ().front ());
103
116
}
104
117
}
118
+ } // namespace
105
119
106
- void layout_nhwc ::apply (module_pass_manager& mpm) const
120
+ void layout_convolution ::apply (module_pass_manager& mpm) const
107
121
{
108
122
preserve_output_layout (mpm.get_module ());
109
- transform_convolutions (mpm.get_module ());
123
+ transform_convolutions (mpm.get_module (), * this );
110
124
mpm.run_pass (dead_code_elimination{});
111
125
mpm.run_pass (eliminate_contiguous{" contiguous" });
112
126
mpm.run_pass (dead_code_elimination{});
0 commit comments