Skip to content

Zhwang/llama gqa #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 135 commits into
base: zhwang/llama
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
135 commits
Select commit Hold shift + click to select a range
1205add
commit
sfc-gh-zhwang Sep 5, 2023
088bebb
commit
sfc-gh-zhwang Sep 5, 2023
ef8e906
commit
sfc-gh-zhwang Sep 5, 2023
9eda6df
commit
sfc-gh-zhwang Sep 5, 2023
98727db
commit
sfc-gh-zhwang Sep 5, 2023
b5eb6cf
commit
sfc-gh-zhwang Sep 5, 2023
62557b6
commit
sfc-gh-zhwang Sep 5, 2023
0d73d63
commit
sfc-gh-zhwang Sep 5, 2023
03087ac
commit
sfc-gh-zhwang Sep 5, 2023
aff2408
commit
sfc-gh-zhwang Sep 5, 2023
5a2ebcb
commit
sfc-gh-zhwang Sep 5, 2023
d0271a3
commit
sfc-gh-zhwang Sep 5, 2023
699c569
commit
sfc-gh-zhwang Sep 5, 2023
4455827
commit
sfc-gh-zhwang Sep 5, 2023
7398a6e
commit
sfc-gh-zhwang Sep 5, 2023
027c697
commit
sfc-gh-zhwang Sep 5, 2023
4da2aa2
commit
sfc-gh-zhwang Sep 5, 2023
07c2f5a
commit
sfc-gh-zhwang Sep 5, 2023
c6948f9
commit
sfc-gh-zhwang Sep 5, 2023
4d5bbe2
commit
sfc-gh-zhwang Sep 5, 2023
e8cceb4
commit
sfc-gh-zhwang Sep 5, 2023
a32c9d2
commit
sfc-gh-zhwang Sep 8, 2023
1ceed73
commit
sfc-gh-zhwang Sep 8, 2023
ad919b8
commit
sfc-gh-zhwang Sep 8, 2023
630ced0
commit
sfc-gh-zhwang Sep 8, 2023
161774c
commit
sfc-gh-zhwang Sep 8, 2023
50a4215
commit
sfc-gh-zhwang Sep 8, 2023
f19a93e
commit
sfc-gh-zhwang Sep 8, 2023
a4d4743
commit
sfc-gh-zhwang Sep 8, 2023
2f264c7
commit
sfc-gh-zhwang Sep 8, 2023
a42ab9d
commit
sfc-gh-zhwang Sep 8, 2023
1a227ef
commit
sfc-gh-zhwang Sep 8, 2023
23619dc
commit
sfc-gh-zhwang Sep 8, 2023
7526221
commit
sfc-gh-zhwang Sep 8, 2023
c236b5d
commit
sfc-gh-zhwang Sep 9, 2023
24ebefa
commit
sfc-gh-zhwang Sep 9, 2023
4114f97
commit
sfc-gh-zhwang Sep 9, 2023
8af8e0d
commit
sfc-gh-zhwang Sep 9, 2023
135b5fd
commit
sfc-gh-zhwang Sep 9, 2023
8d4b18d
commit
sfc-gh-zhwang Sep 9, 2023
4c127bf
commit
sfc-gh-zhwang Sep 9, 2023
50f1de5
commit
sfc-gh-zhwang Sep 9, 2023
8a80dde
commit
sfc-gh-zhwang Sep 9, 2023
794c0de
commit
sfc-gh-zhwang Sep 9, 2023
dc805b2
commit
sfc-gh-zhwang Sep 9, 2023
7b9e51b
commit
sfc-gh-zhwang Sep 9, 2023
5ab58d1
commit
sfc-gh-zhwang Sep 9, 2023
f703210
commit
sfc-gh-zhwang Sep 9, 2023
e2e516a
commit
sfc-gh-zhwang Sep 9, 2023
459c26e
commit
sfc-gh-zhwang Sep 9, 2023
f0ab27f
commit
sfc-gh-zhwang Sep 9, 2023
d5b096b
commit
sfc-gh-zhwang Sep 9, 2023
aca0e79
commit
sfc-gh-zhwang Sep 9, 2023
422a545
commit
sfc-gh-zhwang Sep 9, 2023
d9a8481
commit
sfc-gh-zhwang Sep 9, 2023
29959a3
commit
sfc-gh-zhwang Sep 9, 2023
45c9aee
commit
sfc-gh-zhwang Sep 9, 2023
7e721f9
commit
sfc-gh-zhwang Sep 9, 2023
4a03a09
commit
sfc-gh-zhwang Sep 9, 2023
ccbd727
commit
sfc-gh-zhwang Sep 9, 2023
e58d548
commit
sfc-gh-zhwang Sep 9, 2023
941072c
commit
sfc-gh-zhwang Sep 9, 2023
bbf5ebe
commit
sfc-gh-zhwang Sep 9, 2023
145788b
commit
sfc-gh-zhwang Sep 9, 2023
6333c82
commit
sfc-gh-zhwang Sep 9, 2023
210c439
commit
sfc-gh-zhwang Sep 9, 2023
a305abe
commit
sfc-gh-zhwang Sep 9, 2023
4476f35
commit
sfc-gh-zhwang Sep 9, 2023
2fab8b4
commit
sfc-gh-zhwang Sep 9, 2023
ff64dc2
commit
sfc-gh-zhwang Sep 9, 2023
8efdd38
commit
sfc-gh-zhwang Sep 9, 2023
d805fcf
commit
sfc-gh-zhwang Sep 9, 2023
e47d2ac
commit
sfc-gh-zhwang Sep 9, 2023
3b38507
commit
sfc-gh-zhwang Sep 9, 2023
d601f68
commit
sfc-gh-zhwang Sep 9, 2023
e161612
commit
sfc-gh-zhwang Sep 10, 2023
e655b1c
commit
sfc-gh-zhwang Sep 10, 2023
e037596
commit
sfc-gh-zhwang Sep 10, 2023
c74e6b8
commit
sfc-gh-zhwang Sep 10, 2023
ac982e8
commit
sfc-gh-zhwang Sep 10, 2023
e99edd4
commit
sfc-gh-zhwang Sep 10, 2023
e2d88e0
commit
sfc-gh-zhwang Sep 10, 2023
fce2324
commit
sfc-gh-zhwang Sep 10, 2023
1f5bca8
commit
sfc-gh-zhwang Sep 10, 2023
f7ad862
commit
sfc-gh-zhwang Sep 10, 2023
13d2c2c
commit
sfc-gh-zhwang Sep 10, 2023
6b7b481
commit
sfc-gh-zhwang Sep 10, 2023
38e1c61
commit
sfc-gh-zhwang Sep 10, 2023
6f79374
commit
sfc-gh-zhwang Sep 11, 2023
3e7cbc4
commit
sfc-gh-zhwang Sep 11, 2023
dd152ce
commit
sfc-gh-zhwang Sep 11, 2023
f2e763d
commit
sfc-gh-zhwang Sep 11, 2023
8198daf
commit
sfc-gh-zhwang Sep 11, 2023
8d19608
commit
sfc-gh-zhwang Sep 11, 2023
94cd671
commit
sfc-gh-zhwang Sep 11, 2023
dda42a3
commit
sfc-gh-zhwang Sep 11, 2023
dbe4a43
commit
sfc-gh-zhwang Sep 11, 2023
36094dc
commit
sfc-gh-zhwang Sep 11, 2023
b7d9ca7
commit
sfc-gh-zhwang Sep 11, 2023
5e4fa2e
commit
sfc-gh-zhwang Sep 11, 2023
8473391
commit
sfc-gh-zhwang Sep 11, 2023
9fc5465
commit
sfc-gh-zhwang Sep 11, 2023
8c6f8c4
commit
sfc-gh-zhwang Sep 11, 2023
852edd5
commit
sfc-gh-zhwang Sep 11, 2023
0c1a43d
commit
sfc-gh-zhwang Sep 11, 2023
35d3028
commit
sfc-gh-zhwang Sep 11, 2023
f591af4
commit
sfc-gh-zhwang Sep 11, 2023
bc2714a
commit
sfc-gh-zhwang Sep 11, 2023
83ec68f
commit
sfc-gh-zhwang Sep 11, 2023
255e077
commit
sfc-gh-zhwang Sep 11, 2023
bf46234
commit
sfc-gh-zhwang Sep 11, 2023
2acfc19
commit
sfc-gh-zhwang Sep 11, 2023
281dcaf
commit
sfc-gh-zhwang Sep 11, 2023
7b09acb
commit
sfc-gh-zhwang Sep 11, 2023
be8f65b
commit
sfc-gh-zhwang Sep 11, 2023
a7446b3
commit
sfc-gh-zhwang Sep 11, 2023
45f5042
commit
sfc-gh-zhwang Sep 11, 2023
1a8eb1c
commit
sfc-gh-zhwang Sep 11, 2023
b603941
commit
sfc-gh-zhwang Sep 11, 2023
96bd93a
commit
sfc-gh-zhwang Sep 11, 2023
5d8d95e
commit
sfc-gh-zhwang Sep 11, 2023
6aac295
commit
sfc-gh-zhwang Sep 11, 2023
69f420b
commit
sfc-gh-zhwang Sep 11, 2023
d761478
commit
sfc-gh-zhwang Sep 11, 2023
839b558
commit
sfc-gh-zhwang Sep 11, 2023
5eb2e79
commit
sfc-gh-zhwang Sep 11, 2023
c4057d1
commit
sfc-gh-zhwang Sep 11, 2023
96eb6ad
commit
sfc-gh-zhwang Sep 11, 2023
4b8303b
commit
sfc-gh-zhwang Sep 11, 2023
0f7b363
commit
sfc-gh-zhwang Sep 11, 2023
fa3fa51
commit
sfc-gh-zhwang Sep 11, 2023
0edb133
commit
sfc-gh-zhwang Sep 11, 2023
56ad958
commit
sfc-gh-zhwang Sep 11, 2023
7caf88b
commit
sfc-gh-zhwang Sep 11, 2023
a4b6dd9
commit
sfc-gh-zhwang Sep 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@
"__nullptr": "cpp",
"__string": "cpp",
"compare": "cpp",
"concepts": "cpp"
"concepts": "cpp",
"filesystem": "cpp",
"__memory": "cpp"
}
}
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:cutlass_heuristic>
$<TARGET_OBJECTS:cutlass_preprocessors>
$<TARGET_OBJECTS:decoder_masked_multihead_attention>
$<TARGET_OBJECTS:decoder_masked_groupedquery_attention>
$<TARGET_OBJECTS:decoding_kernels>
$<TARGET_OBJECTS:fpA_intB_gemm>
$<TARGET_OBJECTS:gen_relative_pos_bias>
Expand Down
1 change: 1 addition & 0 deletions src/fastertransformer/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
cmake_minimum_required(VERSION 3.8)

add_subdirectory(cutlass_kernels)
add_subdirectory(llama)

add_library(image_shift_partition_kernels image_shift_partition_kernels.cu)
set_property(TARGET image_shift_partition_kernels PROPERTY POSITION_INDEPENDENT_CODE ON)
Expand Down
23 changes: 23 additions & 0 deletions src/fastertransformer/kernels/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

cmake_minimum_required(VERSION 3.8)

set(decoder_masked_groupedquery_attention_files
decoder_masked_groupedquery_attention.cu
)
file(GLOB decoder_masked_groupedquery_attention_files ${decoder_masked_groupedquery_attention_files} ./decoder_masked_groupedquery_attention/*.cu)
add_library(decoder_masked_groupedquery_attention STATIC ${decoder_masked_groupedquery_attention_files})
set_property(TARGET decoder_masked_groupedquery_attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET decoder_masked_groupedquery_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h"
#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>

template<typename T, typename KERNEL_PARAMS_TYPE>
void groupedquery_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
switch (params.hidden_size_per_head) {
case 32:
mgqa_launch_kernel<T, 32, 32, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 48:
mgqa_launch_kernel<T, 48, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 64:
mgqa_launch_kernel<T, 64, 64, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 80:
mgqa_launch_kernel<T, 80, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 96:
mgqa_launch_kernel<T, 96, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 128:
mgqa_launch_kernel<T, 128, 128, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 144:
mgqa_launch_kernel<T, 144, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 160:
mgqa_launch_kernel<T, 160, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 192:
mgqa_launch_kernel<T, 192, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 224:
mgqa_launch_kernel<T, 224, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
case 256:
mgqa_launch_kernel<T, 256, 256, KERNEL_PARAMS_TYPE>(params, stream);
break;
default:
assert(false);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void masked_groupedquery_attention(const Masked_groupedquery_attention_params<float>& params, const cudaStream_t& stream)
{
groupedquery_attention_<float, Masked_groupedquery_attention_params<float>>(params, stream);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

void masked_groupedquery_attention(const Masked_groupedquery_attention_params<uint16_t>& params, const cudaStream_t& stream)
{
groupedquery_attention_<uint16_t, Masked_groupedquery_attention_params<uint16_t>>(params, stream);
}

////////////////////////////////////////////////////////////////////////////////////////////////////

#ifdef ENABLE_BF16
void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream)
{
groupedquery_attention_<__nv_bfloat16, Masked_groupedquery_attention_params<__nv_bfloat16>>(params, stream);
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////

#ifdef ENABLE_FP8
void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params,
const cudaStream_t& stream)
{
groupedquery_attention_<__nv_fp8_e4m3, Masked_groupedquery_attention_params<__nv_fp8_e4m3>>(params, stream);
}
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h"
#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include "src/fastertransformer/utils/cuda_fp8_utils.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>

template<typename T>
struct GroupedQuery_attention_params: public Multihead_attention_params_base<T> {
// allows to exist attention eary
bool* finished = nullptr;
int num_kv_heads = 0;
// required in case of masked attention with different length
const int* length_per_sample = nullptr;
};

template<class T>
using Masked_groupedquery_attention_params = GroupedQuery_attention_params<T>;

////////////////////////////////////////////////////////////////////////////////////////////////////

void masked_groupedquery_attention(const Masked_groupedquery_attention_params<float>& params, const cudaStream_t& stream);
void masked_groupedquery_attention(const Masked_groupedquery_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params,
const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params,
const cudaStream_t& stream);
#endif

////////////////////////////////////////////////////////////////////////////////////////////////////
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "decoder_masked_groupedquery_attention_template.hpp"
#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>

////////////////////////////////////////////////////////////////////////////////////////////////////

#define MGQA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_groupedquery_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
HAS_BEAMS><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = threads_per_value_t<T, Dh_MAX>::value;
int tlength = params.timestep;
if (params.cache_indir == nullptr) {
if (tlength < 32) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream);
}
else if (tlength < 2048) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream);
}
else {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream);
}
}
else {
if (tlength < 32) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream);
}
else if (tlength < 2048) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream);
}
else {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream);
}
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template void mgqa_launch_kernel<float, 128, 128, GroupedQuery_attention_params<float>>(
const GroupedQuery_attention_params<float>& params, const cudaStream_t& stream);
template void mgqa_launch_kernel<uint16_t, 128, 128, GroupedQuery_attention_params<uint16_t>>(
const GroupedQuery_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, GroupedQuery_attention_params<__nv_bfloat16>>(
const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>(
const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif


#undef MGQA_LAUNCH_KERNEL
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "decoder_masked_groupedquery_attention_template.hpp"
#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h"
#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/fastertransformer/utils/cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>

////////////////////////////////////////////////////////////////////////////////////////////////////

#define MGQA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_groupedquery_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
HAS_BEAMS><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Dh, int Dh_MAX, typename KERNEL_PARAMS_TYPE>
void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream)
{
constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16;
int tlength = params.timestep;
if (params.cache_indir == nullptr) {
if (tlength < 32) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream);
}
else if (tlength < 2048) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream);
}
else {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream);
}
}
else {
if (tlength < 32) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream);
}
else if (tlength < 2048) {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream);
}
else {
MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream);
}
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template void mgqa_launch_kernel<float, 144, 256, GroupedQuery_attention_params<float>>(
const GroupedQuery_attention_params<float>& params, const cudaStream_t& stream);
template void mgqa_launch_kernel<uint16_t, 144, 256, GroupedQuery_attention_params<uint16_t>>(
const GroupedQuery_attention_params<uint16_t>& params, const cudaStream_t& stream);
#ifdef ENABLE_BF16
template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, GroupedQuery_attention_params<__nv_bfloat16>>(
const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream);
#endif
#ifdef ENABLE_FP8
template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>(
const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream);
#endif

#undef MGQA_LAUNCH_KERNEL
Loading