diff --git a/.vscode/settings.json b/.vscode/settings.json index b43758f14..bb1913955 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -90,6 +90,8 @@ "__nullptr": "cpp", "__string": "cpp", "compare": "cpp", - "concepts": "cpp" + "concepts": "cpp", + "filesystem": "cpp", + "__memory": "cpp" } } diff --git a/CMakeLists.txt b/CMakeLists.txt index 9289d3b5f..acfb53e4a 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -393,6 +393,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt index ef9a0cced..816bcc5e8 100644 --- a/src/fastertransformer/kernels/CMakeLists.txt +++ b/src/fastertransformer/kernels/CMakeLists.txt @@ -15,6 +15,7 @@ cmake_minimum_required(VERSION 3.8) add_subdirectory(cutlass_kernels) +add_subdirectory(llama) add_library(image_shift_partition_kernels image_shift_partition_kernels.cu) set_property(TARGET image_shift_partition_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt new file mode 100644 index 000000000..07fa20a03 --- /dev/null +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(decoder_masked_groupedquery_attention_files + decoder_masked_groupedquery_attention.cu +) +file(GLOB decoder_masked_groupedquery_attention_files ${decoder_masked_groupedquery_attention_files} ./decoder_masked_groupedquery_attention/*.cu) +add_library(decoder_masked_groupedquery_attention STATIC ${decoder_masked_groupedquery_attention_files}) +set_property(TARGET decoder_masked_groupedquery_attention PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET decoder_masked_groupedquery_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu new file mode 100644 index 000000000..1ec0b3d53 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +template +void groupedquery_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 32: + mgqa_launch_kernel(params, stream); + break; + case 48: + mgqa_launch_kernel(params, stream); + break; + case 64: + mgqa_launch_kernel(params, stream); + break; + case 80: + mgqa_launch_kernel(params, stream); + break; + case 96: + mgqa_launch_kernel(params, stream); + break; + case 128: + mgqa_launch_kernel(params, stream); + break; + case 144: + mgqa_launch_kernel(params, stream); + break; + case 160: + mgqa_launch_kernel(params, stream); + break; + case 192: + mgqa_launch_kernel(params, stream); + break; + case 224: + mgqa_launch_kernel(params, stream); + break; + case 256: + mgqa_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream) +{ + groupedquery_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream) +{ + groupedquery_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + groupedquery_attention_<__nv_bfloat16, Masked_groupedquery_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream) +{ + groupedquery_attention_<__nv_fp8_e4m3, Masked_groupedquery_attention_params<__nv_fp8_e4m3>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h new file mode 100644 index 000000000..b0968519f --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include +#include +#include +#include +#include + +template +struct GroupedQuery_attention_params: public Multihead_attention_params_base { + // allows to exist attention eary + bool* finished = nullptr; + int num_kv_heads = 0; + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +using Masked_groupedquery_attention_params = GroupedQuery_attention_params; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu new file mode 100644 index 000000000..3c96b2ce5 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu new file mode 100644 index 000000000..7e20bdccc --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu new file mode 100644 index 000000000..57c6dd1aa --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu new file mode 100644 index 000000000..d8c349cad --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu new file mode 100644 index 000000000..03ff2cadd --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu new file mode 100644 index 000000000..fe496d4a7 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu new file mode 100644 index 000000000..ceeb96484 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu new file mode 100644 index 000000000..f225bef82 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu new file mode 100644 index 000000000..7a9679952 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu new file mode 100644 index 000000000..8af12155f --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu new file mode 100644 index 000000000..f91209194 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_groupedquery_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp new file mode 100644 index 000000000..581d566ca --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp @@ -0,0 +1,1878 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +// #define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_m_ { +}; + +template<> +struct Qk_vec_m_ { + using Type = float; +}; +template<> +struct Qk_vec_m_ { + using Type = float2; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint2; +}; +template<> +struct Qk_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_m_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 32> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 64> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 128> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 256> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_k_ { + using Type = typename Qk_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 32> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 64> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 128> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 256> { + using Type = float4; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_m_ { +}; + +template<> +struct K_vec_m_ { + using Type = float; +}; +template<> +struct K_vec_m_ { + using Type = float2; +}; +template<> +struct K_vec_m_ { + using Type = float4; +}; +template<> +struct K_vec_m_ { + using Type = uint32_t; +}; +template<> +struct K_vec_m_ { + using Type = uint2; +}; +template<> +struct K_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_m_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes +#ifdef ENABLE_FP8 +template<> +struct K_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct K_vec_m_<__nv_fp8_e4m3, 2> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_m_<__nv_fp8_e4m3, 1> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_k_ { + using Type = typename K_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct K_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct K_vec_k_<__nv_fp8_e4m3, 2> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_k_<__nv_fp8_e4m3, 1> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_m_ { +}; + +template<> +struct V_vec_m_ { + using Type = float; +}; +template<> +struct V_vec_m_ { + using Type = float2; +}; +template<> +struct V_vec_m_ { + using Type = float4; +}; +template<> +struct V_vec_m_ { + using Type = uint32_t; +}; +template<> +struct V_vec_m_ { + using Type = uint2; +}; +template<> +struct V_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_m_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template<> +struct V_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 8> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 16> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_k_ { + using Type = typename V_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct V_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 8> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 16> { + using Type = float4; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct K_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct K_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif // MMHA_USE_FP32_ACUM_FOR_FMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +// template<> +// struct V_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct V_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} +#ifdef ENABLE_FP8 +// fp8_t +template<> +__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) +{ + return float(a); +} +template<> +__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) +{ + return __nv_fp8_e4m3(a); +} +// fp8_2_t +template<> +__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) +{ + return float2(a); +} +template<> +__inline__ __device__ fp8_2_t vec_conversion(const float2& a) +{ + return fp8_2_t(a); +} +// fp8_4_t +template<> +__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) +{ + return float4(a); +} +template<> +__inline__ __device__ fp8_4_t vec_conversion(const float4& a) +{ + return fp8_4_t(a); +} +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x), "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +inline __device__ void convert_from_float(fp8_4_t& dst, float4 src) +{ + dst = fp8_4_t(src); +} +inline __device__ void convert_from_float(fp8_2_t& dst, float2 src) +{ + dst = fp8_2_t(src); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct kernel_type_t { + using Type = T; +}; + +#ifdef ENABLE_FP8 +template<> +struct kernel_type_t<__nv_fp8_e4m3> { + using Type = float; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + using Tk = typename kernel_type_t::Type; + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TDOD + logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool HAS_BEAMS> +__global__ void masked_groupedquery_attention_kernel(GroupedQuery_attention_params params) +{ + using Tk = typename kernel_type_t::Type; +#ifdef ENABLE_FP8 + // FP8 MHA Scales + constexpr bool FP8_MHA_KERNEL = std::is_same::value; +#else + constexpr bool FP8_MHA_KERNEL = false; +#endif + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += div_up(max_timesteps + 1, 4) * 16; + } + Tk* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + Tk* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec_k = typename Qk_vec_k_::Type; // with kernel-used precision + using Qk_vec_m = typename Qk_vec_m_::Type; // with memory-used precision + + // Use alignment for safely casting the shared buffers as Qk_vec_k. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX]; + + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + const int head_n_rep = params.num_heads / params.num_kv_heads; + // const int head_n_rep = 1; + // The head. + const int hi = blockIdx.x; + const int kvhi = hi / head_n_rep; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + const int bkvhi = bi * params.num_kv_heads + kvhi; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbkvhi = bbi * params.beam_width * params.num_kv_heads + kvhi; + // The thread in the block. + const int tidx = threadIdx.x; + + constexpr bool handle_kv = true; + + // here. + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + int tlength = (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + + const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; + const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + + // Trigger the loads from the Q and K buffers. + Qk_vec_k q; + zero(q); + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto q_scaling = params.qkv_scale_out[0]; + const auto q_quant = + *reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]); + + convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); + } + else { + q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); + } + } + + Qk_vec_k k; + zero(k); + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); + } + else { + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + k; + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec_k q_bias; + zero(q_bias); + q_bias = + (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.q_bias[qk_bias_offset])) : + q_bias; + + Qk_vec_k k_bias; + zero(k_bias); + if (handle_kv) { + k_bias = + !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.k_bias[qk_bias_offset])) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_kv) { + k = add(k, k_bias); + } + if (do_ia3 && !is_masked) { + k = mul( + k, + vec_conversion(*reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]))); + } + + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_kv) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_kv) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_kv) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_kv) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + // int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // // params.timestep*QK_ELTS_IN_16B + + // tlength_circ * QK_ELTS_IN_16B + ci; + int offset = bkvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (handle_kv && bhi%head_n_rep==0) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec_k; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. + + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec_k = typename K_vec_k_::Type; + using K_vec_m = typename K_vec_m_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec_k q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + // T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbkvhi * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // The keys loaded from the key cache. + K_vec_k k[K_VECS_PER_THREAD]; + K_vec_k k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (HAS_BEAMS) { + // const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + const int beam_offset = beam_indices[ti_circ] * params.num_kv_heads * params.memory_max_len * Dh; + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + } + else { + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + if (params.linear_bias_slopes != nullptr) { + // Apply the linear position bias: (ki - qi) * slope[hi]. + // The padding token locates between the input context and the generated tokens. + // We need to remove the number of padding tokens in the distance computation. + // ti : 0 1 2 3 4 5 6 7 8 9(tlength) + // token: i i i i p p p o o o where i=input, p=pad, o=output. + // e.g. ti = 2, dist = (9 - 3) - 2 = 4. + int max_context_length = params.max_prefix_prompt_length + params.max_input_length; + float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; + + qk += mul(params.linear_bias_slopes[hi], dist); + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; +#ifdef FP8_MHA + float logit = 0.f; + if (FP8_MHA_KERNEL) { + logit = is_mask ? 0.f : + __expf((qk_smem[ti - first_step] - qk_max) * params.query_weight_output_scale[0] + * params.query_weight_output_scale[0]); + } + else { + logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + } +#else + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); +#endif + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec_k = typename V_vec_k_::Type; + using V_vec_m = typename V_vec_m_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + // The base pointer for the value in the cache buffer. + // if (bkvhi == 63) { + // printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); + // } + T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec_k v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = vec_conversion( + *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec_k; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + + // Separate the ti < memory_max_len and ti > memory_max_len + // to prevent ti % memory_len when ti < memory_len, and + // the compiler cannot optimize the codes automatically. + const int min_length = min(tlength, params.memory_max_len); + for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; + // const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + if (ti < params.memory_max_len) { + // handled by previous loop + continue; + } + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec_k v; + // Trigger the loads from the V buffer. + const auto v_offset = qkv_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); + } + else { + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); + } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + + // Compute the V values with bias. + v = add(v, v_bias); + + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + if (bhi % head_n_rep == 0) { + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS + // out = fma(logits_smem[params.timestep], v, out); +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = mul(1.0f / params.attention_qk_scale[0], logits_smem[tlength]); + logit = logits_smem[tlength - first_step]; + } + else { + logit = logits_smem[tlength - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + out = fma(logits_smem[tlength - first_step], v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + if (FP8_MHA_KERNEL) { +#ifdef FP8_MHA + // float result_scale = params.attention_qk_scale[0] * params.query_weight_output_scale[0] * + // params.attention_output_weight_input_scale_inv[0]; + float result_scale = + params.query_weight_output_scale[0] * params.attention_output_weight_input_scale_inv[0]; + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), + mul(result_scale, out)); +#endif // FP8_MHA + } + else if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + out = mul(*params.attention_out_scale, out); + *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = + cast_to_int8(out); + } + else { + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); + } +#else // MMHA_USE_FP32_ACUM_FOR_OUT + // TODO: support int8_mode? + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = vec_conversion(out); +#endif // MMHA_USE_FP32_ACUM_FOR_OUT + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct threads_per_value_t { + static const int value = Dh_MAX * sizeof(T) / 16; +}; +#ifdef ENABLE_FP8 +template +struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> { + static const int value = Dh_MAX * 4 / 16; // DEBUG: float v +}; +#endif + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d0fb0a197..4f0df238e 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1698,6 +1698,34 @@ __global__ void transpose_4d_batch_major_k_cache( } } +template +__global__ void transpose_4d_batch_major_k_cache( + T* k_dst, const T* k_src, const int head_n_rep, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + auto key_src = reinterpret_cast(k_src + batch_id * head_n_rep * kv_head_num * size_per_head * seq_len + + head_id * head_n_rep * size_per_head * seq_len); + auto key_dst = reinterpret_cast(k_dst + batch_id * kv_head_num * size_per_head * max_seq_len + + head_id * size_per_head * max_seq_len); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + if (out_idx >= size_per_head_div_x * max_seq_len) { + return; + } + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + idx = (idx - k_seq_len_id) / max_seq_len; + const int k_head_size_id = idx % size_per_head_div_x; + + if (k_seq_len_id < seq_len) { + key_dst[out_idx] = key_src[k_seq_len_id * size_per_head_div_x + k_head_size_id]; + } +} + template __global__ void transpose_4d_batch_major_v_cache( T* v_dst, const T* v_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len) @@ -1724,6 +1752,32 @@ __global__ void transpose_4d_batch_major_v_cache( val_dst[idx] = val_src[idx]; } +template +__global__ void transpose_4d_batch_major_v_cache( + T* v_dst, const T* v_src, const int head_n_rep, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + + // 16 byte loads will handle "x" dimension + auto val_src = reinterpret_cast(v_src + batch_id * kv_head_num * head_n_rep * size_per_head * seq_len + + head_id * head_n_rep * size_per_head * seq_len); + auto val_dst = reinterpret_cast(v_dst + batch_id * kv_head_num * size_per_head * max_seq_len + + head_id * size_per_head * max_seq_len); + + // idx is over output dimension L * size_per_head / x for values + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int size_per_head_div_x = size_per_head / X_ELEMS; + + if (idx >= size_per_head_div_x * seq_len) { + return; + } + + val_dst[idx] = val_src[idx]; +} + template void invokeTranspose4dBatchMajor(T* k_dst, T* v_dst, @@ -1749,6 +1803,33 @@ void invokeTranspose4dBatchMajor(T* k_dst, v_dst, v_src, local_head_num, size_per_head, seq_len, max_seq_len); } +template +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const int local_kv_head_num, + cudaStream_t stream) +{ + constexpr int block_sz = 128; + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + int size = max_seq_len * size_per_head / x; + int head_n_rep = local_head_num / local_kv_head_num; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + + transpose_4d_batch_major_k_cache<<>>( + k_dst, k_src, head_n_rep, local_kv_head_num, size_per_head, seq_len, max_seq_len); + + transpose_4d_batch_major_v_cache<<>>( + v_dst, v_src, head_n_rep, local_kv_head_num, size_per_head, seq_len, max_seq_len); +} + #define INSTANTIATETRANSPOSE4DBATCHMAJOR(T) \ template void invokeTranspose4dBatchMajor(T* k_dst, \ T* v_dst, \ @@ -1759,6 +1840,17 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int max_seq_len, \ const int size_per_head, \ const int local_head_num, \ + cudaStream_t stream); \ + template void invokeTranspose4dBatchMajor(T* k_dst, \ + T* v_dst, \ + const T* k_src, \ + const T* v_src, \ + const int local_batch_size, \ + const int seq_len, \ + const int max_seq_len, \ + const int size_per_head, \ + const int local_head_num, \ + const int local_kv_head_num, \ cudaStream_t stream) INSTANTIATETRANSPOSE4DBATCHMAJOR(float); INSTANTIATETRANSPOSE4DBATCHMAJOR(half); diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h index 7ac7604d4..569c40f81 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h @@ -189,6 +189,19 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int local_head_num, cudaStream_t stream); +template +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const int local_kv_head_num, + cudaStream_t stream); + template void invokeAddRelativeAttentionBias(T* qk_buf, const T* relative_attention_bias, diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt index 628b3083a..60bbcffba 100644 --- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt @@ -42,7 +42,7 @@ target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasM add_library(LlamaDecoderSelfAttentionLayer STATIC LlamaDecoderSelfAttentionLayer.cc) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) +target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_groupedquery_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) add_library(LlamaContextAttentionLayer STATIC LlamaContextAttentionLayer.cc) set_property(TARGET LlamaContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 91de7d46d..f0a2076d9 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -43,9 +43,10 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // output_tensors: // hidden_features [token_num, hidden_dimension] - // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] - // value_cache [batch, local_head_num, max_seq_len, size_per_head] - printf("LlamaContextAttentionLayer::forward\n"); + // key_cache [batch, local_kv_head_num, size_per_head // x, max_seq_len, x] + // value_cache [batch, local_kv_head_num, max_seq_len, size_per_head] + printf("LlamaContextAttentionLayer::forward at layer: %d is_final: %d\n", input_tensors->getVal("layer_id"), input_tensors->at("is_final_layer").getVal()); + printf("is_free_buffer_after_forward_: %d\n", is_free_buffer_after_forward_); FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); FT_CHECK(output_tensors->at("value_cache").shape.size() == 4 @@ -356,11 +357,12 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten max_seq_len, size_per_head_, local_head_num_, + local_kv_head_num_, stream_); // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) sync_check_cuda_error(); - + printf("invokeTranspose4dBatchMajor done\n"); // TODO: fmha kernels doesn't support different seq lengths of q and kv if (attention_type == AttentionType::FUSED_MHA) { dispatcher_fp16->setup_causal_masked_fmha(request_seq_len, request_batch_size); diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index 5d12ff9a4..ed53c22d3 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -15,7 +15,7 @@ */ #include "src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/utils/logger.h" #include "src/fastertransformer/utils/memory_utils.h" #include "src/fastertransformer/kernels/repeat_kv_kernels.h" @@ -47,6 +47,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, const int beam_width, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const bool neox_rotary_style, @@ -70,7 +71,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, { using DataType = typename SATypeConverter::Type; // Prepare the parameters. - Masked_multihead_attention_params params; + Masked_groupedquery_attention_params params; memset(¶ms, 0, sizeof(params)); int hidden_units = head_num * size_per_head; if (qkv_bias != nullptr) { @@ -112,6 +113,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation params.timestep = step + max_prefix_prompt_length - 1; params.num_heads = head_num; + params.num_kv_heads = kv_head_num; params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; params.neox_rotary_style = neox_rotary_style; @@ -142,7 +144,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, } PUSH_RANGE("scaled dot-product fusion"); - masked_multihead_attention(params, stream); + masked_groupedquery_attention(params, stream); POP_RANGE; } @@ -160,6 +162,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, \ const int beam_width, \ const int head_num, \ + const int kv_head_num, \ const int size_per_head, \ const int rotary_embedding_dim, \ const bool neox_rotary_style, \ @@ -629,6 +632,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* output_tens batch_size, beam_width, local_head_num_, + local_kv_head_num_, size_per_head_, rotary_embedding_dim_, neox_rotary_style_, diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index 32b022820..f7e892998 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -104,7 +104,7 @@ void Llama::allocateBuffer( FT_LOG_DEBUG(__PRETTY_FUNCTION__); const size_t batchxbeam = batch_size * beam_width; const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_cache_seq_len - * hidden_units_ / tensor_para_.world_size_; + * kv_head_num_ * size_per_head_ / tensor_para_.world_size_; if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ = @@ -597,13 +597,13 @@ void Llama::forward(std::unordered_map* output_ten const std::vector self_k_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, - local_head_num_, + local_kv_head_num_, size_per_head_ / (16 / sizeof(T)), max_cache_seq_len, 16 / sizeof(T)}; const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, - local_head_num_, + local_kv_head_num_, max_cache_seq_len, size_per_head_}; diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index b2402fdcd..de8c666c8 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -444,7 +444,7 @@ void LlamaContextDecoder::forward(std::unordered_map* ite_cache_offset *= *t; } cache_offset += ite_cache_offset; - + printf("cache_offset: %d\n", cache_offset); T* k_cache_ptr = use_shared_contexts ? k_cache_layer_ : k_cache.getPtrWithOffset(cache_offset); T* v_cache_ptr = use_shared_contexts ? v_cache_layer_ : v_cache.getPtrWithOffset(cache_offset); @@ -478,13 +478,13 @@ void LlamaContextDecoder::forward(std::unordered_map* } #endif - + printf("use_shared_contexts: %d\n", use_shared_contexts); if (use_shared_contexts) { // Even with local batches, we must process the whole K/V caches as any // element in batch_idx_to_compact_idx may reference the local batch // we're processing. We also need to discard references that aren't in // that particular local batch. - const size_t cache_stride_per_batch = hidden_units_ / tensor_para_.world_size_ * max_seq_len; + const size_t cache_stride_per_batch = kv_head_num_ * size_per_head_ / tensor_para_.world_size_ * max_seq_len; const size_t cache_layer_offset = (l - getFirstLayerParallelId()) * request_batch_size * cache_stride_per_batch; invokeUnCompactCaches(k_cache.getPtrWithOffset(cache_layer_offset), @@ -493,7 +493,7 @@ void LlamaContextDecoder::forward(std::unordered_map* v_cache_layer_, input_tensors->at("batch_to_compact_idx").getPtr(), request_batch_size, // batch_size (uncompact) - v_cache.shape[2], // local_head_num + v_cache.shape[2], // local_kv_head_num max_seq_len, seq_len, size_per_head_, @@ -572,7 +572,6 @@ void LlamaContextDecoder::forward(std::unordered_map* } sync_check_cuda_error(); - #define ENABLE_FLEX_DEBUG #ifdef ENABLE_FLEX_DEBUG if (l == 1) { printf("%d %d: %d %d\n", l, ite, h_token_num, hidden_units_); diff --git a/src/fastertransformer/models/llama/LlamaDecoder.cc b/src/fastertransformer/models/llama/LlamaDecoder.cc index c82de8568..cb4c2f623 100644 --- a/src/fastertransformer/models/llama/LlamaDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaDecoder.cc @@ -241,6 +241,7 @@ void LlamaDecoder::forward(std::unordered_map* for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { self_k_cache_size.push_back(*t); } + #define ENABLE_FLEX_DEBUG #ifdef ENABLE_FLEX_DEBUG printf("self_k_cache_size: "); for (int i=0; i