Skip to content

Commit 1568dfe

Browse files
authored
Have fp8 GEMMs go to hipblaslt only (#3990)
Will always use hipblaslt for fp8 GEMMs with this PR. FP8 will be eliminated if hipblaslt is not avaliable.
1 parent aaf3256 commit 1568dfe

File tree

6 files changed

+37
-19
lines changed

6 files changed

+37
-19
lines changed

src/targets/gpu/device_name.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,20 @@ bool gfx_default_rocblas()
8787
}
8888
#endif
8989

90+
bool hipblaslt_supported()
91+
{
92+
#if !MIGRAPHX_USE_HIPBLASLT
93+
return false;
94+
#else
95+
const auto device_name = trim(split_string(get_device_name(), ':').front());
96+
// hipblaslt is supported for MI200 and above, and Navi3x and above.
97+
return (device_name == "gfx90a" or
98+
(starts_with(device_name, "gfx94") and device_name >= "gfx942") or
99+
(starts_with(device_name, "gfx95") and device_name >= "gfx950") or
100+
starts_with(device_name, "gfx110") or starts_with(device_name, "gfx120"));
101+
#endif
102+
}
103+
90104
} // namespace gpu
91105
} // namespace MIGRAPHX_INLINE_NS
92106
} // namespace migraphx

src/targets/gpu/hipblaslt.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,6 @@ hipblaslt_preference_ptr create_hipblaslt_preference_ptr()
5353
return hipblaslt_preference_ptr{preference};
5454
}
5555

56-
bool hipblaslt_supported()
57-
{
58-
const auto device_name = trim(split_string(get_device_name(), ':').front());
59-
// hipblaslt is supported for MI200 and above, and Navi3x and above.
60-
return (device_name == "gfx90a" or
61-
(starts_with(device_name, "gfx94") and device_name >= "gfx942") or
62-
(starts_with(device_name, "gfx95") and device_name >= "gfx950") or
63-
starts_with(device_name, "gfx110") or starts_with(device_name, "gfx120"));
64-
}
65-
6656
#endif // MIGRAPHX_USE_HIPBLASLT
6757

6858
} // namespace gpu

src/targets/gpu/include/migraphx/gpu/device_name.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ MIGRAPHX_GPU_EXPORT bool gfx_has_fp8fnuz_support();
4747

4848
MIGRAPHX_GPU_EXPORT bool gfx_default_rocblas();
4949

50+
MIGRAPHX_GPU_EXPORT bool hipblaslt_supported();
51+
5052
} // namespace gpu
5153
} // namespace MIGRAPHX_INLINE_NS
5254
} // namespace migraphx

src/targets/gpu/include/migraphx/gpu/hipblaslt.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ using hipblaslt_preference_ptr = MIGRAPHX_MANAGE_PTR(hipblasLtMatmulPreference_t
100100

101101
hipblaslt_handle_ptr create_hipblaslt_handle_ptr();
102102
hipblaslt_preference_ptr create_hipblaslt_preference_ptr();
103-
MIGRAPHX_GPU_EXPORT bool hipblaslt_supported();
104103
const size_t hipblaslt_workspace_size = 2 * 128 * 1024 * 1024;
105104
} // namespace gpu
106105
} // namespace MIGRAPHX_INLINE_NS

src/targets/gpu/lowering.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include <migraphx/pass_manager.hpp>
3434
#include <migraphx/iterator_for.hpp>
3535
#include <migraphx/program.hpp>
36+
#include <migraphx/fp8_types.hpp>
3637

3738
#include <migraphx/op/common.hpp>
3839
#include <migraphx/op/dot.hpp>
@@ -253,17 +254,20 @@ struct miopen_apply
253254
auto output = insert_allocation(ins, ins->get_shape());
254255
refs.push_back(output);
255256

256-
#if MIGRAPHX_USE_HIPBLASLT
257+
bool has_fp8_inputs =
258+
std::any_of(ins->inputs().begin(), ins->inputs().end(), [](auto i_input) {
259+
return contains(fp8_types{}.get(), i_input->get_shape().type());
260+
});
261+
257262
// Check if user explicitly sets rocBLAS as GEMM provider, or
258263
// if the hardware cannot support hipblaslt, or
259264
// if the hardware is defaulted to use rocBLAS (such as gfx90).
260-
if((string_value_of(MIGRAPHX_SET_GEMM_PROVIDER{}) == "rocblas") or
261-
not hipblaslt_supported() or gpu::gfx_default_rocblas())
265+
if(not has_fp8_inputs and
266+
((string_value_of(MIGRAPHX_SET_GEMM_PROVIDER{}) == "rocblas") or
267+
not hipblaslt_supported() or gpu::gfx_default_rocblas()))
262268
{
263-
#endif
264269
return mod->replace_instruction(
265270
ins, rocblas_gemm<Op>{Op{}, 1, 0, compute_fp32}, refs);
266-
#if MIGRAPHX_USE_HIPBLASLT
267271
}
268272
std::string op_name = "gpu::hip_gemm";
269273
if(contains(name, "quant_"))
@@ -277,7 +281,6 @@ struct miopen_apply
277281
ins->inputs().at(0),
278282
ins->inputs().at(1),
279283
output);
280-
#endif
281284
});
282285
}
283286
#endif

src/targets/gpu/target.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
112112
}
113113

114114
// whiltelist supported Ops for the FP8 types
115-
// rocBLAS does not support any FP8 types
116115
std::set<std::string> unsupported_fp8fnuz_ops = {};
117-
if(string_value_of(MIGRAPHX_SET_GEMM_PROVIDER{}) == "rocblas" or gpu::gfx_default_rocblas())
116+
117+
// disable dot & quant_dot if no hipblaslt
118+
if(not hipblaslt_supported())
118119
{
119120
unsupported_fp8fnuz_ops.insert("dot");
120121
unsupported_fp8fnuz_ops.insert("quant_dot");
121122
}
123+
122124
#if MIGRAPHX_USE_MIOPEN // MIOpen doesn't have support for fp8 pooling yet.
123125
unsupported_fp8fnuz_ops.insert("pooling");
124126
#endif
@@ -141,6 +143,14 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
141143
unsupported_fp8fnuz_ops.insert("argmin");
142144

143145
std::set<std::string> unsupported_fp8ocp_ops = {};
146+
147+
// disable dot & quant_dot if no hipblaslt
148+
if(not hipblaslt_supported())
149+
{
150+
unsupported_fp8ocp_ops.insert("dot");
151+
unsupported_fp8ocp_ops.insert("quant_dot");
152+
}
153+
144154
#if MIGRAPHX_USE_MIOPEN
145155
// MIOpen doesn't have support for fp8 pooling yet.
146156
unsupported_fp8ocp_ops.insert("pooling");

0 commit comments

Comments
 (0)