3333#include  < migraphx/pass_manager.hpp> 
3434#include  < migraphx/iterator_for.hpp> 
3535#include  < migraphx/program.hpp> 
36+ #include  < migraphx/fp8_types.hpp> 
3637
3738#include  < migraphx/op/common.hpp> 
3839#include  < migraphx/op/dot.hpp> 
@@ -253,17 +254,20 @@ struct miopen_apply
253254            auto  output = insert_allocation (ins, ins->get_shape ());
254255            refs.push_back (output);
255256
256- #if  MIGRAPHX_USE_HIPBLASLT
257+             bool  has_fp8_inputs =
258+                 std::any_of (ins->inputs ().begin (), ins->inputs ().end (), [](auto  i_input) {
259+                     return  contains (fp8_types{}.get (), i_input->get_shape ().type ());
260+                 });
261+ 
257262            //  Check if user explicitly sets rocBLAS as GEMM provider, or
258263            //  if the hardware cannot support hipblaslt, or
259264            //  if the hardware is defaulted to use rocBLAS (such as gfx90).
260-             if ((string_value_of (MIGRAPHX_SET_GEMM_PROVIDER{}) == " rocblas" or 
261-                not  hipblaslt_supported () or  gpu::gfx_default_rocblas ())
265+             if (not  has_fp8_inputs and 
266+                ((string_value_of (MIGRAPHX_SET_GEMM_PROVIDER{}) == " rocblas" or 
267+                 not  hipblaslt_supported () or  gpu::gfx_default_rocblas ()))
262268            {
263- #endif
264269                return  mod->replace_instruction (
265270                    ins, rocblas_gemm<Op>{Op{}, 1 , 0 , compute_fp32}, refs);
266- #if  MIGRAPHX_USE_HIPBLASLT
267271            }
268272            std::string op_name = " gpu::hip_gemm" 
269273            if (contains (name, " quant_" 
@@ -277,7 +281,6 @@ struct miopen_apply
277281                ins->inputs ().at (0 ),
278282                ins->inputs ().at (1 ),
279283                output);
280- #endif 
281284        });
282285    }
283286#endif 
0 commit comments