@@ -165,26 +165,26 @@ struct grid_sampler
165165    static  instruction_ref concat_on_first_dim (const  onnx_parser::node_info& info,
166166                                               std::vector<instruction_ref> instructions)
167167    {
168-         return  std::accumulate (
169-             std::next ( instructions.begin () ),
170-             instructions.end (),
171-             instructions. front (), 
172-             [&info]( auto & ret,  auto & ins) { 
173-                 return  info. add_instruction ( make_op (" concat" " axis" 0 }}), ret, ins);
174-             });
168+         return  std::accumulate (std::next (instructions. begin ()), 
169+                                 instructions.end ( ),
170+                                 instructions.front (),
171+                                [&info]( auto & ret,  auto & ins) { 
172+                                     return  info. add_instruction ( 
173+                                         make_op (" concat" " axis" 0 }}), ret, ins);
174+                                 });
175175    }
176176
177177    static  instruction_ref concat_on_dim (const  onnx_parser::node_info& info,
178178                                         std::array<instruction_ref, 4 > instructions,
179179                                         int64_t  dim)
180180    {
181-         return  std::accumulate (
182-             std::next ( instructions.begin () ),
183-             instructions.end (),
184-             instructions. front (), 
185-             [&info, &dim]( auto & ret,  auto & ins) { 
186-                 return  info. add_instruction ( make_op (" concat" " axis" 
187-             });
181+         return  std::accumulate (std::next (instructions. begin ()), 
182+                                 instructions.end ( ),
183+                                 instructions.front (),
184+                                [&info, &dim]( auto & ret,  auto & ins) { 
185+                                     return  info. add_instruction ( 
186+                                         make_op (" concat" " axis" 
187+                                 });
188188    }
189189
190190    bool  has_border_padding () const  { return  m_padding == " border" 
@@ -296,26 +296,31 @@ struct linear_sampler : grid_sampler
296296
297297    instruction_ref sample (const  onnx_parser::node_info& info)
298298    {
299-         std::vector<instruction_ref> weight_indices;
300-         std::vector<instruction_ref> xy_indices;
301-         std::vector<instruction_ref> nc_values;
302- 
303-         const  static  auto  nhw_shape = migraphx::shape{migraphx::shape::int64_type, {1 , 3 }};
304-         dfor (m_batch, m_out_height, m_out_width)([&](auto  n, auto  h, auto  w) {
305-             auto  nhw = info.add_literal (migraphx::literal{nhw_shape, {n, h, w}});
306-             weight_indices.push_back (nhw);
307-             for (size_t  c = 0 ; c < m_channel; c++)
308-             {
309-                 xy_indices.push_back (nhw);
310-                 nc_values.push_back (info.add_literal (migraphx::literal{m_nc_shape, {n, c}}));
311-             }
299+         std::vector<float > xy_indices_data;
300+         std::vector<float > weight_indices_data;
301+         std::vector<float > nc_values_data;
302+         dfor (m_batch, m_out_height, m_out_width, m_channel)([&](auto  n, auto  h, auto  w, auto  c) {
303+             xy_indices_data.push_back (n);
304+             xy_indices_data.push_back (h);
305+             xy_indices_data.push_back (w);
306+             weight_indices_data.push_back (n);
307+             weight_indices_data.push_back (h);
308+             weight_indices_data.push_back (w);
309+             nc_values_data.push_back (n);
310+             nc_values_data.push_back (c);
312311        });
313- 
314-         auto  xy_indices_t  = concat_on_first_dim (info, xy_indices);
315-         auto  y0_samples   = info.add_instruction (make_op (" gathernd" xy_indices_t );
316-         auto  x0_samples   = info.add_instruction (make_op (" gathernd" xy_indices_t );
317-         auto  y1_samples   = info.add_instruction (make_op (" gathernd" xy_indices_t );
318-         auto  x1_samples   = info.add_instruction (make_op (" gathernd" xy_indices_t );
312+         size_t  num_indices  = m_batch * m_out_height * m_out_width * m_channel;
313+         auto  xy_indices_t    = info.add_literal (migraphx::literal{
314+             migraphx::shape{migraphx::shape::float_type, {num_indices, 3 }}, xy_indices_data});
315+         auto  weight_index_t  = info.add_literal (migraphx::literal{
316+             migraphx::shape{migraphx::shape::float_type, {num_indices, 3 }}, weight_indices_data});
317+         auto  nc             = info.add_literal (migraphx::literal{
318+             migraphx::shape{migraphx::shape::float_type, {num_indices, 2 }}, nc_values_data});
319+ 
320+         auto  y0_samples = info.add_instruction (make_op (" gathernd" xy_indices_t );
321+         auto  x0_samples = info.add_instruction (make_op (" gathernd" xy_indices_t );
322+         auto  y1_samples = info.add_instruction (make_op (" gathernd" xy_indices_t );
323+         auto  x1_samples = info.add_instruction (make_op (" gathernd" xy_indices_t );
319324
320325        auto  validate_samples = [&](auto & samples, auto & max) {
321326            auto  clip       = info.add_common_op (" clip" 
@@ -338,8 +343,6 @@ struct linear_sampler : grid_sampler
338343        x1_samples = info.add_instruction (
339344            make_op (" reshape" " dims" get_shape ().elements (), 1 }}}), x1_samples);
340345
341-         auto  nc = concat_on_first_dim (info, nc_values);
342- 
343346        auto  make_corner_indices = [&](auto & x, auto & y) {
344347            auto  hw = info.add_instruction (make_op (" concat" " axis" 1 }}), y, x);
345348            return  info.add_instruction (make_op (" concat" " axis" 1 }}), nc, hw);
@@ -356,10 +359,6 @@ struct linear_sampler : grid_sampler
356359            info.add_common_op (" logical_and" 
357360
358361        std::array<instruction_ref, 4 > corner_samples;
359-         auto  weight_index_t  = concat_on_first_dim (info, weight_indices);
360-         weight_index_t       = info.add_instruction (
361-             make_op (" reshape" " dims" size (), 3 }}}), weight_index_t );
362- 
363362        std::transform (corner_indices.begin (),
364363                       corner_indices.end (),
365364                       corner_validations.begin (),
@@ -528,13 +527,13 @@ struct bicubic_sampler : grid_sampler
528527                make_op (" reshape" " dims" get_shape ().elements (), 1 }}}),
529528                corner_weight);
530529        });
531-         auto  weights_t  = std::accumulate (
532-             std::next ( corner_weights.begin () ),
533-             corner_weights.end (),
534-             corner_weights. front (), 
535-             [&info]( auto & acc,  auto & ins) { 
536-                 return  info. add_instruction ( make_op (" concat" " axis" 1 }}), acc, ins);
537-             });
530+         auto  weights_t  = std::accumulate (std::next (corner_weights. begin ()), 
531+                                           corner_weights.end ( ),
532+                                           corner_weights.front (),
533+                                          [&info]( auto & acc,  auto & ins) { 
534+                                               return  info. add_instruction ( 
535+                                                   make_op (" concat" " axis" 1 }}), acc, ins);
536+                                           });
538537        return  info.add_instruction (make_op (" reshape" " dims" weights_t );
539538    }
540539
@@ -664,7 +663,7 @@ struct parse_gridsample : op_parser<parse_gridsample>
664663                          std::vector<instruction_ref> args) const 
665664    {
666665        bool  align_corners = false ;
667-         //  Note: defult  mode can be linear or bilinear depending on the onnx version
666+         //  Note: default  mode can be linear or bilinear depending on the onnx version
668667        std::string mode         = " linear" 
669668        std::string padding_mode = " zeros" 
670669
0 commit comments