Skip to content

Commit 2917c94

Browse files
authored
Add propagate_precision pass (#2853)
1 parent 7382a62 commit 2917c94

File tree

6 files changed

+494
-3
lines changed

6 files changed

+494
-3
lines changed

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ add_library(migraphx
9191
process.cpp
9292
program.cpp
9393
propagate_constant.cpp
94+
propagate_precision.cpp
9495
promote_literals.cpp
9596
quantization.cpp
9697
quantize_int4.cpp
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*
24+
*/
25+
#ifndef MIGRAPHX_GUARD_MIGRAPHX_PROPAGATE_PRECISION_HPP
26+
#define MIGRAPHX_GUARD_MIGRAPHX_PROPAGATE_PRECISION_HPP
27+
28+
#include <migraphx/config.hpp>
29+
#include <string>
30+
31+
namespace migraphx {
32+
inline namespace MIGRAPHX_INLINE_NS {
33+
34+
struct module_pass_manager;
35+
36+
/// This pass will propagate higher precision through more adjacent operators.
37+
struct MIGRAPHX_EXPORT propagate_precision
38+
{
39+
std::string name() const { return "propagate_precision"; }
40+
void apply(module_pass_manager& mpm) const;
41+
};
42+
43+
} // namespace MIGRAPHX_INLINE_NS
44+
} // namespace migraphx
45+
#endif // MIGRAPHX_GUARD_MIGRAPHX_PROPAGATE_PRECISION_HPP

src/include/migraphx/shape.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,9 +392,9 @@ struct MIGRAPHX_EXPORT shape
392392

393393
std::size_t size(std::size_t n = 1) const { return sizeof(type) * n; }
394394

395-
auto is_integral() const { return std::is_integral<type>{}; }
396-
auto is_signed() const { return std::is_signed<type>{}; }
397-
auto is_unsigned() const { return std::is_unsigned<type>{}; }
395+
bool is_integral() const { return std::is_integral<type>{}; }
396+
bool is_signed() const { return std::is_signed<type>{}; }
397+
bool is_unsigned() const { return std::is_unsigned<type>{}; }
398398

399399
template <class U>
400400
type* from(U* buffer, std::size_t n = 0) const

src/propagate_precision.cpp

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*
24+
*/
25+
#include <migraphx/propagate_precision.hpp>
26+
#include <migraphx/module.hpp>
27+
#include <migraphx/iterator_for.hpp>
28+
#include <migraphx/instruction.hpp>
29+
#include <migraphx/pass_manager.hpp>
30+
#include <migraphx/make_op.hpp>
31+
#include <migraphx/functional.hpp>
32+
#include <migraphx/ranges.hpp>
33+
#include <migraphx/eliminate_convert.hpp>
34+
#include <unordered_set>
35+
#include <unordered_map>
36+
37+
namespace migraphx {
38+
inline namespace MIGRAPHX_INLINE_NS {
39+
40+
namespace {
41+
#ifdef __clang__
42+
#pragma clang diagnostic push
43+
#pragma clang diagnostic ignored "-Wunused-function"
44+
#endif
45+
// Class wrappper so we can compare precision using comparison operators
46+
struct precision
47+
{
48+
shape::type_t type;
49+
50+
friend bool operator==(const precision& xp, const precision& yp) { return xp.type == yp.type; }
51+
friend bool operator<(const precision& xp, const precision& yp)
52+
{
53+
bool is_less = false;
54+
shape::visit(xp.type, [&](auto x) {
55+
shape::visit(yp.type, [&](auto y) {
56+
if(x.is_integral() != y.is_integral())
57+
return;
58+
if(x.is_integral())
59+
{
60+
if(x.is_unsigned() != y.is_unsigned() and x.size() == y.size())
61+
is_less = y.is_unsigned();
62+
else
63+
is_less = x.size() < y.size();
64+
}
65+
else
66+
{
67+
is_less = x.size() < y.size();
68+
}
69+
});
70+
});
71+
return is_less;
72+
}
73+
friend bool operator!=(const precision& xp, const precision& yp) { return not(xp == yp); }
74+
friend bool operator>(const precision& xp, const precision& yp) { return yp < xp; }
75+
// This is not totally ordered
76+
friend bool operator<=(const precision& xp, const precision& yp)
77+
{
78+
return (xp < yp) or (xp == yp);
79+
}
80+
friend bool operator>=(const precision& xp, const precision& yp)
81+
{
82+
return (xp > yp) or (xp == yp);
83+
}
84+
};
85+
#ifdef __clang__
86+
#pragma clang diagnostic pop
87+
#endif
88+
} // namespace
89+
90+
static bool is_pointwise_or_reduce(instruction_ref ins)
91+
{
92+
return contains(ins->name(), "reduce") or
93+
ins->get_operator().attributes().get("pointwise", false);
94+
}
95+
// Check if its not a scalar constant
96+
static bool is_non_scalar_const(instruction_ref ins)
97+
{
98+
return not(ins->get_shape().scalar() and ins->can_eval());
99+
}
100+
// Get the next input instruction otherwise return a nullopt
101+
static std::optional<instruction_ref> get_next_input(instruction_ref ins)
102+
{
103+
if(ins->inputs().size() == 1)
104+
return ins->inputs().front();
105+
if(ins->inputs().size() > 1)
106+
{
107+
std::unordered_set<instruction_ref> non_scalars;
108+
std::copy_if(ins->inputs().begin(),
109+
ins->inputs().end(),
110+
std::inserter(non_scalars, non_scalars.end()),
111+
&is_non_scalar_const);
112+
if(non_scalars.size() == 1)
113+
return *non_scalars.begin();
114+
}
115+
return nullopt;
116+
}
117+
118+
// Find all adjacent instructions that could be upgraded with higher precision
119+
// by traversing the inputs from a convert
120+
121+
static std::unordered_set<instruction_ref> find_adjacent_inputs(instruction_ref start)
122+
{
123+
std::unordered_set<instruction_ref> result;
124+
// Promote inputs
125+
fix([&](auto self, instruction_ref ins) {
126+
if(not is_pointwise_or_reduce(ins))
127+
return;
128+
if(contains(result, ins))
129+
return;
130+
auto next = get_next_input(ins);
131+
if(not next.has_value())
132+
return;
133+
result.insert(ins);
134+
self(*next);
135+
})(start->inputs().front());
136+
return result;
137+
}
138+
139+
// Find all adjacent instructions that could be upgraded with higher precision
140+
// by traversing the outputs from a convert
141+
static std::unordered_set<instruction_ref> find_adjacent_outputs(instruction_ref start)
142+
{
143+
std::unordered_set<instruction_ref> result;
144+
// Promote outputs
145+
fix([&](auto self, instruction_ref ins) {
146+
for(auto output : ins->outputs())
147+
{
148+
if(not is_pointwise_or_reduce(output))
149+
continue;
150+
if(contains(result, output))
151+
continue;
152+
auto next = get_next_input(output);
153+
if(not next.has_value())
154+
continue;
155+
if(*next != ins)
156+
continue;
157+
result.insert(output);
158+
self(output);
159+
}
160+
})(start);
161+
return result;
162+
}
163+
164+
// Insert the instructions to upgrade into the map. If the map already has the
165+
// instruction then choose the highest precision
166+
template <class Map, class Instructions>
167+
static void
168+
insert_instructions_to_upgrade(Map& m, const Instructions& instructions, shape::type_t t)
169+
{
170+
for(auto ins : instructions)
171+
{
172+
auto it = m.find(ins);
173+
if(it == m.end())
174+
{
175+
m[ins] = t;
176+
}
177+
else
178+
{
179+
it->second = std::max(precision{t}, precision{it->second}).type;
180+
}
181+
}
182+
}
183+
184+
// Find adjacent instructions from a convert to upgrade to use a higher
185+
// precision
186+
static std::unordered_map<instruction_ref, shape::type_t> find_instruction_to_upgrade(module& m)
187+
{
188+
std::unordered_map<instruction_ref, shape::type_t> result;
189+
for(auto ins : iterator_for(m))
190+
{
191+
if(ins->name() != "convert")
192+
continue;
193+
auto output = precision{ins->get_shape().type()};
194+
auto input = precision{ins->inputs().front()->get_shape().type()};
195+
if(output.type == shape::type_t::bool_type)
196+
continue;
197+
if(input < output)
198+
{
199+
insert_instructions_to_upgrade(result, find_adjacent_inputs(ins), output.type);
200+
}
201+
else if(input > output)
202+
{
203+
insert_instructions_to_upgrade(result, find_adjacent_outputs(ins), input.type);
204+
}
205+
}
206+
return result;
207+
}
208+
209+
void propagate_precision::apply(module_pass_manager& mpm) const
210+
{
211+
auto upgrade = find_instruction_to_upgrade(mpm.get_module());
212+
for(const auto& p : upgrade)
213+
{
214+
auto ins = p.first;
215+
auto t = p.second;
216+
auto convert1 = mpm.get_module().insert_instruction(
217+
std::next(ins), make_op("convert", {{"target_type", ins->get_shape().type()}}), ins);
218+
mpm.get_module().replace_instruction(ins, convert1);
219+
std::vector<instruction_ref> inputs;
220+
std::transform(ins->inputs().begin(),
221+
ins->inputs().end(),
222+
std::back_inserter(inputs),
223+
[&](auto input) {
224+
return mpm.get_module().insert_instruction(
225+
ins, make_op("convert", {{"target_type", t}}), input);
226+
});
227+
mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs);
228+
}
229+
mpm.run_pass(eliminate_convert{});
230+
}
231+
232+
} // namespace MIGRAPHX_INLINE_NS
233+
} // namespace migraphx

src/targets/gpu/target.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <migraphx/optimize_module.hpp>
4343
#include <migraphx/preallocate_param.hpp>
4444
#include <migraphx/promote_literals.hpp>
45+
#include <migraphx/propagate_precision.hpp>
4546
#include <migraphx/register_target.hpp>
4647
#include <migraphx/replace_allocate.hpp>
4748
#include <migraphx/rewrite_dot.hpp>
@@ -217,6 +218,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
217218
rewrite_low_precision{},
218219
enable_pass(enabled(MIGRAPHX_ENABLE_REWRITE_DOT{}), rewrite_dot{}),
219220
dead_code_elimination{},
221+
propagate_precision{},
222+
dead_code_elimination{},
220223
optimize_module{},
221224
fuse_pointwise_reduce{},
222225
dead_code_elimination{},

0 commit comments

Comments
 (0)