Skip to content

Update fp8 paged attention for MI308 #592

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
7 changes: 7 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,13 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
#
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Wno-unused-result")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-result")

#
# Support fp8 instructions
#
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Xclang -target-feature -Xclang +fp8-conversion-insts")
set(CMAKE_${VLLM_GPU_LANG}_FLAGS "${CMAKE_${VLLM_GPU_LANG}_FLAGS} -Xclang -target-feature -Xclang +fp8-insts")
list(APPEND VLLM_GPU_FLAGS "-Rpass-analysis=kernel-resource-usage")
endif()

#
Expand Down
152 changes: 136 additions & 16 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,21 @@ __device__ __forceinline__ floatx4 gcn_mfma16x16x16_instr(const _B16x4& inpA,
}
}

template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma16x16x32_instr(const long& inpA,
const long& inpB,
const floatx4& inpC) {
if constexpr (std::is_same<T, __hip_fp8_e4m3>::value) {
return __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(inpA, inpB, inpC, absz, cbid,
blgp);
} else if constexpr (std::is_same<T, __hip_fp8_e5m2>::value) {
return __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(inpA, inpB, inpC, absz,
cbid, blgp);
} else {
static_assert(false, "unsupported 8b dtype");
}
}

template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
Expand Down Expand Up @@ -256,6 +271,46 @@ __device__ __forceinline__ _B16x8 convert_b8x8_custom(const _B8x8 input) {
return ret;
}

#define _MI308 // enable for MI308 when full fp8 attention required for better performance
typedef union u64_cvt {
half f16x4[4];
int16_t b16x4[4];
_B8x8 b8x8;
_B16x4 b64;
int64_t i64;
} _T8x8;

__device__ __forceinline__ _B8x8 convert_b16x8(const _B16x8& input, _T8x8& Mtemp)
{

_T8x8 Qtmp8x8;

for (int i = 0; i < 2; i++) {
floatx4 q_out = {0,0,0,0};
q_out = gcn_mfma16x16x16_instr<_Float16, 0, 0, 0>(
Mtemp.b64,
input.xy[i], q_out);
Qtmp8x8.b16x4[i*2 ] = __builtin_amdgcn_cvt_pk_fp8_f32(q_out[0], q_out[1],0,false);
Qtmp8x8.b16x4[i*2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32(q_out[2], q_out[3],0,false);
}
return Qtmp8x8.b8x8;
}

#define hipWarpSize 64
__device__ float warpReduceMax(float val) {
for (int offset = hipWarpSize / 2; offset > 0; offset /= 2) {
val = max(val, __shfl_down(val, offset, hipWarpSize)); // Using max() for reduction
}
return val;
}

__device__ float warpReduceMin(float val) {
for (int offset = hipWarpSize / 2; offset > 0; offset /= 2) {
val = min(val, __shfl_down(val, offset, hipWarpSize)); // Using max() for reduction
}
return val;
}

// grid (num_seqs, num_partitions,num_kv_heads)
// block (256)
// clang-format off
Expand Down Expand Up @@ -410,12 +465,18 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
__syncthreads();
float q_max = 0;
for (int qkhe_depth = 0; qkhe_depth < QKHELOOP; qkhe_depth++) {
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
for (int i = 0; i < 2; i++) {
Qlocal[qkhe_depth][qkratio].xy[i] =
shared_logits[qkhe_depth][rowid][lane16id % GQA_RATIO]
[2 * qkratio + i];
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto){
scalar_t* qptr = reinterpret_cast<scalar_t*>(&Qlocal[qkhe_depth][qkratio].xy[i]);
for(int k = 0; k< 2; k++)
q_max = fmax(to_float<scalar_t>(qptr[k]), q_max);
}
}
}
}
Expand Down Expand Up @@ -512,9 +573,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(

// calculate post qk mfma scale
float scale2 = scale;
q_max = warpReduceMax(q_max);
float q_scale = q_max > 0 ? 224.0 / q_max : 1.0;
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
// multiply by k_scale if fp8 kv cache
scale2 *= *k_scale;
scale2 /= q_scale;
}

floatx4 d_out[TLOOP];
Expand All @@ -534,13 +598,38 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
auto Ktmp = Klocal[token_depth][qkhe_depth];
_B8x16 Ktmp8x16 = *reinterpret_cast<_B8x16*>(&Ktmp);
for (int qkratio = 0; qkratio < QK_SIZE_RATIO; qkratio++) {
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
for (int i = 0; i < 2; i++) {
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
d_out[token_depth]);
#ifndef _MI308
{
_B8x8 Ktmp8x8 = Ktmp8x16.xy[qkratio];
_B16x8 Klocaltmp = convert_b8x8_custom<scalar_t>(Ktmp8x8);
for (int i = 0; i < 2; i++) {
d_out[token_depth] = gcn_mfma16x16x16_instr<scalar_t, 0, 0, 0>(
Klocaltmp.xy[i], Qlocal[qkhe_depth][qkratio].xy[i],
d_out[token_depth]);
}
}
#else
{
_T8x8 Ktmp8x8, Qtmp8x8;
Ktmp8x8.b8x8 = Ktmp8x16.xy[qkratio];

for(int n = 0; n < 2; n++)
{
scalar_t* qptr = reinterpret_cast<scalar_t*>(&Qlocal[qkhe_depth][qkratio].xy[n]);

Qtmp8x8.b16x4[n*2] = __builtin_amdgcn_cvt_pk_fp8_f32(
to_float<scalar_t>(qptr[0])*q_scale,
to_float<scalar_t>(qptr[1])*q_scale, 0, false);
Qtmp8x8.b16x4[n*2+1] = __builtin_amdgcn_cvt_pk_fp8_f32(
to_float<scalar_t>(qptr[2])*q_scale,
to_float<scalar_t>(qptr[3])*q_scale, 0, false);
}

d_out[token_depth] = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
Ktmp8x8.i64, Qtmp8x8.i64,
d_out[token_depth]);
}
#endif
}
}
}
Expand Down Expand Up @@ -631,16 +720,31 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
constexpr bool LOGITS_RTZ_CONVERSION = false;

// write logits to shared mem
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
d_out[token_depth] *= inv_sum_scale;
if constexpr (LOGITS_RTZ_CONVERSION) {
// use rtz conversion for better performance, with negligible impact on
// accuracy
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
} else {
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4<scalar_t>(d_out[token_depth]);
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto)
{
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
d_out[token_depth] *= inv_sum_scale;
if constexpr (LOGITS_RTZ_CONVERSION) {
// use rtz conversion for better performance, with negligible impact on
// accuracy
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4_rtz<scalar_t>(d_out[token_depth]);
} else {
shared_logits[warpid][token_depth][lane16id][rowid] =
from_floatx4<scalar_t>(d_out[token_depth]);
}
}
}
else
{
int rowid_8x8 = rowid/2;
int offset = rowid%2;
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
d_out[token_depth] *= inv_sum_scale;
// cast _B16x4* to _B8x8*
_T8x8& logits_8x8 = *reinterpret_cast<_T8x8*>(&shared_logits[warpid][token_depth][lane16id][rowid_8x8]);
logits_8x8.b16x4[offset * 2 ] = __builtin_amdgcn_cvt_pk_fp8_f32(d_out[token_depth][0], d_out[token_depth][1],0,false);
logits_8x8.b16x4[offset * 2 + 1] = __builtin_amdgcn_cvt_pk_fp8_f32(d_out[token_depth][2], d_out[token_depth][3],0,false);
}
}

Expand Down Expand Up @@ -693,6 +797,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
_B8x16 Vtmp8x16 = *reinterpret_cast<_B8x16*>(&Vtmp);
for (int j = 0; j < ELEMS16_ELEMS8_RATIO; j++) {
_B8x8 Vtmp8x8 = Vtmp8x16.xy[j];
#ifndef _MI308
_B16x8 Vlocaltmp = convert_b8x8_custom<scalar_t>(Vtmp8x8);
for (int i = 0; i < ELEMS8_ELEMS4_RATIO; i++) {
const int offset =
Expand All @@ -707,6 +812,21 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
shared_logits[vtoken_depth][offset2][lane16id][offset1],
tmp_out);
}
#else
for (int i = 0; i < ELEMS8_ELEMS4_RATIO/2; i++) {
const int offset =
rowid * ELEMS16_ELEMS8_RATIO * ELEMS8_ELEMS4_RATIO +
j * ELEMS8_ELEMS4_RATIO + i;
const int offset1 = (offset % ROWS_PER_WARP) / 2;
const int offset2 = offset / ROWS_PER_WARP;
// output format is 16 qheads across 16 lanes, 16 head elems
// spread across 4 rows
tmp_out = gcn_mfma16x16x32_instr<__hip_fp8_e4m3, 0, 0, 0>(
reinterpret_cast<_T8x8*>(&Vtmp8x8)->i64,
reinterpret_cast<_T8x8*>(&shared_logits[vtoken_depth][offset2][lane16id][offset1])->i64,
tmp_out);
}
#endif
}
}
}
Expand Down
Loading
Loading