Skip to content

Commit fcad24a

Browse files
authored
GridSample Linear Sampler Refactor (#4067)
This PR refactors the linear sampler used in the GridSample operator to use less literals. Also adds support and tests for when channels > 1. Note: This PR does not change the nearest and bicubic samplers. Although they are currently not used in any models we support, they should likely be refactored as well sometime in the future.
1 parent 6e202cc commit fcad24a

File tree

8 files changed

+787
-47
lines changed

8 files changed

+787
-47
lines changed

src/onnx/parse_gridsample.cpp

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -165,26 +165,26 @@ struct grid_sampler
165165
static instruction_ref concat_on_first_dim(const onnx_parser::node_info& info,
166166
std::vector<instruction_ref> instructions)
167167
{
168-
return std::accumulate(
169-
std::next(instructions.begin()),
170-
instructions.end(),
171-
instructions.front(),
172-
[&info](auto& ret, auto& ins) {
173-
return info.add_instruction(make_op("concat", {{"axis", 0}}), ret, ins);
174-
});
168+
return std::accumulate(std::next(instructions.begin()),
169+
instructions.end(),
170+
instructions.front(),
171+
[&info](auto& ret, auto& ins) {
172+
return info.add_instruction(
173+
make_op("concat", {{"axis", 0}}), ret, ins);
174+
});
175175
}
176176

177177
static instruction_ref concat_on_dim(const onnx_parser::node_info& info,
178178
std::array<instruction_ref, 4> instructions,
179179
int64_t dim)
180180
{
181-
return std::accumulate(
182-
std::next(instructions.begin()),
183-
instructions.end(),
184-
instructions.front(),
185-
[&info, &dim](auto& ret, auto& ins) {
186-
return info.add_instruction(make_op("concat", {{"axis", dim}}), ret, ins);
187-
});
181+
return std::accumulate(std::next(instructions.begin()),
182+
instructions.end(),
183+
instructions.front(),
184+
[&info, &dim](auto& ret, auto& ins) {
185+
return info.add_instruction(
186+
make_op("concat", {{"axis", dim}}), ret, ins);
187+
});
188188
}
189189

190190
bool has_border_padding() const { return m_padding == "border"; }
@@ -296,26 +296,31 @@ struct linear_sampler : grid_sampler
296296

297297
instruction_ref sample(const onnx_parser::node_info& info)
298298
{
299-
std::vector<instruction_ref> weight_indices;
300-
std::vector<instruction_ref> xy_indices;
301-
std::vector<instruction_ref> nc_values;
302-
303-
const static auto nhw_shape = migraphx::shape{migraphx::shape::int64_type, {1, 3}};
304-
dfor(m_batch, m_out_height, m_out_width)([&](auto n, auto h, auto w) {
305-
auto nhw = info.add_literal(migraphx::literal{nhw_shape, {n, h, w}});
306-
weight_indices.push_back(nhw);
307-
for(size_t c = 0; c < m_channel; c++)
308-
{
309-
xy_indices.push_back(nhw);
310-
nc_values.push_back(info.add_literal(migraphx::literal{m_nc_shape, {n, c}}));
311-
}
299+
std::vector<float> xy_indices_data;
300+
std::vector<float> weight_indices_data;
301+
std::vector<float> nc_values_data;
302+
dfor(m_batch, m_out_height, m_out_width, m_channel)([&](auto n, auto h, auto w, auto c) {
303+
xy_indices_data.push_back(n);
304+
xy_indices_data.push_back(h);
305+
xy_indices_data.push_back(w);
306+
weight_indices_data.push_back(n);
307+
weight_indices_data.push_back(h);
308+
weight_indices_data.push_back(w);
309+
nc_values_data.push_back(n);
310+
nc_values_data.push_back(c);
312311
});
313-
314-
auto xy_indices_t = concat_on_first_dim(info, xy_indices);
315-
auto y0_samples = info.add_instruction(make_op("gathernd"), m_floor_y, xy_indices_t);
316-
auto x0_samples = info.add_instruction(make_op("gathernd"), m_floor_x, xy_indices_t);
317-
auto y1_samples = info.add_instruction(make_op("gathernd"), m_ceil_y, xy_indices_t);
318-
auto x1_samples = info.add_instruction(make_op("gathernd"), m_ceil_x, xy_indices_t);
312+
size_t num_indices = m_batch * m_out_height * m_out_width * m_channel;
313+
auto xy_indices_t = info.add_literal(migraphx::literal{
314+
migraphx::shape{migraphx::shape::float_type, {num_indices, 3}}, xy_indices_data});
315+
auto weight_index_t = info.add_literal(migraphx::literal{
316+
migraphx::shape{migraphx::shape::float_type, {num_indices, 3}}, weight_indices_data});
317+
auto nc = info.add_literal(migraphx::literal{
318+
migraphx::shape{migraphx::shape::float_type, {num_indices, 2}}, nc_values_data});
319+
320+
auto y0_samples = info.add_instruction(make_op("gathernd"), m_floor_y, xy_indices_t);
321+
auto x0_samples = info.add_instruction(make_op("gathernd"), m_floor_x, xy_indices_t);
322+
auto y1_samples = info.add_instruction(make_op("gathernd"), m_ceil_y, xy_indices_t);
323+
auto x1_samples = info.add_instruction(make_op("gathernd"), m_ceil_x, xy_indices_t);
319324

320325
auto validate_samples = [&](auto& samples, auto& max) {
321326
auto clip = info.add_common_op("clip", samples, m_zero_l, max);
@@ -338,8 +343,6 @@ struct linear_sampler : grid_sampler
338343
x1_samples = info.add_instruction(
339344
make_op("reshape", {{"dims", {x1_samples->get_shape().elements(), 1}}}), x1_samples);
340345

341-
auto nc = concat_on_first_dim(info, nc_values);
342-
343346
auto make_corner_indices = [&](auto& x, auto& y) {
344347
auto hw = info.add_instruction(make_op("concat", {{"axis", 1}}), y, x);
345348
return info.add_instruction(make_op("concat", {{"axis", 1}}), nc, hw);
@@ -356,10 +359,6 @@ struct linear_sampler : grid_sampler
356359
info.add_common_op("logical_and", x1_validation, y1_validation)};
357360

358361
std::array<instruction_ref, 4> corner_samples;
359-
auto weight_index_t = concat_on_first_dim(info, weight_indices);
360-
weight_index_t = info.add_instruction(
361-
make_op("reshape", {{"dims", {weight_indices.size(), 3}}}), weight_index_t);
362-
363362
std::transform(corner_indices.begin(),
364363
corner_indices.end(),
365364
corner_validations.begin(),
@@ -528,13 +527,13 @@ struct bicubic_sampler : grid_sampler
528527
make_op("reshape", {{"dims", {corner_weight->get_shape().elements(), 1}}}),
529528
corner_weight);
530529
});
531-
auto weights_t = std::accumulate(
532-
std::next(corner_weights.begin()),
533-
corner_weights.end(),
534-
corner_weights.front(),
535-
[&info](auto& acc, auto& ins) {
536-
return info.add_instruction(make_op("concat", {{"axis", 1}}), acc, ins);
537-
});
530+
auto weights_t = std::accumulate(std::next(corner_weights.begin()),
531+
corner_weights.end(),
532+
corner_weights.front(),
533+
[&info](auto& acc, auto& ins) {
534+
return info.add_instruction(
535+
make_op("concat", {{"axis", 1}}), acc, ins);
536+
});
538537
return info.add_instruction(make_op("reshape", {{"dims", out_lens}}), weights_t);
539538
}
540539

@@ -664,7 +663,7 @@ struct parse_gridsample : op_parser<parse_gridsample>
664663
std::vector<instruction_ref> args) const
665664
{
666665
bool align_corners = false;
667-
// Note: defult mode can be linear or bilinear depending on the onnx version
666+
// Note: default mode can be linear or bilinear depending on the onnx version
668667
std::string mode = "linear";
669668
std::string padding_mode = "zeros";
670669

test/onnx/gen_onnx.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4462,6 +4462,41 @@ def gridsample_test():
44624462

44634463
return ([node], [x, grid], [y])
44644464

4465+
@onnx_test()
4466+
def gridsample_channel_test():
4467+
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 4, 4])
4468+
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
4469+
[1, 6, 6, 2])
4470+
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 6, 6])
4471+
4472+
node = onnx.helper.make_node(
4473+
"GridSample",
4474+
inputs=["x", "grid"],
4475+
outputs=["y"],
4476+
mode="bilinear",
4477+
padding_mode="border",
4478+
align_corners=1,
4479+
)
4480+
4481+
return ([node], [x, grid], [y])
4482+
4483+
@onnx_test()
4484+
def gridsample_512x512_test():
4485+
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 3, 512, 512])
4486+
grid = helper.make_tensor_value_info('grid', TensorProto.FLOAT,
4487+
[1, 512, 512, 2])
4488+
y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [1, 3, 512, 512])
4489+
4490+
node = onnx.helper.make_node(
4491+
"GridSample",
4492+
inputs=["x", "grid"],
4493+
outputs=["y"],
4494+
mode="bilinear",
4495+
padding_mode="border",
4496+
align_corners=1,
4497+
)
4498+
4499+
return ([node], [x, grid], [y])
44654500

44664501
@onnx_test()
44674502
def gridsample_half_test():

test/onnx/gridsample_512x512_test.onnx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
 gridsample_512x512_test:�
2+
^
3+
x
4+
gridy"
5+
GridSample*
6+
align_corners�*
7+
mode"bilinear�*
8+
padding_mode"border�gridsample_512x512_testZ
9+
x
10+

11+

12+

13+
�
14+
�Z
15+
grid
16+

17+

18+
�
19+
�
20+
b
21+
y
22+

23+

24+

25+
�
26+
�B

test/onnx/gridsample_channel_test.onnx

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
 gridsample_channel_test:�
2+
^
3+
x
4+
gridy"
5+
GridSample*
6+
align_corners�*
7+
mode"bilinear�*
8+
padding_mode"border�gridsample_channel_testZ
9+
x
10+

11+

12+

13+

14+
Z
15+
grid
16+

17+

18+

19+

20+
b
21+
y
22+

23+

24+

25+

26+
B

0 commit comments

Comments
 (0)