Skip to content

Commit f0a8cf7

Browse files
authored
Support additional types for linear Resize op (#4115)
Adds a convert inside of the resize parser for when the data is not float type.
1 parent 3de2b9e commit f0a8cf7

File tree

5 files changed

+165
-1
lines changed

5 files changed

+165
-1
lines changed

src/onnx/parse_resize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ struct parse_resize : op_parser<parse_resize>
456456
for(auto idx = resized_ct; idx != 0u; --idx)
457457
{
458458
dim_lens[0] /= 2; // halved for 2 slices of data (hi & low below)
459-
shape dim_s{shape::float_type, dim_lens};
459+
shape dim_s{in_s.type(), dim_lens};
460460
const auto& dim_delta = delta[idx - 1];
461461
std::vector<float> delta_data;
462462
for(std::size_t j = 0; j < dim_lens[0] / out_lens[0]; ++j)

test/onnx/gen_onnx.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12305,6 +12305,24 @@ def resize_downsample_linear_test():
1230512305

1230612306
return ([node], [X], [Y], [scale_tensor])
1230712307

12308+
@onnx_test()
12309+
def resize_downsample_linear_half_test():
12310+
scales = np.array([1.0, 1.0, 0.6, 0.5], dtype=np.float16)
12311+
scale_tensor = helper.make_tensor(name='scales',
12312+
data_type=TensorProto.FLOAT16,
12313+
dims=scales.shape,
12314+
vals=scales.flatten().astype(np.float16))
12315+
12316+
X = helper.make_tensor_value_info('X', TensorProto.FLOAT16, [1, 1, 2, 4])
12317+
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT16, [])
12318+
12319+
node = onnx.helper.make_node('Resize',
12320+
inputs=['X', '', 'scales'],
12321+
outputs=['Y'],
12322+
mode='linear')
12323+
12324+
return ([node], [X], [Y], [scale_tensor])
12325+
1230812326

1230912327
@onnx_test()
1231012328
def resize_linear_non_const_test():
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include <onnx_test.hpp>
26+
27+
/* IR for the test case below:
28+
module: "main"
29+
@0 = @literal{0.333008, 0.333008} -> half_type, {1, 1, 1, 2}, {2, 2, 2, 1}
30+
@1 = @literal{0.5, 0.5, 0.5, 0.5} -> half_type, {2, 1, 1, 2}, {2, 2, 2, 1}
31+
@2 = @literal{0, 2, 4, 6, 1, 3, 5, 7} -> int32_type, {4, 1, 1, 2}, {2, 2, 2, 1}
32+
X = @param:X -> half_type, {1, 1, 2, 4}, {8, 8, 4, 1}
33+
@4 = @literal{1, 1, 0.600098, 0.5} -> half_type, {4}, {1}
34+
@5 = undefined -> float_type, {}, {}
35+
@6 = reshape[dims={8}](X) -> half_type, {8}, {1}
36+
@7 = gather[axis=0](@6,@2) -> half_type, {4, 1, 1, 2}, {2, 2, 2, 1}
37+
@8 = slice[axes={0},starts={0},ends={2}](@7) -> half_type, {2, 1, 1, 2}, {2, 2, 2, 1}
38+
@9 = slice[axes={0},starts={2},ends={4}](@7) -> half_type, {2, 1, 1, 2}, {2, 2, 2, 1}
39+
@10 = sub(@9,@8) -> half_type, {2, 1, 1, 2}, {2, 2, 2, 1}
40+
@11 = mul(@10,@1) -> half_type, {2, 1, 1, 2}, {2, 2, 2, 1}
41+
@12 = add(@11,@8) -> half_type, {2, 1, 1, 2}, {2, 2, 2, 1}
42+
@13 = slice[axes={0},starts={0},ends={1}](@12) -> half_type, {1, 1, 1, 2}, {2, 2, 2, 1}
43+
@14 = slice[axes={0},starts={1},ends={2}](@12) -> half_type, {1, 1, 1, 2}, {2, 2, 2, 1}
44+
@15 = sub(@14,@13) -> half_type, {1, 1, 1, 2}, {2, 2, 2, 1}
45+
@16 = mul(@15,@0) -> half_type, {1, 1, 1, 2}, {2, 2, 2, 1}
46+
@17 = add(@16,@13) -> half_type, {1, 1, 1, 2}, {2, 2, 2, 1}
47+
@18 = @return(@17)
48+
*/
49+
50+
TEST_CASE(resize_downsample_linear_half_test)
51+
{
52+
using migraphx::half;
53+
migraphx::program p;
54+
auto* mm = p.get_main_module();
55+
migraphx::shape ss{migraphx::shape::half_type, {4}};
56+
std::vector<half> ds = {half{1}, half{1}, half{0.60009765625}, half{0.5}};
57+
mm->add_literal(migraphx::literal(ss, ds));
58+
59+
migraphx::shape sx{migraphx::shape::half_type, {1, 1, 2, 4}};
60+
auto x = mm->add_parameter("X", sx);
61+
62+
migraphx::shape s_ind{migraphx::shape::int32_type, {4, 1, 1, 2}};
63+
std::vector<int> d_ind = {0, 2, 4, 6, 1, 3, 5, 7};
64+
auto l_ind = mm->add_literal(migraphx::literal(s_ind, d_ind));
65+
66+
migraphx::shape s2{migraphx::shape::half_type, {2, 1, 1, 2}};
67+
std::vector<float> d2(4, 0.5f);
68+
auto l2 = mm->add_literal(migraphx::literal(s2, d2));
69+
70+
migraphx::shape s1{migraphx::shape::half_type, {1, 1, 1, 2}};
71+
std::vector<float> d1(2, 0.5 / 0.60009765625 - 0.5);
72+
auto l1 = mm->add_literal(migraphx::literal(s1, d1));
73+
74+
mm->add_instruction(migraphx::make_op("undefined"));
75+
76+
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), x);
77+
auto data = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, l_ind);
78+
auto slc20 = mm->add_instruction(
79+
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {2}}}), data);
80+
auto slc21 = mm->add_instruction(
81+
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {2}}, {"ends", {4}}}), data);
82+
auto diff2 = mm->add_instruction(migraphx::make_op("sub"), slc21, slc20);
83+
auto mul2 = mm->add_instruction(migraphx::make_op("mul"), diff2, l2);
84+
auto add2 = mm->add_instruction(migraphx::make_op("add"), mul2, slc20);
85+
86+
auto slc10 = mm->add_instruction(
87+
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {0}}, {"ends", {1}}}), add2);
88+
auto slc11 = mm->add_instruction(
89+
migraphx::make_op("slice", {{"axes", {0}}, {"starts", {1}}, {"ends", {2}}}), add2);
90+
auto diff1 = mm->add_instruction(migraphx::make_op("sub"), slc11, slc10);
91+
auto mul1 = mm->add_instruction(migraphx::make_op("mul"), diff1, l1);
92+
auto add1 = mm->add_instruction(migraphx::make_op("add"), mul1, slc10);
93+
mm->add_return({add1});
94+
95+
auto prog = read_onnx("resize_downsample_linear_half_test.onnx");
96+
EXPECT(p == prog);
97+
}
192 Bytes
Binary file not shown.
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
25+
#include <migraphx/register_target.hpp>
26+
#include <migraphx/verify.hpp>
27+
#include <onnx_test.hpp>
28+
29+
TEST_CASE(resize_downsample_linear_half_test)
30+
{
31+
using migraphx::half;
32+
migraphx::program p = read_onnx("resize_downsample_linear_half_test.onnx");
33+
p.compile(migraphx::make_target("ref"));
34+
35+
migraphx::shape sx{migraphx::shape::half_type, {1, 1, 2, 4}};
36+
std::vector<half> dx = {half{1}, half{2}, half{3}, half{4}, half{5}, half{6}, half{7}, half{8}};
37+
38+
migraphx::parameter_map pp;
39+
pp["X"] = migraphx::argument(sx, dx.data());
40+
41+
auto result = p.eval(pp).back();
42+
std::vector<half> result_vector;
43+
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
44+
45+
// Expected output was calculated without any quantization
46+
std::vector<half> gold = {half{2.8333333}, half{4.833333}};
47+
48+
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
49+
}

0 commit comments

Comments
 (0)