Skip to content

Commit 4ec35e5

Browse files
authored
Add GatherND operator (#1089)
Add ref and gpu implementations for ONNX op GatherND Resolves #1032
1 parent 4c72cc9 commit 4ec35e5

17 files changed

+671
-3
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ register_migraphx_ops(
109109
flatten
110110
floor
111111
gather
112+
gathernd
112113
get_tuple_elem
113114
greater
114115
gru

src/include/migraphx/op/gathernd.hpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#ifndef MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
2+
#define MIGRAPHX_GUARD_OPERATORS_GATHERND_HPP
3+
4+
#include <migraphx/check_shapes.hpp>
5+
#include <migraphx/shape_for_each.hpp>
6+
#include <migraphx/par_for.hpp>
7+
8+
namespace migraphx {
9+
inline namespace MIGRAPHX_INLINE_NS {
10+
namespace op {
11+
12+
struct gathernd
13+
{
14+
int batch_dims = 0;
15+
16+
template <class Self, class F>
17+
static auto reflect(Self& self, F f)
18+
{
19+
return pack(f(self.batch_dims, "batch_dims"));
20+
}
21+
22+
std::string name() const { return "gathernd"; }
23+
24+
shape compute_shape(std::vector<shape> inputs) const
25+
{
26+
check_shapes{inputs, *this}.has(2);
27+
auto r = inputs.front().lens().size();
28+
auto q = inputs.back().lens().size();
29+
auto k = inputs.back().lens().back();
30+
if(k > r - batch_dims)
31+
{
32+
MIGRAPHX_THROW("GATHERND: Indices of length " + std::to_string(k) +
33+
" cannot be used to access data of rank " +
34+
std::to_string(r - batch_dims));
35+
}
36+
auto indices_lens_iter = inputs.back().lens().begin();
37+
auto output_lens_size = q + r - k - batch_dims - 1;
38+
std::vector<std::size_t> output_lens(output_lens_size);
39+
std::copy(indices_lens_iter, indices_lens_iter + (q - 1), output_lens.begin());
40+
if(k < r - batch_dims)
41+
{
42+
auto data_lens = inputs.front().lens();
43+
std::copy(
44+
data_lens.begin() + batch_dims + k, data_lens.end(), output_lens.begin() + q - 1);
45+
}
46+
shape output_shape{inputs.front().type(), output_lens};
47+
return output_shape;
48+
}
49+
50+
argument compute(const shape& output_shape, std::vector<argument> args) const
51+
{
52+
argument result{output_shape};
53+
visit_all(result, args[0])([&](auto output, auto data) {
54+
args[1].visit([&](auto indices) {
55+
auto indices_shape = indices.get_shape();
56+
auto indices_shape_lens = indices_shape.lens();
57+
auto data_shape = data.get_shape();
58+
auto data_shape_lens = data_shape.lens();
59+
auto k = indices_shape.lens().back();
60+
const auto num_slice_dims = k;
61+
std::size_t num_slices = std::accumulate(indices_shape_lens.begin(),
62+
indices_shape_lens.end() - 1,
63+
1,
64+
std::multiplies<std::size_t>());
65+
std::size_t slice_size = std::accumulate(data_shape_lens.begin() + k + batch_dims,
66+
data_shape_lens.end(),
67+
1,
68+
std::multiplies<std::size_t>());
69+
std::size_t num_batches = std::accumulate(data_shape_lens.begin(),
70+
data_shape_lens.begin() + batch_dims,
71+
1,
72+
std::multiplies<std::size_t>());
73+
std::size_t data_batch_stride =
74+
std::accumulate(data_shape_lens.begin() + batch_dims,
75+
data_shape_lens.end(),
76+
1,
77+
std::multiplies<std::size_t>());
78+
auto num_slices_per_batch = num_slices / num_batches;
79+
80+
std::vector<std::size_t> sizes_from_slice_dims(num_slice_dims);
81+
{
82+
auto running_product = slice_size;
83+
for(std::size_t i = 0; i < num_slice_dims; ++i)
84+
{
85+
sizes_from_slice_dims[num_slice_dims - 1 - i] = running_product;
86+
running_product *= data_shape_lens[batch_dims + num_slice_dims - 1 - i];
87+
}
88+
}
89+
90+
std::vector<std::size_t> input_slice_offsets(num_slices);
91+
par_for(num_slices, [&](const auto i) {
92+
std::size_t batch_idx = i / num_slices_per_batch;
93+
94+
auto slice_indices = indices.begin() + (i * num_slice_dims);
95+
std::size_t relative_slice_offset = 0;
96+
for(size_t dim_idx = 0; dim_idx < num_slice_dims; ++dim_idx)
97+
{
98+
int64_t index = *(slice_indices + dim_idx);
99+
const std::size_t input_dim_idx = batch_dims + dim_idx;
100+
const auto input_dim = data_shape_lens[input_dim_idx];
101+
if(index < -static_cast<int64_t>(input_dim) or
102+
index >= static_cast<int64_t>(input_dim))
103+
MIGRAPHX_THROW("GatherND: index " + std::to_string(index) +
104+
" is out of bounds for dim of len " +
105+
std::to_string(input_dim));
106+
if(index < 0)
107+
index += input_dim;
108+
109+
relative_slice_offset += index * sizes_from_slice_dims[dim_idx];
110+
}
111+
112+
input_slice_offsets[i] =
113+
(batch_idx * data_batch_stride) + relative_slice_offset;
114+
});
115+
116+
par_for(num_slices * slice_size, [&](const auto i) {
117+
auto slice_offset = input_slice_offsets[i / slice_size];
118+
output[i] = data[slice_offset + i % slice_size];
119+
});
120+
});
121+
});
122+
123+
return result;
124+
}
125+
};
126+
127+
} // namespace op
128+
} // namespace MIGRAPHX_INLINE_NS
129+
} // namespace migraphx
130+
131+
#endif

src/include/migraphx/operators.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <migraphx/op/flatten.hpp>
3636
#include <migraphx/op/floor.hpp>
3737
#include <migraphx/op/gather.hpp>
38+
#include <migraphx/op/gathernd.hpp>
3839
#include <migraphx/op/get_tuple_elem.hpp>
3940
#include <migraphx/op/greater.hpp>
4041
#include <migraphx/op/gru.hpp>

src/onnx/parse_generic_op.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ struct parse_generic_op : op_parser<parse_generic_op>
2828
{"Flatten", "flatten"},
2929
{"Floor", "floor"},
3030
{"Gather", "gather"},
31+
{"GatherND", "gathernd"},
3132
{"Identity", "identity"},
3233
{"IsNaN", "isnan"},
3334
{"LeakyRelu", "leaky_relu"},

src/targets/gpu/jit/gathernd.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#include <migraphx/gpu/compiler.hpp>
2+
#include <migraphx/make_op.hpp>
3+
#include <migraphx/gpu/context.hpp>
4+
5+
#include <migraphx/gpu/compile_hip_code_object.hpp>
6+
#include <migraphx/gpu/compile_hip.hpp>
7+
#include <migraphx/ranges.hpp>
8+
#include <migraphx/reduce_dims.hpp>
9+
#include <migraphx/stringutils.hpp>
10+
#include <migraphx/dead_code_elimination.hpp>
11+
#include <migraphx/eliminate_common_subexpression.hpp>
12+
#include <migraphx/module.hpp>
13+
#include <migraphx/pass_manager.hpp>
14+
15+
namespace migraphx {
16+
inline namespace MIGRAPHX_INLINE_NS {
17+
namespace gpu {
18+
19+
// NOLINTNEXTLINE
20+
static const char* const gathernd_kernel = R"__migraphx__(
21+
#include <migraphx/kernels/gathernd.hpp>
22+
#include <migraphx/kernels/basic_ops.hpp>
23+
#include <migraphx/kernels/integral_constant.hpp>
24+
#include <migraphx/kernels/generic_constant.hpp>
25+
#include <args.hpp>
26+
27+
namespace migraphx {
28+
29+
extern "C" {
30+
31+
__global__ void gathernd_kernel(void* in_data, void* in_indices, void* output)
32+
{
33+
make_tensors()(in_data, in_indices, output)([](auto&&... xs) {
34+
auto settings = make_gathernd_settings(MIGRAPHX_MAKE_CONSTANT(int64_t{BATCH_DIMS}));
35+
gathernd(xs..., settings);
36+
});
37+
}
38+
39+
}
40+
41+
} // namespace migraphx
42+
43+
)__migraphx__";
44+
45+
struct gathernd_compiler : compiler<gathernd_compiler>
46+
{
47+
std::vector<std::string> names() const { return {"gathernd"}; }
48+
49+
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
50+
{
51+
hip_compile_options options;
52+
auto out_s = inputs.back();
53+
options.set_launch_params(v, compute_global_for(ctx, out_s.elements()));
54+
options.inputs = inputs;
55+
options.output = out_s;
56+
options.kernel_name = "gathernd_kernel";
57+
options.virtual_inputs = inputs;
58+
59+
// batch_dims
60+
assert(v.contains("batch_dims"));
61+
auto batch_dims = v.at("batch_dims").to<int64_t>();
62+
options.params += " -DBATCH_DIMS=" + std::to_string(batch_dims);
63+
64+
return compile_hip_code_object(gathernd_kernel, options);
65+
}
66+
67+
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
68+
{
69+
return replace(compile_op(ctx, to_shapes(ins->inputs()), op.to_value()));
70+
}
71+
};
72+
73+
} // namespace gpu
74+
} // namespace MIGRAPHX_INLINE_NS
75+
} // namespace migraphx

src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@ struct greater
2121
}
2222
};
2323

24+
template <class InputIt, class T, class BinaryOperation>
25+
constexpr T accumulate(InputIt first, InputIt last, T init, BinaryOperation op)
26+
{
27+
for(; first != last; ++first)
28+
{
29+
init = op(std::move(init), *first);
30+
}
31+
return init;
32+
}
33+
2434
template <class InputIt, class OutputIt>
2535
constexpr OutputIt copy(InputIt first, InputIt last, OutputIt d_first)
2636
{
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#ifndef MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
2+
#define MIGRAPHX_GUARD_KERNELS_GATHERND_HPP
3+
4+
#include <migraphx/kernels/index.hpp>
5+
#include <migraphx/kernels/algorithm.hpp>
6+
7+
namespace migraphx {
8+
9+
template <class T>
10+
struct gathernd_settings
11+
{
12+
T batch_dims{};
13+
};
14+
15+
template <class... Ts>
16+
constexpr gathernd_settings<Ts...> make_gathernd_settings(Ts... xs)
17+
{
18+
return {xs...};
19+
}
20+
21+
template <class T, class U, class V, class Settings>
22+
__device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t, Settings s)
23+
{
24+
auto ind = make_index();
25+
auto batch_dims = s.batch_dims;
26+
auto output_shape = output_t.get_shape();
27+
auto indices_shape = indices_t.get_shape();
28+
auto data_shape = data_t.get_shape();
29+
30+
auto indices_shape_lens = indices_shape.lens;
31+
auto data_shape_lens = data_shape.lens;
32+
auto num_slice_dims = indices_shape_lens.back();
33+
std::size_t num_slices = accumulate(indices_shape_lens.begin(),
34+
indices_shape_lens.end() - 1,
35+
1,
36+
std::multiplies<std::size_t>());
37+
std::size_t slice_size = accumulate(data_shape_lens.begin() + num_slice_dims + batch_dims,
38+
data_shape_lens.end(),
39+
1,
40+
std::multiplies<std::size_t>());
41+
const std::size_t num_batches = accumulate(data_shape_lens.begin(),
42+
data_shape_lens.begin() + batch_dims,
43+
1,
44+
std::multiplies<std::size_t>());
45+
const std::size_t data_batch_stride = accumulate(data_shape_lens.begin() + batch_dims,
46+
data_shape_lens.end(),
47+
1,
48+
std::multiplies<std::size_t>());
49+
const auto num_slices_per_batch = num_slices / num_batches;
50+
51+
ind.global_stride(output_shape.elements(), [&](auto i) {
52+
const auto* indices_ptr = indices_t.data();
53+
const std::size_t j = i / slice_size;
54+
const std::size_t batch_idx = j / num_slices_per_batch;
55+
56+
auto* slice_indices = indices_ptr + (j * num_slice_dims);
57+
std::size_t relative_slice_offset = 0;
58+
for(std::size_t idx = 0; idx < num_slice_dims; ++idx)
59+
{
60+
int64_t index = slice_indices[idx];
61+
const std::size_t input_dim_idx = batch_dims + idx;
62+
const auto input_dim = data_shape_lens[input_dim_idx];
63+
assert(index >= -static_cast<int64_t>(input_dim) and
64+
index < static_cast<int64_t>(input_dim));
65+
if(index < 0)
66+
index += input_dim;
67+
std::size_t size_from_slice_dims =
68+
accumulate(data_shape_lens.begin() + batch_dims + idx + 1,
69+
data_shape_lens.begin() + batch_dims + num_slice_dims,
70+
slice_size,
71+
std::multiplies<std::size_t>());
72+
relative_slice_offset += index * size_from_slice_dims;
73+
}
74+
75+
auto slice_offset = (batch_idx * data_batch_stride) + relative_slice_offset;
76+
output_t[i] = data_t[slice_offset + i % slice_size];
77+
});
78+
}
79+
80+
} // namespace migraphx
81+
#endif
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
gathernd_batch_dims_test:�
2+
/
3+
data
4+
indicesy"GatherND*
5+
6+
batch_dims�gathernd_batch_dims_testZ
7+
data
8+

9+

10+

11+
Z
12+
indices
13+

14+

15+
b
16+
y
17+

18+

19+
B

test/onnx/gathernd_test.onnx

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
gathernd_test:q
2+

3+
data
4+
indicesy"GatherNDgathernd_testZ
5+
data
6+

7+

8+
Z
9+
indices
10+

11+

12+
b
13+
y
14+
15+

16+
B

0 commit comments

Comments
 (0)