Skip to content

Commit 04ac9fc

Browse files
authored
Enable attention by default on mi300 (#3463)
1 parent c967696 commit 04ac9fc

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/targets/gpu/fuse_mlir.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,14 @@ static bool specific_op(std::string_view option, bool fallback = false)
125125
return contains(options, option);
126126
}
127127

128-
bool mlir_attention_enabled()
128+
bool mlir_attention_enabled(context* ctx)
129129
{
130130
#ifdef MIGRAPHX_MLIR
131131
if(not mlir_enabled())
132132
return false;
133+
// Enable attention by default for mi300
134+
if(ctx != nullptr and starts_with(ctx->get_current_device().get_gfx_name(), "gfx94"))
135+
return true;
133136
return specific_op<requested>("attention");
134137
#else
135138
return false;
@@ -996,7 +999,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
996999
};
9971000

9981001
// Attention offloads; default disabled
999-
if(mlir_attention_enabled() or enable_extra)
1002+
if(mlir_attention_enabled(ctx) or enable_extra)
10001003
{
10011004
match::find_matches(mpm, find_mlir_attention_fused_ops{mlir_mode::all});
10021005
mpm.run_pass(dead_code_elimination{});

src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ struct module_pass_manager;
3434
namespace gpu {
3535

3636
MIGRAPHX_GPU_EXPORT bool mlir_enabled();
37-
MIGRAPHX_GPU_EXPORT bool mlir_attention_enabled();
3837

3938
struct MIGRAPHX_GPU_EXPORT fuse_mlir
4039
{

0 commit comments

Comments
 (0)