From 1205addae11f00263e4dc8386c30a738501d16cd Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 21:25:07 -0700 Subject: [PATCH 001/135] commit --- src/fastertransformer/models/llama/Llama.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/models/llama/Llama.cc b/src/fastertransformer/models/llama/Llama.cc index 32b022820..f7e892998 100644 --- a/src/fastertransformer/models/llama/Llama.cc +++ b/src/fastertransformer/models/llama/Llama.cc @@ -104,7 +104,7 @@ void Llama::allocateBuffer( FT_LOG_DEBUG(__PRETTY_FUNCTION__); const size_t batchxbeam = batch_size * beam_width; const size_t self_cache_size = (num_layer_ / pipeline_para_.world_size_) * batchxbeam * max_cache_seq_len - * hidden_units_ / tensor_para_.world_size_; + * kv_head_num_ * size_per_head_ / tensor_para_.world_size_; if (vocab_size_ != vocab_size_padded_) { padded_embedding_kernel_ = @@ -597,13 +597,13 @@ void Llama::forward(std::unordered_map* output_ten const std::vector self_k_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, - local_head_num_, + local_kv_head_num_, size_per_head_ / (16 / sizeof(T)), max_cache_seq_len, 16 / sizeof(T)}; const std::vector self_v_cache_shape = {num_layer_ / pipeline_para_.world_size_, batch_size * beam_width, - local_head_num_, + local_kv_head_num_, max_cache_seq_len, size_per_head_}; From 088bebb8718e3d05221e1539e3b9b1c6ac364041 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 21:28:43 -0700 Subject: [PATCH 002/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 91de7d46d..06fa92eb4 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -43,8 +43,8 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // output_tensors: // hidden_features [token_num, hidden_dimension] - // key_cache [batch, local_head_num, size_per_head // x, max_seq_len, x] - // value_cache [batch, local_head_num, max_seq_len, size_per_head] + // key_cache [batch, local_kv_head_num, size_per_head // x, max_seq_len, x] + // value_cache [batch, local_kv_head_num, max_seq_len, size_per_head] printf("LlamaContextAttentionLayer::forward\n"); FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); From ef8e9067792dda7f322972ccf10d69647dd4c25c Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 21:57:41 -0700 Subject: [PATCH 003/135] commit --- .../attention_layers/LlamaContextAttentionLayer.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 06fa92eb4..4028f2def 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -46,6 +46,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // key_cache [batch, local_kv_head_num, size_per_head // x, max_seq_len, x] // value_cache [batch, local_kv_head_num, max_seq_len, size_per_head] printf("LlamaContextAttentionLayer::forward\n"); + printf("is_free_buffer_after_forward_: %d\n", is_free_buffer_after_forward_); FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); FT_CHECK(output_tensors->at("value_cache").shape.size() == 4 @@ -328,7 +329,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten k_buf_2_, v_buf_2_, param, // prefix prompt - qkv_buf_, + qkv_buf_tmp_, attention_weights->query_weight.bias, padding_offset, request_batch_size, @@ -355,7 +356,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten max_prompt_length + request_seq_len, // max input length + prefix prompt length max_seq_len, size_per_head_, - local_head_num_, + local_kv_head_num_, stream_); // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) @@ -727,9 +728,10 @@ void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, size_t seq qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, type_size * 3 * batch_size * seq_len * local_hidden_units_, true); size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_; qkv_buf_tmp_ = (T*)allocator_->reMalloc(qkv_buf_tmp_, type_size * batch_size * seq_len * local_qkv_size, true); - q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true); + // q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true); + q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * local_qkv_size, true); k_buf_2_ = q_buf_2_ + batch_size * seq_len * local_hidden_units_; - v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_hidden_units_; + v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_kv_head_num_ * size_per_head_; // save memory usage when using fmha if (allocate_qk_buf) { From 9eda6dfe41477006816d688c3788f99d23357c51 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:05:59 -0700 Subject: [PATCH 004/135] commit --- .../kernels/unfused_attention_kernels.cu | 92 +++++++++++++++++++ .../kernels/unfused_attention_kernels.h | 13 +++ .../LlamaContextAttentionLayer.cc | 2 +- 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d0fb0a197..991cf4e9b 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1698,6 +1698,35 @@ __global__ void transpose_4d_batch_major_k_cache( } } +template +__global__ void transpose_4d_batch_major_k_cache( + T* k_dst, const T* k_src, const int head_num, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + + auto key_src = reinterpret_cast(k_src + batch_id * head_num * size_per_head * seq_len + + head_id * size_per_head * seq_len); + auto key_dst = reinterpret_cast(k_dst + batch_id * kv_head_num * size_per_head * max_seq_len + + head_id * size_per_head * max_seq_len); + + const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; + int size_per_head_div_x = size_per_head / X_ELEMS; + if (out_idx >= size_per_head_div_x * max_seq_len) { + return; + } + + int idx = out_idx; + const int k_seq_len_id = idx % max_seq_len; + idx = (idx - k_seq_len_id) / max_seq_len; + const int k_head_size_id = idx % size_per_head_div_x; + + if (k_seq_len_id < seq_len) { + key_dst[out_idx] = key_src[k_seq_len_id * size_per_head_div_x + k_head_size_id]; + } +} + template __global__ void transpose_4d_batch_major_v_cache( T* v_dst, const T* v_src, const int head_num, const int size_per_head, const int seq_len, const int max_seq_len) @@ -1724,6 +1753,32 @@ __global__ void transpose_4d_batch_major_v_cache( val_dst[idx] = val_src[idx]; } +template +__global__ void transpose_4d_batch_major_v_cache( + T* v_dst, const T* v_src, const int head_num, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) +{ + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + + // 16 byte loads will handle "x" dimension + auto val_src = reinterpret_cast(v_src + batch_id * head_num * size_per_head * seq_len + + head_id * size_per_head * seq_len); + auto val_dst = reinterpret_cast(v_dst + batch_id * kv_head_num * size_per_head * max_seq_len + + head_id * size_per_head * max_seq_len); + + // idx is over output dimension L * size_per_head / x for values + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int size_per_head_div_x = size_per_head / X_ELEMS; + + if (idx >= size_per_head_div_x * seq_len) { + return; + } + + val_dst[idx] = val_src[idx]; +} + template void invokeTranspose4dBatchMajor(T* k_dst, T* v_dst, @@ -1749,6 +1804,32 @@ void invokeTranspose4dBatchMajor(T* k_dst, v_dst, v_src, local_head_num, size_per_head, seq_len, max_seq_len); } +template +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const int local_kv_head_num, + cudaStream_t stream) +{ + constexpr int block_sz = 128; + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + int size = max_seq_len * size_per_head / x; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + + transpose_4d_batch_major_k_cache<<>>( + k_dst, k_src, local_head_num, local_kv_head_num, size_per_head, seq_len, max_seq_len); + + transpose_4d_batch_major_v_cache<<>>( + v_dst, v_src, local_head_num, local_kv_head_num, size_per_head, seq_len, max_seq_len); +} + #define INSTANTIATETRANSPOSE4DBATCHMAJOR(T) \ template void invokeTranspose4dBatchMajor(T* k_dst, \ T* v_dst, \ @@ -1760,6 +1841,17 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int size_per_head, \ const int local_head_num, \ cudaStream_t stream) + template void invokeTranspose4dBatchMajor(T* k_dst, \ + T* v_dst, \ + const T* k_src, \ + const T* v_src, \ + const int local_batch_size, \ + const int seq_len, \ + const int max_seq_len, \ + const int size_per_head, \ + const int local_head_num, \ + const int local_kv_head_num, \ + cudaStream_t stream) INSTANTIATETRANSPOSE4DBATCHMAJOR(float); INSTANTIATETRANSPOSE4DBATCHMAJOR(half); #ifdef ENABLE_BF16 diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.h b/src/fastertransformer/kernels/unfused_attention_kernels.h index 7ac7604d4..569c40f81 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.h +++ b/src/fastertransformer/kernels/unfused_attention_kernels.h @@ -189,6 +189,19 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int local_head_num, cudaStream_t stream); +template +void invokeTranspose4dBatchMajor(T* k_dst, + T* v_dst, + const T* k_src, + const T* v_src, + const int local_batch_size, + const int seq_len, + const int max_seq_len, + const int size_per_head, + const int local_head_num, + const int local_kv_head_num, + cudaStream_t stream); + template void invokeAddRelativeAttentionBias(T* qk_buf, const T* relative_attention_bias, diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 4028f2def..e22aa3870 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -329,7 +329,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten k_buf_2_, v_buf_2_, param, // prefix prompt - qkv_buf_tmp_, + qkv_buf_, attention_weights->query_weight.bias, padding_offset, request_batch_size, From 98727db1ba83327c5fcfd6d355e07cf0e86398a5 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:06:49 -0700 Subject: [PATCH 005/135] commit --- src/fastertransformer/kernels/unfused_attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index 991cf4e9b..2866ff49b 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1840,7 +1840,7 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int max_seq_len, \ const int size_per_head, \ const int local_head_num, \ - cudaStream_t stream) + cudaStream_t stream); template void invokeTranspose4dBatchMajor(T* k_dst, \ T* v_dst, \ const T* k_src, \ From b5eb6cf4968e65b94200046cd84b004f11b63f16 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:07:23 -0700 Subject: [PATCH 006/135] commit --- src/fastertransformer/kernels/unfused_attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index 2866ff49b..e4f707033 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1840,7 +1840,7 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int max_seq_len, \ const int size_per_head, \ const int local_head_num, \ - cudaStream_t stream); + cudaStream_t stream); \ template void invokeTranspose4dBatchMajor(T* k_dst, \ T* v_dst, \ const T* k_src, \ From 62557b62105f70f2dcc24e9e1410dcb0d56c803d Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:08:45 -0700 Subject: [PATCH 007/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index e22aa3870..c61edc4a4 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -728,10 +728,9 @@ void LlamaContextAttentionLayer::allocateBuffer(size_t batch_size, size_t seq qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, type_size * 3 * batch_size * seq_len * local_hidden_units_, true); size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_; qkv_buf_tmp_ = (T*)allocator_->reMalloc(qkv_buf_tmp_, type_size * batch_size * seq_len * local_qkv_size, true); - // q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true); - q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * local_qkv_size, true); + q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true); k_buf_2_ = q_buf_2_ + batch_size * seq_len * local_hidden_units_; - v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_kv_head_num_ * size_per_head_; + v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_hidden_units_; // save memory usage when using fmha if (allocate_qk_buf) { From 0d73d631dd19aeab339f3eccc3ef8d15fdfb96a7 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:10:00 -0700 Subject: [PATCH 008/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index c61edc4a4..fd4073e5b 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -356,7 +356,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten max_prompt_length + request_seq_len, // max input length + prefix prompt length max_seq_len, size_per_head_, - local_kv_head_num_, + local_head_num_, stream_); // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) From 03087acbfc8032e447b41bdec1a268d2bcc37dd4 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:10:17 -0700 Subject: [PATCH 009/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index fd4073e5b..525388421 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -357,6 +357,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten max_seq_len, size_per_head_, local_head_num_, + local_kv_head_num_, stream_); // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) From aff24080dc909cb5b6156e39ffa907a5992a10da Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:12:55 -0700 Subject: [PATCH 010/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 525388421..592821d23 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -45,7 +45,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // hidden_features [token_num, hidden_dimension] // key_cache [batch, local_kv_head_num, size_per_head // x, max_seq_len, x] // value_cache [batch, local_kv_head_num, max_seq_len, size_per_head] - printf("LlamaContextAttentionLayer::forward\n"); + printf("LlamaContextAttentionLayer::forward at layer: %d\n", input_tensors->getVal("layer_id")); printf("is_free_buffer_after_forward_: %d\n", is_free_buffer_after_forward_); FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); From 5a2ebcb86df67bdc3b65491c6c356510e9a3ff19 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:16:34 -0700 Subject: [PATCH 011/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index 592821d23..b79ec47ea 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -362,7 +362,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // IDEA : after this, k_cache = (batch_size, num_heads, Dh/x, prefix_prompt_len + L, x) // k_cache = (batch_size, num_heads, prefix_prompt_len + L, Dh) sync_check_cuda_error(); - + printf("invokeTranspose4dBatchMajor done\n"); // TODO: fmha kernels doesn't support different seq lengths of q and kv if (attention_type == AttentionType::FUSED_MHA) { dispatcher_fp16->setup_causal_masked_fmha(request_seq_len, request_batch_size); From d0271a382dff162b30e1889868e474f8003f28d0 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:17:00 -0700 Subject: [PATCH 012/135] commit --- .../layers/attention_layers/LlamaContextAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc index b79ec47ea..f0a2076d9 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaContextAttentionLayer.cc @@ -45,7 +45,7 @@ void LlamaContextAttentionLayer::forward(TensorMap* output_ten // hidden_features [token_num, hidden_dimension] // key_cache [batch, local_kv_head_num, size_per_head // x, max_seq_len, x] // value_cache [batch, local_kv_head_num, max_seq_len, size_per_head] - printf("LlamaContextAttentionLayer::forward at layer: %d\n", input_tensors->getVal("layer_id")); + printf("LlamaContextAttentionLayer::forward at layer: %d is_final: %d\n", input_tensors->getVal("layer_id"), input_tensors->at("is_final_layer").getVal()); printf("is_free_buffer_after_forward_: %d\n", is_free_buffer_after_forward_); FT_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); FT_CHECK(output_tensors->at("key_cache").shape.size() == 5); From 699c569051995ff49523ec41afc6d6729b8a2de7 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:20:11 -0700 Subject: [PATCH 013/135] commit --- src/fastertransformer/models/llama/LlamaContextDecoder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index b2402fdcd..6e0d449b5 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -444,7 +444,7 @@ void LlamaContextDecoder::forward(std::unordered_map* ite_cache_offset *= *t; } cache_offset += ite_cache_offset; - + printf("cache_offset: %d\n", cache_offset); T* k_cache_ptr = use_shared_contexts ? k_cache_layer_ : k_cache.getPtrWithOffset(cache_offset); T* v_cache_ptr = use_shared_contexts ? v_cache_layer_ : v_cache.getPtrWithOffset(cache_offset); From 44558272db6a2a48de28808e8d1a38a20ef916b2 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:21:57 -0700 Subject: [PATCH 014/135] commit --- src/fastertransformer/models/llama/LlamaContextDecoder.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 6e0d449b5..1b56eb550 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -478,7 +478,7 @@ void LlamaContextDecoder::forward(std::unordered_map* } #endif - + printf("use_shared_contexts: %d\n", use_shared_contexts); if (use_shared_contexts) { // Even with local batches, we must process the whole K/V caches as any // element in batch_idx_to_compact_idx may reference the local batch From 7398a6e33475c7ef4d5af1cf5321491f509bb8bc Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:27:30 -0700 Subject: [PATCH 015/135] commit --- src/fastertransformer/models/llama/LlamaContextDecoder.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 1b56eb550..7d2a8bbed 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -479,6 +479,7 @@ void LlamaContextDecoder::forward(std::unordered_map* } #endif printf("use_shared_contexts: %d\n", use_shared_contexts); + use_shared_contexts = false; if (use_shared_contexts) { // Even with local batches, we must process the whole K/V caches as any // element in batch_idx_to_compact_idx may reference the local batch From 027c697c388a930f9c45be32b35623ef0260529a Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:28:00 -0700 Subject: [PATCH 016/135] commit --- src/fastertransformer/models/llama/LlamaContextDecoder.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 7d2a8bbed..6bb465a07 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -479,8 +479,7 @@ void LlamaContextDecoder::forward(std::unordered_map* } #endif printf("use_shared_contexts: %d\n", use_shared_contexts); - use_shared_contexts = false; - if (use_shared_contexts) { + if (use_shared_contexts && false) { // Even with local batches, we must process the whole K/V caches as any // element in batch_idx_to_compact_idx may reference the local batch // we're processing. We also need to discard references that aren't in From 4da2aa222240a07d69c00441456736eb1d2d4c29 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:31:38 -0700 Subject: [PATCH 017/135] commit --- src/fastertransformer/models/llama/LlamaContextDecoder.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 6bb465a07..268dee769 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -479,12 +479,12 @@ void LlamaContextDecoder::forward(std::unordered_map* } #endif printf("use_shared_contexts: %d\n", use_shared_contexts); - if (use_shared_contexts && false) { + if (use_shared_contexts) { // Even with local batches, we must process the whole K/V caches as any // element in batch_idx_to_compact_idx may reference the local batch // we're processing. We also need to discard references that aren't in // that particular local batch. - const size_t cache_stride_per_batch = hidden_units_ / tensor_para_.world_size_ * max_seq_len; + const size_t cache_stride_per_batch = kv_head_num_ * size_per_head_ / tensor_para_.world_size_ * max_seq_len; const size_t cache_layer_offset = (l - getFirstLayerParallelId()) * request_batch_size * cache_stride_per_batch; invokeUnCompactCaches(k_cache.getPtrWithOffset(cache_layer_offset), @@ -493,7 +493,7 @@ void LlamaContextDecoder::forward(std::unordered_map* v_cache_layer_, input_tensors->at("batch_to_compact_idx").getPtr(), request_batch_size, // batch_size (uncompact) - v_cache.shape[2], // local_head_num + v_cache.shape[2], // local_kv_head_num max_seq_len, seq_len, size_per_head_, From 07c2f5a07c922a4ddf4d4f3d76b04b33b7999d1e Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:33:33 -0700 Subject: [PATCH 018/135] commit --- src/fastertransformer/models/llama/LlamaDecoder.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/models/llama/LlamaDecoder.cc b/src/fastertransformer/models/llama/LlamaDecoder.cc index c82de8568..cb4c2f623 100644 --- a/src/fastertransformer/models/llama/LlamaDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaDecoder.cc @@ -241,6 +241,7 @@ void LlamaDecoder::forward(std::unordered_map* for (auto t = k_cache.shape.begin() + 2; t != k_cache.shape.end(); ++t) { self_k_cache_size.push_back(*t); } + #define ENABLE_FLEX_DEBUG #ifdef ENABLE_FLEX_DEBUG printf("self_k_cache_size: "); for (int i=0; i Date: Mon, 4 Sep 2023 22:49:52 -0700 Subject: [PATCH 019/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index 5d12ff9a4..fc104f254 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -47,6 +47,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, const int beam_width, const int head_num, + const int kv_head_num, const int size_per_head, const int rotary_embedding_dim, const bool neox_rotary_style, From 4d5bbe20b3be53dadded7cc64204d74498332e04 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:51:41 -0700 Subject: [PATCH 020/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index fc104f254..b3cf534f5 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -161,6 +161,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, const int inference_batch_size, \ const int beam_width, \ const int head_num, \ + const int kv_head_num, \ const int size_per_head, \ const int rotary_embedding_dim, \ const bool neox_rotary_style, \ From e8cceb48fc3b7c7de115708e1bced4f3bbbb845f Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 4 Sep 2023 22:52:20 -0700 Subject: [PATCH 021/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index b3cf534f5..c61a11bad 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -631,6 +631,7 @@ void LlamaDecoderSelfAttentionLayer::forward(TensorMap* output_tens batch_size, beam_width, local_head_num_, + local_kv_head_num_, size_per_head_, rotary_embedding_dim_, neox_rotary_style_, From a32c9d28880a5d7b2c416a2d8fce7db5456a7212 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 15:05:44 -0700 Subject: [PATCH 022/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index c61a11bad..a834ad1ee 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -15,7 +15,7 @@ */ #include "src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" #include "src/fastertransformer/utils/logger.h" #include "src/fastertransformer/utils/memory_utils.h" #include "src/fastertransformer/kernels/repeat_kv_kernels.h" From 1ceed7372822b3e0bf2eb0c67000ba2e65f56360 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 15:05:46 -0700 Subject: [PATCH 023/135] commit --- .../decoder_masked_multihead_attention.cu | 135 ++ .../decoder_masked_multihead_attention.h | 190 ++ .../decoder_masked_multihead_attention_128.cu | 103 + .../decoder_masked_multihead_attention_144.cu | 101 + .../decoder_masked_multihead_attention_160.cu | 101 + .../decoder_masked_multihead_attention_192.cu | 101 + .../decoder_masked_multihead_attention_224.cu | 101 + .../decoder_masked_multihead_attention_256.cu | 101 + .../decoder_masked_multihead_attention_32.cu | 101 + .../decoder_masked_multihead_attention_48.cu | 101 + .../decoder_masked_multihead_attention_64.cu | 101 + .../decoder_masked_multihead_attention_80.cu | 101 + .../decoder_masked_multihead_attention_96.cu | 101 + ...er_masked_multihead_attention_template.hpp | 1959 +++++++++++++++++ 14 files changed, 3397 insertions(+) create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_144.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_48.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu new file mode 100644 index 000000000..175bdf9a9 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +template +void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 32: + mmha_launch_kernel(params, stream); + break; + case 48: + mmha_launch_kernel(params, stream); + break; + case 64: + mmha_launch_kernel(params, stream); + break; + case 80: + mmha_launch_kernel(params, stream); + break; + case 96: + mmha_launch_kernel(params, stream); + break; + case 128: + mmha_launch_kernel(params, stream); + break; + case 144: + mmha_launch_kernel(params, stream); + break; + case 160: + mmha_launch_kernel(params, stream); + break; + case 192: + mmha_launch_kernel(params, stream); + break; + case 224: + mmha_launch_kernel(params, stream); + break; + case 256: + mmha_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_fp8_e4m3, Cross_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h new file mode 100644 index 000000000..5a768184c --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + +template +struct Multihead_attention_params_base { + + // The output buffer. Dimensions B x D. + T* out = nullptr; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q = nullptr, *q_bias = nullptr; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k = nullptr, *k_bias = nullptr; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v = nullptr, *v_bias = nullptr; + + // The cache for the Ks. The size must be at least B x L x D. + T* k_cache = nullptr; + // The cache for the Vs. The size must be at least B x L x D. + T* v_cache = nullptr; + // The indirections to use for cache when beam sampling. + const int* cache_indir = nullptr; + + // scales + const float* query_weight_output_scale = nullptr; + const float* attention_qk_scale = nullptr; + const float* attention_output_weight_input_scale_inv = nullptr; + + // Stride to handle the case when KQV is a single buffer + int stride = 0; + + // The batch size. + int batch_size = 0; + // The beam width + int beam_width = 0; + // The sequence length. + int memory_max_len = 0; + // The number of heads (H). + int num_heads = 0; + // The hidden dimension per head (Dh). + int hidden_size_per_head = 0; + // The per-head latent space reserved for rotary embeddings. + int rotary_embedding_dim = 0; + bool neox_rotary_style = false; + // The maximum length of input sentences. + int max_input_length = 0; + // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? + int timestep = 0; + // The current timestep of each sentences (support different timestep for different sentences) + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh = 0.0f; + + // Used when we have some input context like gpt + const int* total_padding_tokens = nullptr; + + const bool* masked_tokens = nullptr; + const int* prefix_prompt_lengths = nullptr; + int max_prefix_prompt_length = 0; + + const T* relative_attention_bias = nullptr; + int relative_attention_bias_stride = 0; + // The slope per head of linear position bias to attention score (H). + const T* linear_bias_slopes = nullptr; + + const T* ia3_key_weights = nullptr; + const T* ia3_value_weights = nullptr; + const int* ia3_tasks = nullptr; + + const float* qkv_scale_out = nullptr; + const float* attention_out_scale = nullptr; + int int8_mode = 0; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + // will need it here till if constexpr in c++17 + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +using Masked_multihead_attention_params = Multihead_attention_params; + +template +using Cross_multihead_attention_params = Multihead_attention_params; + +template +struct outputCrossAttentionParam { + // max decoder output length + int max_decoder_seq_len = 0; + T* cross_attention_out = nullptr; + bool is_return_cross_attentions = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream); +#endif +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu new file mode 100644 index 000000000..9b4f7c393 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu @@ -0,0 +1,103 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +#include "decoder_masked_multihead_attention_template.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 128, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 128, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_144.cu new file mode 100644 index 000000000..0da2134e9 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_144.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 144, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 144, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 144, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 144, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu new file mode 100644 index 000000000..86153f37a --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 160, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 160, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 160, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 160, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu new file mode 100644 index 000000000..12c6e22bf --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 192, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 192, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 192, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 192, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu new file mode 100644 index 000000000..7b17ae7b7 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 224, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 224, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 224, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 224, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu new file mode 100644 index 000000000..e17fa03ae --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 256, 256, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 256, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 256, 256, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 256, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu new file mode 100644 index 000000000..91ecc2f46 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 32, 32, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 32, 32, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 32, 32, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 32, 32, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_48.cu new file mode 100644 index 000000000..79bf3ca83 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_48.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 48, 64, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 48, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 48, 64, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 48, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL \ No newline at end of file diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu new file mode 100644 index 000000000..a4156e071 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 64, 64, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 64, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 64, 64, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 64, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu new file mode 100644 index 000000000..b94345952 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 80, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 80, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 80, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 80, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu new file mode 100644 index 000000000..6e754fd14 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_multihead_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + } + } + else { + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 96, 128, Masked_multihead_attention_params<__nv_bfloat16>>( + const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 96, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( + const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +template void mmha_launch_kernel>( + const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mmha_launch_kernel<__nv_bfloat16, 96, 128, Cross_multihead_attention_params<__nv_bfloat16>>( + const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mmha_launch_kernel<__nv_fp8_e4m3, 96, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( + const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp new file mode 100644 index 000000000..b9c1329f8 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -0,0 +1,1959 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +// #define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_m_ { +}; + +template<> +struct Qk_vec_m_ { + using Type = float; +}; +template<> +struct Qk_vec_m_ { + using Type = float2; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint2; +}; +template<> +struct Qk_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_m_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 32> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 64> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 128> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 256> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_k_ { + using Type = typename Qk_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 32> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 64> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 128> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 256> { + using Type = float4; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_m_ { +}; + +template<> +struct K_vec_m_ { + using Type = float; +}; +template<> +struct K_vec_m_ { + using Type = float2; +}; +template<> +struct K_vec_m_ { + using Type = float4; +}; +template<> +struct K_vec_m_ { + using Type = uint32_t; +}; +template<> +struct K_vec_m_ { + using Type = uint2; +}; +template<> +struct K_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_m_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes +#ifdef ENABLE_FP8 +template<> +struct K_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct K_vec_m_<__nv_fp8_e4m3, 2> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_m_<__nv_fp8_e4m3, 1> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_k_ { + using Type = typename K_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct K_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct K_vec_k_<__nv_fp8_e4m3, 2> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_k_<__nv_fp8_e4m3, 1> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_m_ { +}; + +template<> +struct V_vec_m_ { + using Type = float; +}; +template<> +struct V_vec_m_ { + using Type = float2; +}; +template<> +struct V_vec_m_ { + using Type = float4; +}; +template<> +struct V_vec_m_ { + using Type = uint32_t; +}; +template<> +struct V_vec_m_ { + using Type = uint2; +}; +template<> +struct V_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_m_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template<> +struct V_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 8> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 16> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_k_ { + using Type = typename V_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct V_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 8> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 16> { + using Type = float4; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct K_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct K_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif // MMHA_USE_FP32_ACUM_FOR_FMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +// template<> +// struct V_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct V_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} +#ifdef ENABLE_FP8 +// fp8_t +template<> +__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) +{ + return float(a); +} +template<> +__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) +{ + return __nv_fp8_e4m3(a); +} +// fp8_2_t +template<> +__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) +{ + return float2(a); +} +template<> +__inline__ __device__ fp8_2_t vec_conversion(const float2& a) +{ + return fp8_2_t(a); +} +// fp8_4_t +template<> +__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) +{ + return float4(a); +} +template<> +__inline__ __device__ fp8_4_t vec_conversion(const float4& a) +{ + return fp8_4_t(a); +} +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x), "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +inline __device__ void convert_from_float(fp8_4_t& dst, float4 src) +{ + dst = fp8_4_t(src); +} +inline __device__ void convert_from_float(fp8_2_t& dst, float2 src) +{ + dst = fp8_2_t(src); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct kernel_type_t { + using Type = T; +}; + +#ifdef ENABLE_FP8 +template<> +struct kernel_type_t<__nv_fp8_e4m3> { + using Type = float; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const Multihead_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + using Tk = typename kernel_type_t::Type; + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TDOD + logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(Tk) : + div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool DO_CROSS_ATTENTION, + bool HAS_BEAMS> +__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) +{ + + using Tk = typename kernel_type_t::Type; +#ifdef ENABLE_FP8 + // FP8 MHA Scales + constexpr bool FP8_MHA_KERNEL = std::is_same::value; +#else + constexpr bool FP8_MHA_KERNEL = false; +#endif + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += + (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + } + Tk* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + Tk* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec_k = typename Qk_vec_k_::Type; // with kernel-used precision + using Qk_vec_m = typename Qk_vec_m_::Type; // with memory-used precision + + // Use alignment for safely casting the shared buffers as Qk_vec_k. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX]; + + // This is one of the reasons we should have a separate kernel for cross attention + __shared__ __align__(sizeof(Qk_vec_k)) Tk bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; + + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + // The head. + const int hi = blockIdx.x; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + // The thread in the block. + const int tidx = threadIdx.x; + + const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : + (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + + const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; + const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + + // Trigger the loads from the Q and K buffers. + Qk_vec_k q; + zero(q); + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto q_scaling = params.qkv_scale_out[0]; + const auto q_quant = + *reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]); + + convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); + } + else { + q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); + } + } + + Qk_vec_k k; + zero(k); + if (DO_CROSS_ATTENTION) { + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k_cache[offset])) : + k; + } + else { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); + } + else { + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + k; + } + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec_k q_bias; + zero(q_bias); + q_bias = + (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.q_bias[qk_bias_offset])) : + q_bias; + + Qk_vec_k k_bias; + zero(k_bias); + if (handle_kv) { + k_bias = + !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.k_bias[qk_bias_offset])) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_kv) { + k = add(k, k_bias); + } + if (do_ia3 && !is_masked) { + k = mul( + k, + vec_conversion(*reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]))); + } + + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_kv) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_kv) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_kv) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_kv) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Store Dh values of k_bias into smem, since will need to add later + // if params.timestep == 0 + if (DO_CROSS_ATTENTION && params.timestep == 0) { + *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; + } + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (handle_kv) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec_k; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. + + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec_k = typename K_vec_k_::Type; + using K_vec_m = typename K_vec_m_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec_k q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + K_vec_k k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; + if (DO_CROSS_ATTENTION && params.timestep == 0) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // The keys loaded from the key cache. + K_vec_k k[K_VECS_PER_THREAD]; + K_vec_k k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (HAS_BEAMS) { + const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + } + else { + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + } + } + // add bias and update k_cache + if (DO_CROSS_ATTENTION && params.timestep == 0) { + k[ii] = add(k[ii], k_bias_vec[ii]); + + if (do_ia3) { + k[ii] = mul( + k[ii], + vec_conversion(*reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE]))); + } + + if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { + *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = + vec_conversion(k[ii]); + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + if (params.linear_bias_slopes != nullptr) { + // Apply the linear position bias: (ki - qi) * slope[hi]. + // The padding token locates between the input context and the generated tokens. + // We need to remove the number of padding tokens in the distance computation. + // ti : 0 1 2 3 4 5 6 7 8 9(tlength) + // token: i i i i p p p o o o where i=input, p=pad, o=output. + // e.g. ti = 2, dist = (9 - 3) - 2 = 4. + int max_context_length = params.max_prefix_prompt_length + params.max_input_length; + float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; + + qk += mul(params.linear_bias_slopes[hi], dist); + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; +#ifdef FP8_MHA + float logit = 0.f; + if (FP8_MHA_KERNEL) { + logit = is_mask ? 0.f : + __expf((qk_smem[ti - first_step] - qk_max) * params.query_weight_output_scale[0] + * params.query_weight_output_scale[0]); + } + else { + logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + } +#else + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); +#endif + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + const size_t cross_attention_out_offset = + params.is_return_cross_attentions ? + bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : + 0; + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + if (params.is_return_cross_attentions) { + params.cross_attention_out[cross_attention_out_offset + ti] = logit; + } + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec_k = typename V_vec_k_::Type; + using V_vec_m = typename V_vec_m_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + + // The base pointer for the value in the cache buffer. + T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec_k v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (handle_kv) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = vec_conversion( + *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); + } + if (DO_CROSS_ATTENTION) { + *reinterpret_cast(&bias_smem[vi]) = vec_conversion(v_bias); + } + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec_k; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + + // Separate the ti < memory_max_len and ti > memory_max_len + // to prevent ti % memory_len when ti < memory_len, and + // the compiler cannot optimize the codes automatically. + const int min_length = min(tlength, params.memory_max_len); + for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + if (DO_CROSS_ATTENTION && params.timestep == 0) { + v = add(v, vec_conversion(*reinterpret_cast(&bias_smem[vi]))); + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + *reinterpret_cast(&v_cache[ti * Dh]) = vec_conversion(v); + } + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + if (ti < params.memory_max_len) { + // handled by previous loop + continue; + } + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + if (DO_CROSS_ATTENTION && params.timestep == 0) { + v = add(v, vec_conversion(*reinterpret_cast(&bias_smem[vi]))); + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + *reinterpret_cast(&v_cache[ti * Dh]) = vec_conversion(v); + } + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec_k v; + if (DO_CROSS_ATTENTION) { + v = vec_conversion(*reinterpret_cast(&v_cache[tlength * Dh])); + } + else { + // Trigger the loads from the V buffer. + const auto v_offset = qkv_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); + } + else { + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); + } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + } + + // Compute the V values with bias. + if (handle_kv) { + v = add(v, v_bias); + + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS + // out = fma(logits_smem[params.timestep], v, out); +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = mul(1.0f / params.attention_qk_scale[0], logits_smem[tlength]); + logit = logits_smem[tlength - first_step]; + } + else { + logit = logits_smem[tlength - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + out = fma(logits_smem[tlength - first_step], v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + if (FP8_MHA_KERNEL) { +#ifdef FP8_MHA + // float result_scale = params.attention_qk_scale[0] * params.query_weight_output_scale[0] * + // params.attention_output_weight_input_scale_inv[0]; + float result_scale = + params.query_weight_output_scale[0] * params.attention_output_weight_input_scale_inv[0]; + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), + mul(result_scale, out)); +#endif // FP8_MHA + } + else if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + out = mul(*params.attention_out_scale, out); + *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = + cast_to_int8(out); + } + else { + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); + } +#else // MMHA_USE_FP32_ACUM_FOR_OUT + // TODO: support int8_mode? + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = vec_conversion(out); +#endif // MMHA_USE_FP32_ACUM_FOR_OUT + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct threads_per_value_t { + static const int value = Dh_MAX * sizeof(T) / 16; +}; +#ifdef ENABLE_FP8 +template +struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> { + static const int value = Dh_MAX * 4 / 16; // DEBUG: float v +}; +#endif + +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); From ad919b819c4d021483ea9c59e25a2b8686f3bcf6 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 15:06:51 -0700 Subject: [PATCH 024/135] commit --- .../kernels/llama/decoder_masked_multihead_attention.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index 5a768184c..c7292f47a 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -49,7 +49,7 @@ // Dh: Hidden dimension per head - Dh = D / H. template -struct Multihead_attention_params_base { +struct Multihead_attention_params_base2 { // The output buffer. Dimensions B x D. T* out = nullptr; From 630ced0afa9594a58ce1f2d3d588e14249896444 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 15:58:40 -0700 Subject: [PATCH 025/135] commit --- .../decoder_masked_multihead_attention.cu | 16 +- .../decoder_masked_multihead_attention.h | 144 ++---------------- 2 files changed, 17 insertions(+), 143 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu index 175bdf9a9..c91004b6b 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu @@ -66,35 +66,35 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_>(params, stream); + multihead_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_>(params, stream); + multihead_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, +void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream) { - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); + multihead_attention_<__nv_bfloat16, Masked_llama_multihead_attention_params<__nv_bfloat16>>(params, stream); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream) { - multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); + multihead_attention_<__nv_fp8_e4m3, Masked_llama_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); } #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index c7292f47a..89c7a0852 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -16,6 +16,7 @@ #pragma once +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" #include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" @@ -25,140 +26,20 @@ #include #include -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - template -struct Multihead_attention_params_base2 { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // scales - const float* query_weight_output_scale = nullptr; - const float* attention_qk_scale = nullptr; - const float* attention_output_weight_input_scale_inv = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - +struct Llama_multihead_attention_params: public Multihead_attention_params_base { // allows to exist attention eary bool* finished = nullptr; - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - // required in case of masked attention with different length const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - // required in case of masked attention with different length - const int* length_per_sample = nullptr; + // number of kv heads(kvH) + int num_kv_heads = 0; }; template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; +using Masked_llama_multihead_attention_params = Llama_multihead_attention_params; template struct outputCrossAttentionParam { @@ -170,21 +51,14 @@ struct outputCrossAttentionParam { //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, +void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif - //////////////////////////////////////////////////////////////////////////////////////////////////// From 161774c415860986b85365057485992b387c037e Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:00:18 -0700 Subject: [PATCH 026/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index a834ad1ee..987cf8013 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -71,7 +71,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, { using DataType = typename SATypeConverter::Type; // Prepare the parameters. - Masked_multihead_attention_params params; + Masked_llama_multihead_attention_params params; memset(¶ms, 0, sizeof(params)); int hidden_units = head_num * size_per_head; if (qkv_bias != nullptr) { @@ -113,6 +113,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation params.timestep = step + max_prefix_prompt_length - 1; params.num_heads = head_num; + params.num_kv_heads = kv_head_num; params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; params.neox_rotary_style = neox_rotary_style; From 50a4215489c920cc3aa30e7b703cccedefddeaf8 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:02:29 -0700 Subject: [PATCH 027/135] commit --- .../kernels/llama/decoder_masked_multihead_attention.h | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index 89c7a0852..e0d928009 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -27,7 +27,7 @@ #include template -struct Llama_multihead_attention_params: public Multihead_attention_params_base { +struct Llama_multihead_attention_params: public Multihead_attention_params_base { // allows to exist attention eary bool* finished = nullptr; @@ -41,14 +41,6 @@ struct Llama_multihead_attention_params: public Multihead_attention_params_ba template using Masked_llama_multihead_attention_params = Llama_multihead_attention_params; -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - //////////////////////////////////////////////////////////////////////////////////////////////////// void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream); From f19a93e3bb72db6fe44526ecd2c47883a280a5d4 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:02:57 -0700 Subject: [PATCH 028/135] commit --- .../kernels/llama/decoder_masked_multihead_attention.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index e0d928009..c8b6912be 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -16,7 +16,7 @@ #pragma once -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" From a4d4743c435fbf658e0bc4195b1efb7a212f45cd Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:05:49 -0700 Subject: [PATCH 029/135] commit --- .../kernels/llama/CMakeLists.txt | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 src/fastertransformer/kernels/llama/CMakeLists.txt diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt new file mode 100644 index 000000000..7848e5413 --- /dev/null +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -0,0 +1,23 @@ +# Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +cmake_minimum_required(VERSION 3.8) + +set(decoder_masked_multihead_attention_files + decoder_masked_multihead_attention.cu +) +file(GLOB decoder_masked_multihead_attention_files ${decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) +add_library(decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) +set_property(TARGET decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) From 2f264c757f34b772d979e5a32652a7daf380d99c Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:11:40 -0700 Subject: [PATCH 030/135] commit --- CMakeLists.txt | 1 + src/fastertransformer/kernels/llama/CMakeLists.txt | 2 +- src/fastertransformer/layers/attention_layers/CMakeLists.txt | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9289d3b5f..b3b9b857e 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -393,6 +393,7 @@ add_library(transformer-shared SHARED $ $ $ + $ $ $ $ diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt index 7848e5413..0640034c9 100644 --- a/src/fastertransformer/kernels/llama/CMakeLists.txt +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -18,6 +18,6 @@ set(decoder_masked_multihead_attention_files decoder_masked_multihead_attention.cu ) file(GLOB decoder_masked_multihead_attention_files ${decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) -add_library(decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) +add_library(llama_decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) set_property(TARGET decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt index 628b3083a..0f927b8bd 100644 --- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt @@ -42,7 +42,7 @@ target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasM add_library(LlamaDecoderSelfAttentionLayer STATIC LlamaDecoderSelfAttentionLayer.cc) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) +target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils llama_decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) add_library(LlamaContextAttentionLayer STATIC LlamaContextAttentionLayer.cc) set_property(TARGET LlamaContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) From a42ab9dd2709c916b38263abceb1af7fde353847 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:12:19 -0700 Subject: [PATCH 031/135] commit --- src/fastertransformer/kernels/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/CMakeLists.txt b/src/fastertransformer/kernels/CMakeLists.txt index ef9a0cced..816bcc5e8 100644 --- a/src/fastertransformer/kernels/CMakeLists.txt +++ b/src/fastertransformer/kernels/CMakeLists.txt @@ -15,6 +15,7 @@ cmake_minimum_required(VERSION 3.8) add_subdirectory(cutlass_kernels) +add_subdirectory(llama) add_library(image_shift_partition_kernels image_shift_partition_kernels.cu) set_property(TARGET image_shift_partition_kernels PROPERTY POSITION_INDEPENDENT_CODE ON) From 1a227efa992982f325d437d81a4136923f12b76c Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:13:04 -0700 Subject: [PATCH 032/135] commit --- src/fastertransformer/kernels/llama/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt index 0640034c9..28fd48b05 100644 --- a/src/fastertransformer/kernels/llama/CMakeLists.txt +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -19,5 +19,5 @@ set(decoder_masked_multihead_attention_files ) file(GLOB decoder_masked_multihead_attention_files ${decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) add_library(llama_decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) -set_property(TARGET decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) From 23619dc7a9ecdd12c3e0991860d33149df553a1d Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:19:30 -0700 Subject: [PATCH 033/135] commit --- src/fastertransformer/kernels/llama/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt index 28fd48b05..37f6014d9 100644 --- a/src/fastertransformer/kernels/llama/CMakeLists.txt +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -14,10 +14,10 @@ cmake_minimum_required(VERSION 3.8) -set(decoder_masked_multihead_attention_files +set(llama_decoder_masked_multihead_attention_files decoder_masked_multihead_attention.cu ) -file(GLOB decoder_masked_multihead_attention_files ${decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) -add_library(llama_decoder_masked_multihead_attention STATIC ${decoder_masked_multihead_attention_files}) +file(GLOB llama_decoder_masked_multihead_attention_files ${llama_decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) +add_library(llama_decoder_masked_multihead_attention STATIC ${llama_decoder_masked_multihead_attention_files}) set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) From 75262213e2ae41c2af9d72dc1c5795679de0cc39 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 16:49:38 -0700 Subject: [PATCH 034/135] commit --- .../decoder_masked_multihead_attention.cu | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu index c91004b6b..4431532bb 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu @@ -99,37 +99,3 @@ void masked_multihead_attention(const Masked_llama_multihead_attention_params<__ #endif //////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_FP8 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_fp8_e4m3, Cross_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// From c236b5dd2cde103e173710fc3c43ef64ed1f1f94 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:09:30 -0700 Subject: [PATCH 035/135] commit --- .../decoder_masked_multihead_attention.cu | 54 +++++-- .../decoder_masked_multihead_attention.h | 152 ++++++++++++++++-- 2 files changed, 187 insertions(+), 19 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu index 4431532bb..4618673d8 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include @@ -66,35 +66,69 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream) +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_>(params, stream); + multihead_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream) +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_>(params, stream); + multihead_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_bfloat16>& params, +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream) { - multihead_attention_<__nv_bfloat16, Masked_llama_multihead_attention_params<__nv_bfloat16>>(params, stream); + multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 -void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream) { - multihead_attention_<__nv_fp8_e4m3, Masked_llama_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); + multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_fp8_e4m3, Cross_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); } #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index c8b6912be..5a768184c 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -16,7 +16,6 @@ #pragma once -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" @@ -26,31 +25,166 @@ #include #include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + template -struct Llama_multihead_attention_params: public Multihead_attention_params_base { +struct Multihead_attention_params_base { + + // The output buffer. Dimensions B x D. + T* out = nullptr; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q = nullptr, *q_bias = nullptr; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k = nullptr, *k_bias = nullptr; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v = nullptr, *v_bias = nullptr; + + // The cache for the Ks. The size must be at least B x L x D. + T* k_cache = nullptr; + // The cache for the Vs. The size must be at least B x L x D. + T* v_cache = nullptr; + // The indirections to use for cache when beam sampling. + const int* cache_indir = nullptr; + + // scales + const float* query_weight_output_scale = nullptr; + const float* attention_qk_scale = nullptr; + const float* attention_output_weight_input_scale_inv = nullptr; + + // Stride to handle the case when KQV is a single buffer + int stride = 0; + + // The batch size. + int batch_size = 0; + // The beam width + int beam_width = 0; + // The sequence length. + int memory_max_len = 0; + // The number of heads (H). + int num_heads = 0; + // The hidden dimension per head (Dh). + int hidden_size_per_head = 0; + // The per-head latent space reserved for rotary embeddings. + int rotary_embedding_dim = 0; + bool neox_rotary_style = false; + // The maximum length of input sentences. + int max_input_length = 0; + // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? + int timestep = 0; + // The current timestep of each sentences (support different timestep for different sentences) + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh = 0.0f; + + // Used when we have some input context like gpt + const int* total_padding_tokens = nullptr; + + const bool* masked_tokens = nullptr; + const int* prefix_prompt_lengths = nullptr; + int max_prefix_prompt_length = 0; + + const T* relative_attention_bias = nullptr; + int relative_attention_bias_stride = 0; + // The slope per head of linear position bias to attention score (H). + const T* linear_bias_slopes = nullptr; + + const T* ia3_key_weights = nullptr; + const T* ia3_value_weights = nullptr; + const int* ia3_tasks = nullptr; + + const float* qkv_scale_out = nullptr; + const float* attention_out_scale = nullptr; + int int8_mode = 0; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + // allows to exist attention eary bool* finished = nullptr; + // required in case of cross attention + // will need it here till if constexpr in c++17 + int* memory_length_per_sample = nullptr; + // required in case of masked attention with different length const int* length_per_sample = nullptr; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; - // number of kv heads(kvH) - int num_kv_heads = 0; + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; }; template -using Masked_llama_multihead_attention_params = Llama_multihead_attention_params; +using Masked_multihead_attention_params = Multihead_attention_params; + +template +using Cross_multihead_attention_params = Multihead_attention_params; + +template +struct outputCrossAttentionParam { + // max decoder output length + int max_decoder_seq_len = 0; + T* cross_attention_out = nullptr; + bool is_return_cross_attentions = false; +}; //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_llama_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_bfloat16>& params, +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -void masked_multihead_attention(const Masked_llama_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif + //////////////////////////////////////////////////////////////////////////////////////////////////// From 24ebefa2ff0497adc312cd07230e8b3be1423d01 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:09:53 -0700 Subject: [PATCH 036/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index 987cf8013..dacf003fa 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -71,7 +71,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, { using DataType = typename SATypeConverter::Type; // Prepare the parameters. - Masked_llama_multihead_attention_params params; + Masked_multihead_attention_params params; memset(¶ms, 0, sizeof(params)); int hidden_units = head_num * size_per_head; if (qkv_bias != nullptr) { From 4114f970b425ddd52195e77d204608c90b23a6d5 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:10:07 -0700 Subject: [PATCH 037/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index dacf003fa..d847b6a67 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -113,7 +113,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation params.timestep = step + max_prefix_prompt_length - 1; params.num_heads = head_num; - params.num_kv_heads = kv_head_num; + // params.num_kv_heads = kv_head_num; params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; params.neox_rotary_style = neox_rotary_style; From 8af8e0d1de9959c4f11b1c4655dd7eb42154cea3 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:13:21 -0700 Subject: [PATCH 038/135] commit --- .../kernels/llama/decoder_masked_multihead_attention.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu index 4618673d8..175bdf9a9 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include From 135b5fd97afbfa6b24356462c521c084c8906168 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:14:55 -0700 Subject: [PATCH 039/135] commit --- .../decoder_masked_multihead_attention.h | 144 +----------------- 1 file changed, 1 insertion(+), 143 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index 5a768184c..1139d23a6 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -16,6 +16,7 @@ #pragma once +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" #include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" @@ -27,149 +28,6 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define CHECK_CUDA(call) \ - do { \ - cudaError_t status_ = call; \ - if (status_ != cudaSuccess) { \ - fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ - exit(1); \ - } \ - } while (0) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// The structure of parameters for the masked multihead attention kernel. -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. - -template -struct Multihead_attention_params_base { - - // The output buffer. Dimensions B x D. - T* out = nullptr; - - // The input Qs and the associated bias. Dimensions B x D and D, resp. - const T *q = nullptr, *q_bias = nullptr; - // The input Ks and the associated bias. Dimensions B x D and D, resp. - const T *k = nullptr, *k_bias = nullptr; - // The input Vs and the associated bias. Dimensions B x D and D, resp. - const T *v = nullptr, *v_bias = nullptr; - - // The cache for the Ks. The size must be at least B x L x D. - T* k_cache = nullptr; - // The cache for the Vs. The size must be at least B x L x D. - T* v_cache = nullptr; - // The indirections to use for cache when beam sampling. - const int* cache_indir = nullptr; - - // scales - const float* query_weight_output_scale = nullptr; - const float* attention_qk_scale = nullptr; - const float* attention_output_weight_input_scale_inv = nullptr; - - // Stride to handle the case when KQV is a single buffer - int stride = 0; - - // The batch size. - int batch_size = 0; - // The beam width - int beam_width = 0; - // The sequence length. - int memory_max_len = 0; - // The number of heads (H). - int num_heads = 0; - // The hidden dimension per head (Dh). - int hidden_size_per_head = 0; - // The per-head latent space reserved for rotary embeddings. - int rotary_embedding_dim = 0; - bool neox_rotary_style = false; - // The maximum length of input sentences. - int max_input_length = 0; - // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? - int timestep = 0; - // The current timestep of each sentences (support different timestep for different sentences) - - // The 1.f / sqrt(Dh). Computed on the host. - float inv_sqrt_dh = 0.0f; - - // Used when we have some input context like gpt - const int* total_padding_tokens = nullptr; - - const bool* masked_tokens = nullptr; - const int* prefix_prompt_lengths = nullptr; - int max_prefix_prompt_length = 0; - - const T* relative_attention_bias = nullptr; - int relative_attention_bias_stride = 0; - // The slope per head of linear position bias to attention score (H). - const T* linear_bias_slopes = nullptr; - - const T* ia3_key_weights = nullptr; - const T* ia3_value_weights = nullptr; - const int* ia3_tasks = nullptr; - - const float* qkv_scale_out = nullptr; - const float* attention_out_scale = nullptr; - int int8_mode = 0; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - // will need it here till if constexpr in c++17 - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -struct Multihead_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; -}; - -template -using Masked_multihead_attention_params = Multihead_attention_params; - -template -using Cross_multihead_attention_params = Multihead_attention_params; - -template -struct outputCrossAttentionParam { - // max decoder output length - int max_decoder_seq_len = 0; - T* cross_attention_out = nullptr; - bool is_return_cross_attentions = false; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 From 8d4b18d2283abfdcfb264486d2eaa7de9b43029e Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:18:01 -0700 Subject: [PATCH 040/135] commit --- .../decoder_masked_multihead_attention.cu | 42 ++----------------- .../decoder_masked_multihead_attention.h | 14 ++----- 2 files changed, 8 insertions(+), 48 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu index 175bdf9a9..ae821226e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu @@ -66,14 +66,14 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { multihead_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { multihead_attention_>(params, stream); } @@ -81,7 +81,7 @@ void masked_multihead_attention(const Masked_multihead_attention_params& params, +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream) { multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); @@ -91,7 +91,7 @@ void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfl //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream) { multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); @@ -99,37 +99,3 @@ void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8 #endif //////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_FP8 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_fp8_e4m3, Cross_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h index 1139d23a6..4f77cf979 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h @@ -28,21 +28,15 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -void masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif //////////////////////////////////////////////////////////////////////////////////////////////////// From 4c127bf8007b5850157ad63ad0aa6df1bf7f7fcb Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:19:36 -0700 Subject: [PATCH 041/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp index b9c1329f8..b221df419 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -15,8 +15,8 @@ */ #pragma once -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_type_utils.cuh" From 50f1de52d10bfebcc81f618a12323cb7973ee4e6 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:19:42 -0700 Subject: [PATCH 042/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp index b221df419..fbb0230ea 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -16,7 +16,7 @@ #pragma once #include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" #include "src/fastertransformer/utils/cuda_type_utils.cuh" From 8a80dde1e5984b22f46594ca9ee71bc518eb5ab7 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:24:40 -0700 Subject: [PATCH 043/135] commit --- .../kernels/llama/CMakeLists.txt | 2 +- .../decoder_masked_multihead_attention.cu | 101 ------------------ .../decoder_masked_multihead_attention.h | 42 -------- 3 files changed, 1 insertion(+), 144 deletions(-) delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt index 37f6014d9..2b28e2932 100644 --- a/src/fastertransformer/kernels/llama/CMakeLists.txt +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -15,7 +15,7 @@ cmake_minimum_required(VERSION 3.8) set(llama_decoder_masked_multihead_attention_files - decoder_masked_multihead_attention.cu + decoder_masked_groupedquery_attention.cu ) file(GLOB llama_decoder_masked_multihead_attention_files ${llama_decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) add_library(llama_decoder_masked_multihead_attention STATIC ${llama_decoder_masked_multihead_attention_files}) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu deleted file mode 100644 index ae821226e..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.cu +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - switch (params.hidden_size_per_head) { - case 32: - mmha_launch_kernel(params, stream); - break; - case 48: - mmha_launch_kernel(params, stream); - break; - case 64: - mmha_launch_kernel(params, stream); - break; - case 80: - mmha_launch_kernel(params, stream); - break; - case 96: - mmha_launch_kernel(params, stream); - break; - case 128: - mmha_launch_kernel(params, stream); - break; - case 144: - mmha_launch_kernel(params, stream); - break; - case 160: - mmha_launch_kernel(params, stream); - break; - case 192: - mmha_launch_kernel(params, stream); - break; - case 224: - mmha_launch_kernel(params, stream); - break; - case 256: - mmha_launch_kernel(params, stream); - break; - default: - assert(false); - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) -{ - multihead_attention_>(params, stream); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_FP8 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, - const cudaStream_t& stream) -{ - multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h deleted file mode 100644 index 4f77cf979..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" -#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include "src/fastertransformer/utils/cuda_fp8_utils.h" -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, - const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, - const cudaStream_t& stream); -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// From 794c0de1271473512e3bb9e10af03e1649623ba6 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:24:49 -0700 Subject: [PATCH 044/135] commit --- .../decoder_masked_groupedquery_attention.cu | 101 ++++++++++++++++++ .../decoder_masked_groupedquery_attention.h | 42 ++++++++ 2 files changed, 143 insertions(+) create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu new file mode 100644 index 000000000..ae821226e --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +template +void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 32: + mmha_launch_kernel(params, stream); + break; + case 48: + mmha_launch_kernel(params, stream); + break; + case 64: + mmha_launch_kernel(params, stream); + break; + case 80: + mmha_launch_kernel(params, stream); + break; + case 96: + mmha_launch_kernel(params, stream); + break; + case 128: + mmha_launch_kernel(params, stream); + break; + case 144: + mmha_launch_kernel(params, stream); + break; + case 160: + mmha_launch_kernel(params, stream); + break; + case 192: + mmha_launch_kernel(params, stream); + break; + case 224: + mmha_launch_kernel(params, stream); + break; + case 256: + mmha_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h new file mode 100644 index 000000000..4f77cf979 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/layers/attention_layers_fp8/AttentionFP8Weight.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// From dc805b20635092cd852a16fc3ca3d0669d3b9738 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:25:39 -0700 Subject: [PATCH 045/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp index fbb0230ea..64d78b5d1 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp @@ -15,7 +15,7 @@ */ #pragma once -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include "src/fastertransformer/utils/cuda_fp8_utils.h" From 7b9e51ba1fe7661561f2fa52f22dece4dc02b507 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:25:59 -0700 Subject: [PATCH 046/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index ae821226e..3fe0ec8a6 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" From 5ab58d16e87a48b0421d3fcb2522fb7fd8a9b6ac Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:28:10 -0700 Subject: [PATCH 047/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index d847b6a67..51e5f6357 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -15,7 +15,7 @@ */ #include "src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.h" -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/utils/logger.h" #include "src/fastertransformer/utils/memory_utils.h" #include "src/fastertransformer/kernels/repeat_kv_kernels.h" From f703210ea7f87743865688ee4cc775f42844067b Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Fri, 8 Sep 2023 19:29:12 -0700 Subject: [PATCH 048/135] commit --- .../llama/decoder_masked_groupedquery_attention.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index 3fe0ec8a6..51d7c5e39 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -66,14 +66,14 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& //////////////////////////////////////////////////////////////////////////////////////////////////// -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { multihead_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { multihead_attention_>(params, stream); } @@ -81,7 +81,7 @@ void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, +void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream) { multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); @@ -91,7 +91,7 @@ void llama_masked_multihead_attention(const Masked_multihead_attention_params<__ //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream) { multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); From e2e516a14eb12f3043046625686320e401f7b0ec Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:21:36 -0700 Subject: [PATCH 049/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 4f77cf979..0aec64936 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -28,14 +28,14 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void llama_masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, +void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -void llama_masked_multihead_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif From 459c26e92df0631838882a4344f6782037060b5a Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:23:33 -0700 Subject: [PATCH 050/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index 51e5f6357..bc49490c4 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -144,7 +144,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, } PUSH_RANGE("scaled dot-product fusion"); - masked_multihead_attention(params, stream); + masked_groupedquery_attention(params, stream); POP_RANGE; } From f0ab27f95e771fed5d98aa68b662164b850c86a9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:24:18 -0700 Subject: [PATCH 051/135] commit --- .../decoder_masked_multihead_attention_128.cu | 0 .../decoder_masked_multihead_attention_144.cu | 0 .../decoder_masked_multihead_attention_160.cu | 0 .../decoder_masked_multihead_attention_192.cu | 0 .../decoder_masked_multihead_attention_224.cu | 0 .../decoder_masked_multihead_attention_256.cu | 0 .../decoder_masked_multihead_attention_32.cu | 0 .../decoder_masked_multihead_attention_48.cu | 0 .../decoder_masked_multihead_attention_64.cu | 0 .../decoder_masked_multihead_attention_80.cu | 0 .../decoder_masked_multihead_attention_96.cu | 0 .../decoder_masked_multihead_attention_template.hpp | 0 12 files changed, 0 insertions(+), 0 deletions(-) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_128.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_144.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_160.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_192.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_224.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_256.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_32.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_48.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_64.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_80.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_96.cu (100%) rename src/fastertransformer/kernels/llama/{decoder_masked_multihead_attention => decoder_masked_groupedquery_attention}/decoder_masked_multihead_attention_template.hpp (100%) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_144.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_160.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_192.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_224.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_256.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_32.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_48.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_64.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_80.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_96.cu rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu diff --git a/src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp similarity index 100% rename from src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp rename to src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp From d5b096b95d6b2d0cbd7d454984ef40c226f009ac Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:24:42 -0700 Subject: [PATCH 052/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index 51d7c5e39..33fc4170f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -15,7 +15,7 @@ */ #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/llama/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include From aca0e797b9072b6521a1a3b9ca821bdf2f38e413 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:24:55 -0700 Subject: [PATCH 053/135] commit --- src/fastertransformer/kernels/llama/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt index 2b28e2932..85b96b113 100644 --- a/src/fastertransformer/kernels/llama/CMakeLists.txt +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.8) set(llama_decoder_masked_multihead_attention_files decoder_masked_groupedquery_attention.cu ) -file(GLOB llama_decoder_masked_multihead_attention_files ${llama_decoder_masked_multihead_attention_files} ./decoder_masked_multihead_attention/*.cu) +file(GLOB llama_decoder_masked_multihead_attention_files ${llama_decoder_masked_multihead_attention_files} ./decoder_masked_groupedquery_attention/*.cu) add_library(llama_decoder_masked_multihead_attention STATIC ${llama_decoder_masked_multihead_attention_files}) set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) From 422a5456800378bd544578ba1ea2b30a1d0fe6b9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:26:15 -0700 Subject: [PATCH 054/135] commit --- src/fastertransformer/kernels/llama/CMakeLists.txt | 10 +++++----- .../layers/attention_layers/CMakeLists.txt | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/fastertransformer/kernels/llama/CMakeLists.txt b/src/fastertransformer/kernels/llama/CMakeLists.txt index 85b96b113..07fa20a03 100644 --- a/src/fastertransformer/kernels/llama/CMakeLists.txt +++ b/src/fastertransformer/kernels/llama/CMakeLists.txt @@ -14,10 +14,10 @@ cmake_minimum_required(VERSION 3.8) -set(llama_decoder_masked_multihead_attention_files +set(decoder_masked_groupedquery_attention_files decoder_masked_groupedquery_attention.cu ) -file(GLOB llama_decoder_masked_multihead_attention_files ${llama_decoder_masked_multihead_attention_files} ./decoder_masked_groupedquery_attention/*.cu) -add_library(llama_decoder_masked_multihead_attention STATIC ${llama_decoder_masked_multihead_attention_files}) -set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY POSITION_INDEPENDENT_CODE ON) -set_property(TARGET llama_decoder_masked_multihead_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) +file(GLOB decoder_masked_groupedquery_attention_files ${decoder_masked_groupedquery_attention_files} ./decoder_masked_groupedquery_attention/*.cu) +add_library(decoder_masked_groupedquery_attention STATIC ${decoder_masked_groupedquery_attention_files}) +set_property(TARGET decoder_masked_groupedquery_attention PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET decoder_masked_groupedquery_attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/src/fastertransformer/layers/attention_layers/CMakeLists.txt b/src/fastertransformer/layers/attention_layers/CMakeLists.txt index 0f927b8bd..60bbcffba 100644 --- a/src/fastertransformer/layers/attention_layers/CMakeLists.txt +++ b/src/fastertransformer/layers/attention_layers/CMakeLists.txt @@ -42,7 +42,7 @@ target_link_libraries(DecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasM add_library(LlamaDecoderSelfAttentionLayer STATIC LlamaDecoderSelfAttentionLayer.cc) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET LlamaDecoderSelfAttentionLayer PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) -target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils llama_decoder_masked_multihead_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) +target_link_libraries(LlamaDecoderSelfAttentionLayer PUBLIC -lcublas -lcudart cublasMMWrapper memory_utils decoder_masked_groupedquery_attention fpA_intB_gemm int8_gemm tensor nvtx_utils) add_library(LlamaContextAttentionLayer STATIC LlamaContextAttentionLayer.cc) set_property(TARGET LlamaContextAttentionLayer PROPERTY POSITION_INDEPENDENT_CODE ON) From d9a84818878dc9673376080ca7074c14cfbc744a Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:27:19 -0700 Subject: [PATCH 055/135] commit --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b3b9b857e..acfb53e4a 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -393,7 +393,7 @@ add_library(transformer-shared SHARED $ $ $ - $ + $ $ $ $ From 29959a3b5781a1e7be8884944b64e4e2454c080d Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:30:22 -0700 Subject: [PATCH 056/135] commit --- .../llama/decoder_masked_groupedquery_attention.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index 33fc4170f..bfa18c15b 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -23,7 +23,7 @@ #include template -void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void groupedquery_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { switch (params.hidden_size_per_head) { case 32: @@ -68,14 +68,14 @@ void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_>(params, stream); + groupedquery_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_>(params, stream); + groupedquery_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -84,7 +84,7 @@ void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) { - multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); + groupedquery_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); } #endif @@ -94,7 +94,7 @@ void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_ void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream) { - multihead_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); + groupedquery_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); } #endif From 45c9aeefd6f0319ef92c61984364db10b881c8b0 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:31:28 -0700 Subject: [PATCH 057/135] commit --- .../decoder_masked_groupedquery_attention.cu | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index bfa18c15b..1f07b286e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -27,37 +27,37 @@ void groupedquery_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_ { switch (params.hidden_size_per_head) { case 32: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 48: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 64: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 80: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 96: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 128: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 144: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 160: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 192: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 224: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; case 256: - mmha_launch_kernel(params, stream); + mgqa_launch_kernel(params, stream); break; default: assert(false); From 7e721f9944005edbcd9ef679fb84a8bc637da2ff Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:33:46 -0700 Subject: [PATCH 058/135] commit --- .../decoder_masked_multihead_attention_128.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_144.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_160.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_192.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_224.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_256.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_32.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_48.cu | 20 +++++++++---------- .../decoder_masked_multihead_attention_64.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_80.cu | 18 ++++++++--------- .../decoder_masked_multihead_attention_96.cu | 18 ++++++++--------- ...er_masked_multihead_attention_template.hpp | 2 +- 12 files changed, 101 insertions(+), 101 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu index 9b4f7c393..540154229 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu @@ -42,7 +42,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -74,29 +74,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 128, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 128, 128, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 128, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu index 0da2134e9..fe973d739 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 144, 256, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 144, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 144, 256, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 144, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu index 86153f37a..bc5624747 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 160, 256, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 160, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 160, 256, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 160, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu index 12c6e22bf..07b84f11e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 192, 256, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 192, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 192, 256, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 192, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu index 7b17ae7b7..9e01c7588 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 224, 256, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 224, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 224, 256, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 224, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu index e17fa03ae..66b45c0ea 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 256, 256, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 256, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 256, 256, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 256, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index 91ecc2f46..46ef3f32a 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 32, 32, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 32, 32, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 32, 32, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 32, 32, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu index 79bf3ca83..4f555cbd8 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,30 +72,30 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 48, 64, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 48, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 48, 64, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 48, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL \ No newline at end of file +#undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu index a4156e071..f172ab56f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 64, 64, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 64, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 64, 64, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 64, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu index b94345952..ea978d743 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 80, 128, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 80, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 80, 128, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 80, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu index 6e754fd14..b4bdac2ec 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu @@ -41,7 +41,7 @@ // !!! Specialize the launcher for Cross attention template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; @@ -72,29 +72,29 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Masked_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 96, 128, Masked_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, Masked_multihead_attention_params<__nv_bfloat16>>( const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 96, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mmha_launch_kernel>( +template void mgqa_launch_kernel>( const Cross_multihead_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mmha_launch_kernel<__nv_bfloat16, 96, 128, Cross_multihead_attention_params<__nv_bfloat16>>( +template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, Cross_multihead_attention_params<__nv_bfloat16>>( const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mmha_launch_kernel<__nv_fp8_e4m3, 96, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( +template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 64d78b5d1..e91f671a4 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1956,4 +1956,4 @@ struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> { #endif template -void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); From 4a03a09d1db348edf8b15dff8351692faa25cfc1 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:40:00 -0700 Subject: [PATCH 059/135] commit --- .../decoder_masked_groupedquery_attention.h | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 0aec64936..7549de547 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -26,6 +26,23 @@ #include #include +template +struct GroupedQuery_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); From ccbd7278dd51858825c84d69f1aa6a7da0f0cd9f Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:40:20 -0700 Subject: [PATCH 060/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 7549de547..2db0ff85b 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -27,7 +27,7 @@ #include template -struct GroupedQuery_attention_params: public Multihead_attention_params_base { +struct GroupedQuery_attention_params: public Multihead_attention_params_base { // output cross attentions float* cross_attention_out = nullptr; int max_decoder_seq_len = 0; From e58d5487dd1867ee2d0010b3c5f2c688f22ea865 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:41:18 -0700 Subject: [PATCH 061/135] commit --- .../llama/decoder_masked_groupedquery_attention.h | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 2db0ff85b..6b390cc2c 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -28,19 +28,9 @@ template struct GroupedQuery_attention_params: public Multihead_attention_params_base { - // output cross attentions - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - // allows to exist attention eary - bool* finished = nullptr; - - // required in case of cross attention - int* memory_length_per_sample = nullptr; - - // required in case of masked attention with different length - const int* length_per_sample = nullptr; + bool* finished = nullptr; + int num_kv_heads = 0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// From 941072c3e6eeeadb40364d97d998da5774ee2db2 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:43:36 -0700 Subject: [PATCH 062/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 6b390cc2c..6797ee84a 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -35,14 +35,14 @@ struct GroupedQuery_attention_params: public Multihead_attention_params_base //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); -void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const GroupedQuery_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, +void masked_groupedquery_attention(const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_groupedquery_attention(const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif From bbf5ebeb75b18fb4822746ea9ddd70ff2f61cf07 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:45:50 -0700 Subject: [PATCH 063/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index bc49490c4..89984bcbd 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -71,7 +71,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, { using DataType = typename SATypeConverter::Type; // Prepare the parameters. - Masked_multihead_attention_params params; + Masked_groupedquery_attention_params params; memset(¶ms, 0, sizeof(params)); int hidden_units = head_num * size_per_head; if (qkv_bias != nullptr) { From 145788b7010a3a32127162660ffaa2f95b38c86a Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:47:41 -0700 Subject: [PATCH 064/135] commit --- .../decoder_masked_groupedquery_attention.cu | 16 ++++++++-------- .../decoder_masked_groupedquery_attention.h | 11 +++++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index 1f07b286e..ec6bb68e8 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -66,35 +66,35 @@ void groupedquery_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_ //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream) { - groupedquery_attention_>(params, stream); + groupedquery_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_groupedquery_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream) { - groupedquery_attention_>(params, stream); + groupedquery_attention_>(params, stream); } //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_BF16 -void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream) { - groupedquery_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); + groupedquery_attention_<__nv_bfloat16, Masked_groupedquery_attention_params<__nv_bfloat16>>(params, stream); } #endif //////////////////////////////////////////////////////////////////////////////////////////////////// #ifdef ENABLE_FP8 -void masked_groupedquery_attention(const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream) { - groupedquery_attention_<__nv_fp8_e4m3, Masked_multihead_attention_params<__nv_fp8_e4m3>>(params, stream); + groupedquery_attention_<__nv_fp8_e4m3, Masked_groupedquery_attention_params<__nv_fp8_e4m3>>(params, stream); } #endif diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 6797ee84a..4c9975211 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -33,16 +33,19 @@ struct GroupedQuery_attention_params: public Multihead_attention_params_base int num_kv_heads = 0; }; +template +using Masked_groupedquery_attention_params = Multihead_attention_params; + //////////////////////////////////////////////////////////////////////////////////////////////////// -void masked_groupedquery_attention(const GroupedQuery_attention_params& params, const cudaStream_t& stream); -void masked_groupedquery_attention(const GroupedQuery_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream); +void masked_groupedquery_attention(const Masked_groupedquery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -void masked_groupedquery_attention(const GroupedQuery_attention_params<__nv_bfloat16>& params, +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -void masked_groupedquery_attention(const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, +void masked_groupedquery_attention(const Masked_groupedquery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif From 6333c82ee8a4dd9d29a48a56e44cab3f4b9887fe Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:48:19 -0700 Subject: [PATCH 065/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 4c9975211..7ad248aa7 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -34,7 +34,7 @@ struct GroupedQuery_attention_params: public Multihead_attention_params_base }; template -using Masked_groupedquery_attention_params = Multihead_attention_params; +using Masked_groupedquery_attention_params = GroupedQuery_attention_params; //////////////////////////////////////////////////////////////////////////////////////////////////// From 210c439abbd39035d1b30788f598b4bc32e83f7d Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:50:55 -0700 Subject: [PATCH 066/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 7ad248aa7..b0968519f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -31,6 +31,8 @@ struct GroupedQuery_attention_params: public Multihead_attention_params_base // allows to exist attention eary bool* finished = nullptr; int num_kv_heads = 0; + // required in case of masked attention with different length + const int* length_per_sample = nullptr; }; template From a305abed0661555b8039c5cc3cd54ba761b9dcf1 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:55:30 -0700 Subject: [PATCH 067/135] commit --- .../decoder_masked_multihead_attention_128.cu | 28 +++++------------- .../decoder_masked_multihead_attention_144.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_160.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_192.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_224.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_256.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_32.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_48.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_64.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_80.cu | 29 +++++-------------- .../decoder_masked_multihead_attention_96.cu | 29 +++++-------------- 11 files changed, 88 insertions(+), 230 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu index 540154229..1bc911bf0 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu @@ -74,30 +74,18 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu index fe973d739..29d02d6b6 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu index bc5624747..42fa04686 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu index 07b84f11e..193c7dcf4 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu index 9e01c7588..b1dfbd050 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu index 66b45c0ea..f2d8058c3 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index 46ef3f32a..6d17c1a59 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu index 4f555cbd8..e3d959fde 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu index f172ab56f..6f2aed91d 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu index ea978d743..14609ce69 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu index b4bdac2ec..61a779109 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu @@ -72,30 +72,17 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st //////////////////////////////////////////////////////////////////////////////////////////////////// -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Masked_multihead_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); #ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, Masked_multihead_attention_params<__nv_bfloat16>>( - const Masked_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); #endif #ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, Masked_multihead_attention_params<__nv_fp8_e4m3>>( - const Masked_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const Cross_multihead_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, Cross_multihead_attention_params<__nv_bfloat16>>( - const Cross_multihead_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, Cross_multihead_attention_params<__nv_fp8_e4m3>>( - const Cross_multihead_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif #undef MMHA_LAUNCH_KERNEL From 4476f35298a0adbc3a99b83899d466c7ef15645f Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 12:57:20 -0700 Subject: [PATCH 068/135] commit --- .../decoder_masked_multihead_attention_128.cu | 5 ++--- .../decoder_masked_multihead_attention_144.cu | 2 +- .../decoder_masked_multihead_attention_160.cu | 2 +- .../decoder_masked_multihead_attention_192.cu | 2 +- .../decoder_masked_multihead_attention_224.cu | 2 +- .../decoder_masked_multihead_attention_256.cu | 2 +- .../decoder_masked_multihead_attention_32.cu | 2 +- .../decoder_masked_multihead_attention_48.cu | 2 +- .../decoder_masked_multihead_attention_64.cu | 2 +- .../decoder_masked_multihead_attention_80.cu | 2 +- .../decoder_masked_multihead_attention_96.cu | 2 +- 11 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu index 1bc911bf0..5c69c18ca 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu @@ -14,15 +14,14 @@ * limitations under the License. */ -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include #include #include -#include "decoder_masked_multihead_attention_template.hpp" - //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu index 29d02d6b6..961f36f40 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu index 42fa04686..8982021c2 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu index 193c7dcf4..ea86f2508 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu index b1dfbd050..e9d63534e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu index f2d8058c3..bfa56b428 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index 6d17c1a59..5d73f6f3e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu index e3d959fde..2f53d0d56 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu index 6f2aed91d..80a2ba640 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu index 14609ce69..2ca904d01 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu index 61a779109..05754e848 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu @@ -15,7 +15,7 @@ */ #include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention.h" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include From 2fab8b471df6f5bb38622d91a969ef76023afebc Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:03:33 -0700 Subject: [PATCH 069/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index e91f671a4..df9f7e6c7 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1056,8 +1056,8 @@ struct kernel_type_t<__nv_fp8_e4m3> { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -inline size_t smem_size_in_bytes(const Multihead_attention_params& params, +template +inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& params, int threads_per_value, int threads_per_block) { @@ -1114,9 +1114,8 @@ template< int THREADS_PER_VALUE, // The number of threads in a threadblock. int THREADS_PER_BLOCK, - bool DO_CROSS_ATTENTION, bool HAS_BEAMS> -__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) +__global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params params) { using Tk = typename kernel_type_t::Type; From ff64dc299e9af2c692f41f9f86739c531db34168 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:06:18 -0700 Subject: [PATCH 070/135] commit --- .../decoder_masked_multihead_attention_32.cu | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index 5d73f6f3e..8423a01d0 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -25,15 +25,14 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) @@ -44,28 +43,29 @@ template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + //constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + constexpr bool DO_CROSS_ATTENTION = false; int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } From 8efdd38486693ec4c5244f6c35854e64afe34442 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:06:59 -0700 Subject: [PATCH 071/135] commit --- .../decoder_masked_multihead_attention_64.cu | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu index 80a2ba640..e8a6e9410 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,28 +43,28 @@ template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } From d805fcf184b4f60cbd49679cef3694d270842d66 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:08:53 -0700 Subject: [PATCH 072/135] commit --- .../decoder_masked_multihead_attention_128.cu | 19 +++++++++---------- .../decoder_masked_multihead_attention_144.cu | 17 ++++++++--------- .../decoder_masked_multihead_attention_160.cu | 17 ++++++++--------- .../decoder_masked_multihead_attention_192.cu | 17 ++++++++--------- .../decoder_masked_multihead_attention_224.cu | 17 ++++++++--------- .../decoder_masked_multihead_attention_256.cu | 17 ++++++++--------- .../decoder_masked_multihead_attention_48.cu | 19 +++++++++---------- .../decoder_masked_multihead_attention_80.cu | 17 ++++++++--------- .../decoder_masked_multihead_attention_96.cu | 17 ++++++++--------- 9 files changed, 74 insertions(+), 83 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu index 5c69c18ca..210fb8f77 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -46,27 +45,27 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu index 961f36f40..d4922de73 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu index 8982021c2..9bd6da2a0 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu index ea86f2508..651cc5c6c 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu index e9d63534e..0bf1e9aaf 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu index bfa56b428..a8bcd0180 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu index 2f53d0d56..c2c65e6b0 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,28 +43,28 @@ template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu index 2ca904d01..ecf73b9ed 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu index 05754e848..c6b52930f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu @@ -25,8 +25,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define MMHA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ mmha::masked_multihead_attention_kernel<<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -48,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, false, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, true, stream); + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } From e47d2ac26ce474175cb46b115965edce81465fb8 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:09:23 -0700 Subject: [PATCH 073/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index df9f7e6c7..91470d386 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1117,7 +1117,8 @@ template< bool HAS_BEAMS> __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params params) { - + // TODO(zhwang): hacky + constexpr bool DO_CROSS_ATTENTION = false; using Tk = typename kernel_type_t::Type; #ifdef ENABLE_FP8 // FP8 MHA Scales From 3b38507717d9adf76e668d16c2c8cf21ed7020c5 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:09:47 -0700 Subject: [PATCH 074/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 91470d386..066010385 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1061,6 +1061,8 @@ inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& int threads_per_value, int threads_per_block) { + // TODO(zhwang): hacky + constexpr bool DO_CROSS_ATTENTION = false; using Tk = typename kernel_type_t::Type; // The amount of shared memory needed to store the Q*K^T values in float. const int max_timesteps = min(params.timestep, params.memory_max_len); From d601f685a7d03f457f3ce964c2f774fc6c5ad1f2 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 13:12:01 -0700 Subject: [PATCH 075/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index b0968519f..30890b042 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -33,6 +33,13 @@ struct GroupedQuery_attention_params: public Multihead_attention_params_base int num_kv_heads = 0; // required in case of masked attention with different length const int* length_per_sample = nullptr; + + // output cross attentions + // TODO(zhwang): remove + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + int* memory_length_per_sample = nullptr; }; template From e1616129da27776631104fb1fbe75d1c875601a9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 19:27:07 -0700 Subject: [PATCH 076/135] commit --- .../decoder_masked_multihead_attention_32.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index 8423a01d0..bb3a42a64 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -33,7 +33,6 @@ Dh_MAX, \ THDS_PER_KEY, \ THDS_PER_VALUE, \ - DO_CROSS_ATTENTION, \ HAS_BEAMS><<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// From e655b1c1170f71fc2ef81eddf5c6b8ac444b9e7e Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 19:28:09 -0700 Subject: [PATCH 077/135] commit --- .../decoder_masked_multihead_attention_32.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index bb3a42a64..e8d24cf5a 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -33,6 +33,7 @@ Dh_MAX, \ THDS_PER_KEY, \ THDS_PER_VALUE, \ + THDS_PER_BLOCK, \ HAS_BEAMS><<>>(params) //////////////////////////////////////////////////////////////////////////////////////////////////// From e0375960e4204c0321ea8c71815e9965f2b18f33 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 19:31:45 -0700 Subject: [PATCH 078/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 066010385..6a35f0309 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1066,15 +1066,14 @@ inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& using Tk = typename kernel_type_t::Type; // The amount of shared memory needed to store the Q*K^T values in float. const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + size_t qk_sz = div_up(max_timesteps + 1, 4) * 16; // The extra memory needed if we are not using floats for the final logits. size_t logits_sz = 0; #ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS if (sizeof(Tk) != 4) { // TDOD - logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(Tk) : - div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); + logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); } #endif From c74e6b8d34fba39f1713f8f52b9affff313fec26 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 19:32:08 -0700 Subject: [PATCH 079/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 6a35f0309..d271088a9 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1061,8 +1061,6 @@ inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& int threads_per_value, int threads_per_block) { - // TODO(zhwang): hacky - constexpr bool DO_CROSS_ATTENTION = false; using Tk = typename kernel_type_t::Type; // The amount of shared memory needed to store the Q*K^T values in float. const int max_timesteps = min(params.timestep, params.memory_max_len); From ac982e8e258449871480146eb74f662d6702a3c1 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 19:34:03 -0700 Subject: [PATCH 080/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index d271088a9..ab1b5d669 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1147,8 +1147,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (sizeof(Tk) != 4) { // TODO - change to tlength const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += - (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + logits_smem_ += div_up(max_timesteps + 1, 4) * 16; } Tk* logits_smem = reinterpret_cast(logits_smem_); #else @@ -1208,7 +1207,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The thread in the block. const int tidx = threadIdx.x; - const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); + constexpr bool handle_kv = true; // While doing the product Q*K^T for the different keys we track the max. float qk_max = -FLT_MAX; From e99edd4c25fae397fc70134247be3952fc1787ee Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sat, 9 Sep 2023 19:34:10 -0700 Subject: [PATCH 081/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index ab1b5d669..f71fe9667 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1209,6 +1209,8 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr bool handle_kv = true; + // here. + // While doing the product Q*K^T for the different keys we track the max. float qk_max = -FLT_MAX; From e2d88e04f71b1694298c6d9538aaeef5332acdff Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 12:57:08 -0700 Subject: [PATCH 082/135] commit --- ...er_masked_multihead_attention_template.hpp | 48 +++++++------------ 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index f71fe9667..00d7ed6b8 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1259,35 +1259,19 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< Qk_vec_k k; zero(k); - if (DO_CROSS_ATTENTION) { - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength * QK_ELTS_IN_16B + ci; - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - vec_conversion(*reinterpret_cast(¶ms.k_cache[offset])) : - k; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); } else { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : - k; - } + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + k; } // Trigger the loads from the Q and K bias buffers. @@ -1388,11 +1372,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Store the Q values to shared memory. *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - // Store Dh values of k_bias into smem, since will need to add later - // if params.timestep == 0 - if (DO_CROSS_ATTENTION && params.timestep == 0) { - *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - } + // // Store Dh values of k_bias into smem, since will need to add later + // // if params.timestep == 0 + // if (DO_CROSS_ATTENTION && params.timestep == 0) { + // *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; + // } // Write the K values to the global memory cache. // From fce23244a744b9c976da046cae89e947ffc671c5 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 13:08:54 -0700 Subject: [PATCH 083/135] commit --- ...er_masked_multihead_attention_template.hpp | 121 ++++-------------- 1 file changed, 28 insertions(+), 93 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 00d7ed6b8..af5e24520 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1469,14 +1469,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); } - K_vec_k k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; - if (DO_CROSS_ATTENTION && params.timestep == 0) { -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - } - // The number of timesteps loaded per iteration. constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; // The number of keys per warp. @@ -1525,23 +1517,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } } - // add bias and update k_cache - if (DO_CROSS_ATTENTION && params.timestep == 0) { - k[ii] = add(k[ii], k_bias_vec[ii]); - - if (do_ia3) { - k[ii] = mul( - k[ii], - vec_conversion(*reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki - + ii * THREADS_PER_KEY * K_VEC_SIZE]))); - } - - if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { - *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = - vec_conversion(k[ii]); - } - } } } @@ -1634,16 +1609,8 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Normalize the logits. float inv_sum = __fdividef(1.f, sum + 1.e-6f); - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - const size_t cross_attention_out_offset = - params.is_return_cross_attentions ? - bhi * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : - 0; for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { float logit = qk_smem[ti - first_step] * inv_sum; - if (params.is_return_cross_attentions) { - params.cross_attention_out[cross_attention_out_offset + ti] = logit; - } convert_from_float(logits_smem[ti - first_step], logit); } @@ -1674,16 +1641,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< zero(v_bias); // if( vo == params.timestep % V_PER_ITER ) { if (Dh == Dh_MAX || vi < Dh) { - if (handle_kv) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = vec_conversion( - *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); - } - if (DO_CROSS_ATTENTION) { - *reinterpret_cast(&bias_smem[vi]) = vec_conversion(v_bias); - } + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = vec_conversion( + *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); } } } @@ -1717,16 +1679,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Load the values from the cache. V_vec_k v = vec_conversion( *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, vec_conversion(*reinterpret_cast(&bias_smem[vi]))); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = vec_conversion(v); - } // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1763,16 +1715,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Load the values from the cache. V_vec_k v = vec_conversion( *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); - if (DO_CROSS_ATTENTION && params.timestep == 0) { - v = add(v, vec_conversion(*reinterpret_cast(&bias_smem[vi]))); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - *reinterpret_cast(&v_cache[ti * Dh]) = vec_conversion(v); - } // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1803,44 +1745,37 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { V_vec_k v; - if (DO_CROSS_ATTENTION) { - v = vec_conversion(*reinterpret_cast(&v_cache[tlength * Dh])); + // Trigger the loads from the V buffer. + const auto v_offset = qkv_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); } else { - // Trigger the loads from the V buffer. - const auto v_offset = qkv_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); // Compute the V values with bias. - if (handle_kv) { - v = add(v, v_bias); + v = add(v, v_bias); - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - - // Store the values with bias back to global memory in the cache for V. - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); } + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + // Initialize the output value with the current timestep. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) // out = fma(logits_smem[params.timestep], cast_to_float(v), out); From 1f5bca86e49638930e86439bd79e5dae4a2c79da Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 13:10:31 -0700 Subject: [PATCH 084/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index af5e24520..8a91810b4 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1372,12 +1372,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Store the Q values to shared memory. *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - // // Store Dh values of k_bias into smem, since will need to add later - // // if params.timestep == 0 - // if (DO_CROSS_ATTENTION && params.timestep == 0) { - // *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; - // } - // Write the K values to the global memory cache. // // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory From f7ad8621b9753d1f9ba523f01cddae2cd4828986 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 13:21:08 -0700 Subject: [PATCH 085/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 8a91810b4..f741de0c6 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1116,8 +1116,6 @@ template< bool HAS_BEAMS> __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params params) { - // TODO(zhwang): hacky - constexpr bool DO_CROSS_ATTENTION = false; using Tk = typename kernel_type_t::Type; #ifdef ENABLE_FP8 // FP8 MHA Scales @@ -1168,9 +1166,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Shared memory to store Q inputs. __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX]; - // This is one of the reasons we should have a separate kernel for cross attention - __shared__ __align__(sizeof(Qk_vec_k)) Tk bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; - // The number of elements per vector. constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T); // Make sure the hidden size per head is a multiple of the vector size. @@ -1220,9 +1215,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const size_t bi_seq_len_offset = bi * params.memory_max_len; - // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : - (params.length_per_sample == nullptr) ? + int tlength = (params.length_per_sample == nullptr) ? params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; const int first_step = max(0, tlength + 1 - params.memory_max_len); From 13d2c2c333c3f876686e3ee23688a35d52f3d7f9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 16:25:18 -0700 Subject: [PATCH 086/135] commit --- .../layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc index 89984bcbd..ed53c22d3 100644 --- a/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc +++ b/src/fastertransformer/layers/attention_layers/LlamaDecoderSelfAttentionLayer.cc @@ -113,7 +113,7 @@ void fusedQKV_masked_attention_dispatch(const T* qkv_buf, // timestep adding max_prefix_prompt_length for shared memory size calculation and rotary embedding computation params.timestep = step + max_prefix_prompt_length - 1; params.num_heads = head_num; - // params.num_kv_heads = kv_head_num; + params.num_kv_heads = kv_head_num; params.hidden_size_per_head = size_per_head; params.rotary_embedding_dim = rotary_embedding_dim; params.neox_rotary_style = neox_rotary_style; From 6b7b481f99b01cdfbf86bf0ebb1a5059454a6c67 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 16:37:05 -0700 Subject: [PATCH 087/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index f741de0c6..70d4a40e6 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1197,6 +1197,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int hi = blockIdx.x; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; + printf("%d\n", bhi); // Combine the "beam-aware" batch idx and the head indices. const int bbhi = bbi * params.beam_width * params.num_heads + hi; // The thread in the block. From 38e1c6141115094125e7e4b25dc1c7b1a525f104 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 16:38:08 -0700 Subject: [PATCH 088/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 70d4a40e6..78e582f9e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1197,7 +1197,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int hi = blockIdx.x; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; - printf("%d\n", bhi); + printf("%d %d %d\n", bi, params.num_heads, hi); // Combine the "beam-aware" batch idx and the head indices. const int bbhi = bbi * params.beam_width * params.num_heads + hi; // The thread in the block. From 6f79374539804ca80e7bb2dd91b3c171653e98da Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 17:53:12 -0700 Subject: [PATCH 089/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 78e582f9e..9caf9f81e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1193,11 +1193,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beami = bi % params.beam_width; // The "beam-aware" batch idx const int bbi = bi / params.beam_width; + const int head_n_rep = params.num_heads / params.num_kv_heads; // The head. const int hi = blockIdx.x; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; - printf("%d %d %d\n", bi, params.num_heads, hi); // Combine the "beam-aware" batch idx and the head indices. const int bbhi = bbi * params.beam_width * params.num_heads + hi; // The thread in the block. From 3e7cbc4b42770fc66442c2dc622b74d030bffeda Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:01:58 -0700 Subject: [PATCH 090/135] commit --- ...er_masked_multihead_attention_template.hpp | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 9caf9f81e..ec9af2e4b 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1195,11 +1195,14 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int bbi = bi / params.beam_width; const int head_n_rep = params.num_heads / params.num_kv_heads; // The head. - const int hi = blockIdx.x; + const int hi = blockIdx.x; + const int kvhi = hi / head_n_rep; // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; + const int bhi = bi * params.num_heads + hi; + const int bkvhi = bi * params.num_kv_heads + kvhi; // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbkvhi = bbi * params.beam_width * params.num_kv_heads + kvhi; // The thread in the block. const int tidx = threadIdx.x; @@ -1386,7 +1389,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (handle_kv) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[offset/params.num_kv_heads]) = vec_conversion(k); } } @@ -1463,9 +1466,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; + T* k_cache = ¶ms.k_cache[(bhi * params.memory_max_len * Dh + ki)/params.num_kv_heads]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; + T* k_cache_batch = ¶ms.k_cache[(bbhi * params.memory_max_len * Dh + ki)/params.num_kv_heads]; // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; @@ -1498,11 +1501,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (HAS_BEAMS) { const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + (*reinterpret_cast(&k_cache_batch[(beam_offset + jj * QK_ELTS_IN_16B)/params.num_kv_heads]))); } else { k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + (*reinterpret_cast(&k_cache_batch[(jj * QK_ELTS_IN_16B)/params.num_kv_heads]))); } } } @@ -1617,9 +1620,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; + T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/params.num_kv_heads]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; + T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi)/params.num_kv_heads]; // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; @@ -1666,7 +1669,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh)/params.num_kv_heads])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1702,7 +1705,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti_circ * Dh)/params.num_kv_heads])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1762,7 +1765,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Store the values with bias back to global memory in the cache for V. //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + *reinterpret_cast(&v_cache[(tlength_circ * Dh)/params.num_kv_heads]) = vec_conversion(v); // Initialize the output value with the current timestep. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) From dd152ce13f8ec32d10e9b9b2c81391139dd24149 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:04:14 -0700 Subject: [PATCH 091/135] commit --- ...er_masked_multihead_attention_template.hpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index ec9af2e4b..7ff8e05df 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1389,7 +1389,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (handle_kv) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset/params.num_kv_heads]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[offset/head_n_repeat]) = vec_conversion(k); } } @@ -1466,9 +1466,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[(bhi * params.memory_max_len * Dh + ki)/params.num_kv_heads]; + T* k_cache = ¶ms.k_cache[(bhi * params.memory_max_len * Dh + ki)/head_n_repeat]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[(bbhi * params.memory_max_len * Dh + ki)/params.num_kv_heads]; + T* k_cache_batch = ¶ms.k_cache[(bbhi * params.memory_max_len * Dh + ki)/head_n_repeat]; // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; @@ -1501,11 +1501,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (HAS_BEAMS) { const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[(beam_offset + jj * QK_ELTS_IN_16B)/params.num_kv_heads]))); + (*reinterpret_cast(&k_cache_batch[(beam_offset + jj * QK_ELTS_IN_16B)/head_n_repeat]))); } else { k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[(jj * QK_ELTS_IN_16B)/params.num_kv_heads]))); + (*reinterpret_cast(&k_cache_batch[(jj * QK_ELTS_IN_16B)/head_n_repeat]))); } } } @@ -1620,9 +1620,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/params.num_kv_heads]; + T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/head_n_repeat]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi)/params.num_kv_heads]; + T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi)/head_n_repeat]; // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; @@ -1669,7 +1669,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh)/params.num_kv_heads])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh)/head_n_repeat])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1705,7 +1705,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti_circ * Dh)/params.num_kv_heads])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti_circ * Dh)/head_n_repeat])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1765,7 +1765,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Store the values with bias back to global memory in the cache for V. //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[(tlength_circ * Dh)/params.num_kv_heads]) = vec_conversion(v); + *reinterpret_cast(&v_cache[(tlength_circ * Dh)/head_n_repeat]) = vec_conversion(v); // Initialize the output value with the current timestep. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) From f2e763d568fc419450e6202457c5340934f0b83c Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:04:57 -0700 Subject: [PATCH 092/135] commit --- ...er_masked_multihead_attention_template.hpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 7ff8e05df..2ebd94967 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1389,7 +1389,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (handle_kv) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset/head_n_repeat]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[offset/head_n_rep]) = vec_conversion(k); } } @@ -1466,9 +1466,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[(bhi * params.memory_max_len * Dh + ki)/head_n_repeat]; + T* k_cache = ¶ms.k_cache[(bhi * params.memory_max_len * Dh + ki)/head_n_rep]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[(bbhi * params.memory_max_len * Dh + ki)/head_n_repeat]; + T* k_cache_batch = ¶ms.k_cache[(bbhi * params.memory_max_len * Dh + ki)/head_n_rep]; // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; @@ -1501,11 +1501,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (HAS_BEAMS) { const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[(beam_offset + jj * QK_ELTS_IN_16B)/head_n_repeat]))); + (*reinterpret_cast(&k_cache_batch[(beam_offset + jj * QK_ELTS_IN_16B)/head_n_rep]))); } else { k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[(jj * QK_ELTS_IN_16B)/head_n_repeat]))); + (*reinterpret_cast(&k_cache_batch[(jj * QK_ELTS_IN_16B)/head_n_rep]))); } } } @@ -1620,9 +1620,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/head_n_repeat]; + T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/head_n_rep]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi)/head_n_repeat]; + T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi)/head_n_rep]; // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; @@ -1669,7 +1669,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh)/head_n_repeat])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh)/head_n_rep])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1705,7 +1705,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti_circ * Dh)/head_n_repeat])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti_circ * Dh)/head_n_rep])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1765,7 +1765,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Store the values with bias back to global memory in the cache for V. //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[(tlength_circ * Dh)/head_n_repeat]) = vec_conversion(v); + *reinterpret_cast(&v_cache[(tlength_circ * Dh)/head_n_rep]) = vec_conversion(v); // Initialize the output value with the current timestep. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) From 8198daf3487eb3930f06c98ea3887ce3b5d14e71 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:07:41 -0700 Subject: [PATCH 093/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 2ebd94967..9b6b96fec 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1193,7 +1193,8 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beami = bi % params.beam_width; // The "beam-aware" batch idx const int bbi = bi / params.beam_width; - const int head_n_rep = params.num_heads / params.num_kv_heads; + // const int head_n_rep = params.num_heads / params.num_kv_heads; + const int head_n_rep = 1 // The head. const int hi = blockIdx.x; const int kvhi = hi / head_n_rep; From 8d1960855176fcac9dba6591c2e82bf3c2d25307 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:07:53 -0700 Subject: [PATCH 094/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 9b6b96fec..f9c48be94 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1194,7 +1194,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The "beam-aware" batch idx const int bbi = bi / params.beam_width; // const int head_n_rep = params.num_heads / params.num_kv_heads; - const int head_n_rep = 1 + const int head_n_rep = 1; // The head. const int hi = blockIdx.x; const int kvhi = hi / head_n_rep; From 94cd671689d1a01bebd8ff8f1fa89cd0384b2d3b Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:10:43 -0700 Subject: [PATCH 095/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index f9c48be94..158844bde 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1197,6 +1197,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int head_n_rep = 1; // The head. const int hi = blockIdx.x; + printf("%d \n", hi) const int kvhi = hi / head_n_rep; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; From dda42a315d06348c1c6e1aea5d37e38a80a79e9b Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:11:02 -0700 Subject: [PATCH 096/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 158844bde..05a29280a 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1197,7 +1197,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int head_n_rep = 1; // The head. const int hi = blockIdx.x; - printf("%d \n", hi) + printf("%d \n", hi); const int kvhi = hi / head_n_rep; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; From dbe4a4332caf359bb8cf099ce9907db05355cc8e Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:18:42 -0700 Subject: [PATCH 097/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 05a29280a..ef1a5cc8e 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1197,7 +1197,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int head_n_rep = 1; // The head. const int hi = blockIdx.x; - printf("%d \n", hi); const int kvhi = hi / head_n_rep; // Combine the batch and the head indices. const int bhi = bi * params.num_heads + hi; @@ -1620,7 +1619,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int vo = tidx / THREADS_PER_VALUE; // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - + printf("%d\n", vi); // The base pointer for the value in the cache buffer. T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/head_n_rep]; // Base pointer for the beam's batch, before offsetting with indirection buffer From 36094dc0fe420c8ba5982d197b84819e49b2dbfb Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:22:46 -0700 Subject: [PATCH 098/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index ef1a5cc8e..057f717d4 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1621,9 +1621,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; printf("%d\n", vi); // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi)/head_n_rep]; + T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi/8)/head_n_rep]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi)/head_n_rep]; + T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi/8)/head_n_rep]; // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; @@ -1670,7 +1670,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh)/head_n_rep])); + *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh / 8)/head_n_rep])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; From b7d9ca7667d1d24ce554feead2042f9c24df3bcf Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:23:06 -0700 Subject: [PATCH 099/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 - src/fastertransformer/models/llama/LlamaContextDecoder.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 057f717d4..fcca38b46 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1619,7 +1619,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int vo = tidx / THREADS_PER_VALUE; // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - printf("%d\n", vi); // The base pointer for the value in the cache buffer. T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi/8)/head_n_rep]; // Base pointer for the beam's batch, before offsetting with indirection buffer diff --git a/src/fastertransformer/models/llama/LlamaContextDecoder.cc b/src/fastertransformer/models/llama/LlamaContextDecoder.cc index 268dee769..de8c666c8 100644 --- a/src/fastertransformer/models/llama/LlamaContextDecoder.cc +++ b/src/fastertransformer/models/llama/LlamaContextDecoder.cc @@ -572,7 +572,6 @@ void LlamaContextDecoder::forward(std::unordered_map* } sync_check_cuda_error(); - #define ENABLE_FLEX_DEBUG #ifdef ENABLE_FLEX_DEBUG if (l == 1) { printf("%d %d: %d %d\n", l, ite, h_token_num, hidden_units_); From 5e4fa2edce4f8ddae7c51624bbfe82f702e99bfb Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:26:05 -0700 Subject: [PATCH 100/135] commit --- ...er_masked_multihead_attention_template.hpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index fcca38b46..96e6574de 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1390,7 +1390,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (handle_kv) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset/head_n_rep]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); } } @@ -1467,9 +1467,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[(bhi * params.memory_max_len * Dh + ki)/head_n_rep]; + T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki)]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[(bbhi * params.memory_max_len * Dh + ki)/head_n_rep]; + T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; @@ -1502,11 +1502,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (HAS_BEAMS) { const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[(beam_offset + jj * QK_ELTS_IN_16B)/head_n_rep]))); + (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } else { k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[(jj * QK_ELTS_IN_16B)/head_n_rep]))); + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } } } @@ -1620,9 +1620,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[(bhi * params.memory_max_len * Dh + vi/8)/head_n_rep]; + T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[(bbhi * params.memory_max_len * Dh + vi/8)/head_n_rep]; + T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; @@ -1669,7 +1669,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti * Dh / 8)/head_n_rep])); + *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1705,7 +1705,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[(beam_offset + ti_circ * Dh)/head_n_rep])); + *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); // Load the logits from shared memory. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) float logit = logits_smem[ti - first_step]; @@ -1765,7 +1765,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Store the values with bias back to global memory in the cache for V. //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[(tlength_circ * Dh)/head_n_rep]) = vec_conversion(v); + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); // Initialize the output value with the current timestep. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) From 84733913d2a277a61f80509cf9509eb76bad3d76 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 21:26:22 -0700 Subject: [PATCH 101/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 96e6574de..3fe9baead 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1467,7 +1467,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki)]; + T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; From 9fc546584cb56a65fb2012f5e0535855057d4876 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:26:52 -0700 Subject: [PATCH 102/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 3fe9baead..8a562b6bd 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1390,7 +1390,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< if (handle_kv) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[offset/8]) = vec_conversion(k); } } From 8c6f8c406145abb61a784ad565084e988f141cbe Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:28:46 -0700 Subject: [PATCH 103/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 8a562b6bd..0d10400c6 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1193,8 +1193,8 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< const int beami = bi % params.beam_width; // The "beam-aware" batch idx const int bbi = bi / params.beam_width; - // const int head_n_rep = params.num_heads / params.num_kv_heads; - const int head_n_rep = 1; + const int head_n_rep = params.num_heads / params.num_kv_heads; + // const int head_n_rep = 1; // The head. const int hi = blockIdx.x; const int kvhi = hi / head_n_rep; @@ -1387,7 +1387,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // params.timestep*QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci; - if (handle_kv) { + if (handle_kv && bhi%head_n_rep==0) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { *reinterpret_cast(¶ms.k_cache[offset/8]) = vec_conversion(k); From 852edd5c9b909c559528fb94cf6e3e5015dbe759 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:31:07 -0700 Subject: [PATCH 104/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 0d10400c6..1a1316e3c 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1383,14 +1383,14 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + int offset = bkvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + // params.timestep*QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci; if (handle_kv && bhi%head_n_rep==0) { // Trigger the stores to global memory. if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset/8]) = vec_conversion(k); + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); } } From 0c1a43d138831013159182b8046fc18d6b518d17 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:32:57 -0700 Subject: [PATCH 105/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 1a1316e3c..62b328f5f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1383,6 +1383,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + // int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // // params.timestep*QK_ELTS_IN_16B + + // tlength_circ * QK_ELTS_IN_16B + ci; int offset = bkvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + // params.timestep*QK_ELTS_IN_16B + tlength_circ * QK_ELTS_IN_16B + ci; From 35d30285d34de6876741d0f84a192837e63fb9d7 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:36:30 -0700 Subject: [PATCH 106/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 62b328f5f..b25735eef 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1470,9 +1470,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bhi * params.memory_max_len * Dh + ki]; + T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki]; + T* k_cache_batch = ¶ms.k_cache[bbkvhi * params.memory_max_len * Dh + ki]; // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; From f591af4b50451b69ce19123e6c82f4a6da92705f Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:38:01 -0700 Subject: [PATCH 107/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index b25735eef..064cee13d 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1503,7 +1503,8 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< } else { if (HAS_BEAMS) { - const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + // const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + const int beam_offset = beam_indices[ti_circ] * params.num_kv_heads * params.memory_max_len * Dh; k[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } From bc2714ab0b703fe7a2576790da5d540efbe294b6 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:38:48 -0700 Subject: [PATCH 108/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 064cee13d..381e20d6d 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1470,7 +1470,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; // The base pointer for the key in the cache buffer. - T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; + // T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* k_cache_batch = ¶ms.k_cache[bbkvhi * params.memory_max_len * Dh + ki]; From 83ec68ff13ce9b3570df8cad3bda6835ca1a33a9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:41:39 -0700 Subject: [PATCH 109/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 381e20d6d..2b1d750a3 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,7 +1624,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; + // T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; From 255e077c248bb3ae225c427c1526c47b98e686b6 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 22:42:42 -0700 Subject: [PATCH 110/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 2b1d750a3..7530a14f1 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,9 +1624,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - // T* v_cache = ¶ms.v_cache[bhi * params.memory_max_len * Dh + vi]; + T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi]; + T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; // The number of values processed per iteration of the loop. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; From bf462344ad10b0dadfaf732595baf90de0ddc8c4 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:01:18 -0700 Subject: [PATCH 111/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 7530a14f1..9b2279d39 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1670,7 +1670,8 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { // Fetch offset based on cache_indir when beam sampling const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; - const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + // const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); @@ -1706,7 +1707,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // Fetch offset based on cache_indir when beam sampling const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; // Load the values from the cache. V_vec_k v = vec_conversion( *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); From 2acfc1920305c79adb45b9ebee2d854fb4dde0a5 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:03:27 -0700 Subject: [PATCH 112/135] commit --- .vscode/settings.json | 3 ++- .../decoder_masked_multihead_attention_template.hpp | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index b43758f14..4b17335e2 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -90,6 +90,7 @@ "__nullptr": "cpp", "__string": "cpp", "compare": "cpp", - "concepts": "cpp" + "concepts": "cpp", + "filesystem": "cpp" } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 9b2279d39..901ee516b 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1224,7 +1224,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; + const int tlength_circ = tlength % params.memory_max_len / 8; // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. const bool is_masked = tidx >= QK_VECS_PER_WARP; From 281dcaf2b0cdd65e0cab91ffdac380f0add65d39 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:04:30 -0700 Subject: [PATCH 113/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 901ee516b..ac8fe32d1 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1224,7 +1224,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len / 8; + const int tlength_circ = tlength / 8 % params.memory_max_len; // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. const bool is_masked = tidx >= QK_VECS_PER_WARP; From 7b09acb35160d0369bedcc4fd2893e585906ecfe Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:05:16 -0700 Subject: [PATCH 114/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index ac8fe32d1..9b2279d39 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1224,7 +1224,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength / 8 % params.memory_max_len; + const int tlength_circ = tlength % params.memory_max_len; // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. const bool is_masked = tidx >= QK_VECS_PER_WARP; From be8f65b8b9d9da1157aa29e55fcfb58cb6c2eaae Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:05:46 -0700 Subject: [PATCH 115/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.h | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h index 30890b042..b0968519f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h @@ -33,13 +33,6 @@ struct GroupedQuery_attention_params: public Multihead_attention_params_base int num_kv_heads = 0; // required in case of masked attention with different length const int* length_per_sample = nullptr; - - // output cross attentions - // TODO(zhwang): remove - float* cross_attention_out = nullptr; - int max_decoder_seq_len = 0; - bool is_return_cross_attentions = false; - int* memory_length_per_sample = nullptr; }; template From a7446b3cadc812c96dcd4cf5eb99653100666760 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:12:18 -0700 Subject: [PATCH 116/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 9b2279d39..46238da95 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1223,6 +1223,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int tlength = (params.length_per_sample == nullptr) ? params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; + printf("%d\n", tlength); const int first_step = max(0, tlength + 1 - params.memory_max_len); const int tlength_circ = tlength % params.memory_max_len; From 45f50428d5d8c50625ded841e90b706ea1909c03 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:12:35 -0700 Subject: [PATCH 117/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 46238da95..4e569c9a2 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1223,7 +1223,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int tlength = (params.length_per_sample == nullptr) ? params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; - printf("%d\n", tlength); + printf("%d %d\n", tlength, params.memory_max_len); const int first_step = max(0, tlength + 1 - params.memory_max_len); const int tlength_circ = tlength % params.memory_max_len; From 1a8eb1cf208e5045de8a9f00c2c5d27470e41b48 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:39:59 -0700 Subject: [PATCH 118/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 4e569c9a2..ac37c5841 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1223,7 +1223,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< int tlength = (params.length_per_sample == nullptr) ? params.timestep : params.length_per_sample[bi] + params.max_prefix_prompt_length; - printf("%d %d\n", tlength, params.memory_max_len); const int first_step = max(0, tlength + 1 - params.memory_max_len); const int tlength_circ = tlength % params.memory_max_len; @@ -1768,10 +1767,11 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< *reinterpret_cast( ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); } - - // Store the values with bias back to global memory in the cache for V. - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + if (bhi % head_n_rep == 0) { + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + } // Initialize the output value with the current timestep. #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) From b60394127d86c47b66d9e464c8eda40b021a5049 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:41:41 -0700 Subject: [PATCH 119/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index ac37c5841..1e5566ee1 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,6 +1624,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. + printf("%d %d %d %d", tidx, THREADS_PER_VALUE, V_VEC_SIZE, vi); T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; From 96bd93a2e4edaaecf51de39376a76d5a6928a0bc Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Sun, 10 Sep 2023 23:45:48 -0700 Subject: [PATCH 120/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 1e5566ee1..afb10a237 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,7 +1624,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - printf("%d %d %d %d", tidx, THREADS_PER_VALUE, V_VEC_SIZE, vi); + printf("%d %d %d %d\n", tidx, THREADS_PER_VALUE, V_VEC_SIZE, vi); T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; From 5d8d95e1194bf63c5a27e8433522fca7ed11fc24 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:01:23 -0700 Subject: [PATCH 121/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index afb10a237..ac37c5841 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,7 +1624,6 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - printf("%d %d %d %d\n", tidx, THREADS_PER_VALUE, V_VEC_SIZE, vi); T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; From 6aac295ab8541443f336fe23e4b412f56cdf748a Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:02:25 -0700 Subject: [PATCH 122/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index ac37c5841..1793eb6e0 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,6 +1624,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. + printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; From 69f420b9f30ad78adc98ae0ba444b40ed6989fc9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:05:28 -0700 Subject: [PATCH 123/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index 1793eb6e0..a6cdbc9a4 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,7 +1624,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); + if (bhi == 0) { + printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); + } T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; From d761478ecaa0b0795435cd783d6496e5f354ed6f Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:16:52 -0700 Subject: [PATCH 124/135] commit --- .../decoder_masked_multihead_attention_template.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index a6cdbc9a4..e853668e8 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,7 +1624,7 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - if (bhi == 0) { + if (bkvhi == 63) { printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); } T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; From 839b558530741cab709013fe0fa20b0b43f52510 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:31:21 -0700 Subject: [PATCH 125/135] commit --- ...er_masked_multihead_attention_template.hpp | 6 ++--- .../kernels/unfused_attention_kernels.cu | 24 ++++++++++++------- 2 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp index e853668e8..e942a0a4d 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp @@ -1624,9 +1624,9 @@ __global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params< // The hidden dimensions computed by this particular thread. int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; // The base pointer for the value in the cache buffer. - if (bkvhi == 63) { - printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); - } + // if (bkvhi == 63) { + // printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); + // } T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; // Base pointer for the beam's batch, before offsetting with indirection buffer T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index e4f707033..8ff470167 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1702,14 +1702,17 @@ template __global__ void transpose_4d_batch_major_k_cache( T* k_dst, const T* k_src, const int head_num, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) { - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; - constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; + const int head_n_rep = head_num / kv_head_num; + if (head_id % head_n_rep != 0) { + return; + } auto key_src = reinterpret_cast(k_src + batch_id * head_num * size_per_head * seq_len + head_id * size_per_head * seq_len); auto key_dst = reinterpret_cast(k_dst + batch_id * kv_head_num * size_per_head * max_seq_len - + head_id * size_per_head * max_seq_len); + + head_id / head_n_rep * size_per_head * max_seq_len); const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; int size_per_head_div_x = size_per_head / X_ELEMS; @@ -1757,14 +1760,19 @@ template __global__ void transpose_4d_batch_major_v_cache( T* v_dst, const T* v_src, const int head_num, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) { - const int batch_id = blockIdx.y; - const int head_id = blockIdx.z; + const int batch_id = blockIdx.y; + const int head_id = blockIdx.z; + const int head_n_rep = head_num / kv_head_num; + + if (head_n_rep % head_n_rep != 0) { + return; + } // 16 byte loads will handle "x" dimension auto val_src = reinterpret_cast(v_src + batch_id * head_num * size_per_head * seq_len + head_id * size_per_head * seq_len); auto val_dst = reinterpret_cast(v_dst + batch_id * kv_head_num * size_per_head * max_seq_len - + head_id * size_per_head * max_seq_len); + + head_id / head_n_rep * size_per_head * max_seq_len); // idx is over output dimension L * size_per_head / x for values const int idx = blockIdx.x * blockDim.x + threadIdx.x; From 5eb2e7930caa026c6b47367485d8c5fcc0ca0b93 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:37:08 -0700 Subject: [PATCH 126/135] commit --- src/fastertransformer/kernels/unfused_attention_kernels.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index 8ff470167..d7129bce7 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1828,8 +1828,8 @@ void invokeTranspose4dBatchMajor(T* k_dst, constexpr int block_sz = 128; constexpr int x = (sizeof(T) == 4) ? 4 : 8; int size = max_seq_len * size_per_head / x; - dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); - dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); transpose_4d_batch_major_k_cache<<>>( k_dst, k_src, local_head_num, local_kv_head_num, size_per_head, seq_len, max_seq_len); From c4057d1d3f350ee29d646e5fa49f6bf08c2019aa Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:43:35 -0700 Subject: [PATCH 127/135] commit --- .../kernels/unfused_attention_kernels.cu | 40 ++++++++----------- 1 file changed, 16 insertions(+), 24 deletions(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d7129bce7..d71b52480 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1700,19 +1700,15 @@ __global__ void transpose_4d_batch_major_k_cache( template __global__ void transpose_4d_batch_major_k_cache( - T* k_dst, const T* k_src, const int head_num, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) + T* k_dst, const T* k_src, const int head_n_rep, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) { const int batch_id = blockIdx.y; const int head_id = blockIdx.z; constexpr int X_ELEMS = (sizeof(T) == 4) ? 4 : 8; - const int head_n_rep = head_num / kv_head_num; - if (head_id % head_n_rep != 0) { - return; - } - auto key_src = reinterpret_cast(k_src + batch_id * head_num * size_per_head * seq_len - + head_id * size_per_head * seq_len); + auto key_src = reinterpret_cast(k_src + batch_id * head_n_rep * kv_head_num * size_per_head * seq_len + + head_id * head_n_rep * size_per_head * seq_len); auto key_dst = reinterpret_cast(k_dst + batch_id * kv_head_num * size_per_head * max_seq_len - + head_id / head_n_rep * size_per_head * max_seq_len); + + head_id * size_per_head * max_seq_len); const int out_idx = blockIdx.x * blockDim.x + threadIdx.x; int size_per_head_div_x = size_per_head / X_ELEMS; @@ -1758,21 +1754,16 @@ __global__ void transpose_4d_batch_major_v_cache( template __global__ void transpose_4d_batch_major_v_cache( - T* v_dst, const T* v_src, const int head_num, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) + T* v_dst, const T* v_src, const int head_n_rep, const int kv_head_num, const int size_per_head, const int seq_len, const int max_seq_len) { const int batch_id = blockIdx.y; const int head_id = blockIdx.z; - const int head_n_rep = head_num / kv_head_num; - - if (head_n_rep % head_n_rep != 0) { - return; - } // 16 byte loads will handle "x" dimension - auto val_src = reinterpret_cast(v_src + batch_id * head_num * size_per_head * seq_len - + head_id * size_per_head * seq_len); + auto val_src = reinterpret_cast(v_src + batch_id * kv_head_num * head_n_rep * size_per_head * seq_len + + head_id * head_n_rep * size_per_head * seq_len); auto val_dst = reinterpret_cast(v_dst + batch_id * kv_head_num * size_per_head * max_seq_len - + head_id / head_n_rep * size_per_head * max_seq_len); + + head_id * size_per_head * max_seq_len); // idx is over output dimension L * size_per_head / x for values const int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -1825,17 +1816,18 @@ void invokeTranspose4dBatchMajor(T* k_dst, const int local_kv_head_num, cudaStream_t stream) { - constexpr int block_sz = 128; - constexpr int x = (sizeof(T) == 4) ? 4 : 8; - int size = max_seq_len * size_per_head / x; - dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_head_num); - dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_head_num); + constexpr int block_sz = 128; + constexpr int x = (sizeof(T) == 4) ? 4 : 8; + int size = max_seq_len * size_per_head / x; + int head_n_rep = head_num / kv_head_num; + dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); + dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); transpose_4d_batch_major_k_cache<<>>( - k_dst, k_src, local_head_num, local_kv_head_num, size_per_head, seq_len, max_seq_len); + k_dst, k_src, head_n_rep, local_kv_head_num, size_per_head, seq_len, max_seq_len); transpose_4d_batch_major_v_cache<<>>( - v_dst, v_src, local_head_num, local_kv_head_num, size_per_head, seq_len, max_seq_len); + v_dst, v_src, head_n_rep, local_kv_head_num, size_per_head, seq_len, max_seq_len); } #define INSTANTIATETRANSPOSE4DBATCHMAJOR(T) \ From 96eb6adef5bc818c17c4cd3f497dc0d84de6a665 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:43:59 -0700 Subject: [PATCH 128/135] commit --- src/fastertransformer/kernels/unfused_attention_kernels.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/unfused_attention_kernels.cu b/src/fastertransformer/kernels/unfused_attention_kernels.cu index d71b52480..4f0df238e 100644 --- a/src/fastertransformer/kernels/unfused_attention_kernels.cu +++ b/src/fastertransformer/kernels/unfused_attention_kernels.cu @@ -1819,7 +1819,7 @@ void invokeTranspose4dBatchMajor(T* k_dst, constexpr int block_sz = 128; constexpr int x = (sizeof(T) == 4) ? 4 : 8; int size = max_seq_len * size_per_head / x; - int head_n_rep = head_num / kv_head_num; + int head_n_rep = local_head_num / local_kv_head_num; dim3 grid((size + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); dim3 grid_v((seq_len * size_per_head / x + block_sz - 1) / block_sz, local_batch_size, local_kv_head_num); From 4b8303bd6f52ea569d1f9b5de85a5cff14f14385 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:45:51 -0700 Subject: [PATCH 129/135] commit --- .../decoder_masked_multihead_attention_128.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_144.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_160.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_192.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_224.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_256.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_32.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_48.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_64.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_80.cu | 16 ++++++++-------- .../decoder_masked_multihead_attention_96.cu | 16 ++++++++-------- 11 files changed, 88 insertions(+), 88 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu index 210fb8f77..e14d88312 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -48,24 +48,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -86,4 +86,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu index d4922de73..18cdba3e4 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu index 9bd6da2a0..10d081fed 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, GroupedQuery_attention const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu index 651cc5c6c..222722a74 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, GroupedQuery_attention const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu index 0bf1e9aaf..757f2bb51 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, GroupedQuery_attention const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu index a8bcd0180..ea820b5cc 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, GroupedQuery_attention const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu index e8d24cf5a..891badf0c 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -48,24 +48,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -85,4 +85,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, GroupedQuery_attention_p const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu index c2c65e6b0..9aabcf0b7 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, GroupedQuery_attention_p const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu index e8a6e9410..ec31f7257 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, GroupedQuery_attention_p const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu index ecf73b9ed..d24011427 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, GroupedQuery_attention_ const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu index c6b52930f..8ac5171a9 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu @@ -24,7 +24,7 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -#define MMHA_LAUNCH_KERNEL( \ +#define MGQA_LAUNCH_KERNEL( \ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ @@ -47,24 +47,24 @@ void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); } } else { if (tlength < 32) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); } else if (tlength < 2048) { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); } else { - MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); } } } @@ -84,4 +84,4 @@ template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, GroupedQuery_attention_ const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); #endif -#undef MMHA_LAUNCH_KERNEL +#undef MGQA_LAUNCH_KERNEL From 0f7b36340aed46f5fa3a7c3a55070484fdd0a031 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:48:24 -0700 Subject: [PATCH 130/135] commit --- .../decoder_masked_multihead_attention_128.cu | 2 +- .../decoder_masked_multihead_attention_144.cu | 2 +- .../decoder_masked_multihead_attention_160.cu | 2 +- .../decoder_masked_multihead_attention_192.cu | 2 +- .../decoder_masked_multihead_attention_224.cu | 2 +- .../decoder_masked_multihead_attention_256.cu | 2 +- .../decoder_masked_multihead_attention_32.cu | 2 +- .../decoder_masked_multihead_attention_48.cu | 2 +- .../decoder_masked_multihead_attention_64.cu | 2 +- .../decoder_masked_multihead_attention_80.cu | 2 +- .../decoder_masked_multihead_attention_96.cu | 2 +- .../decoder_masked_multihead_attention_template.hpp | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu index e14d88312..24d8a3e91 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu @@ -28,7 +28,7 @@ T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_multihead_attention_kernel -__global__ void masked_multihead_attention_kernel(GroupedQuery_attention_params params) +__global__ void masked_groupedquery_attention_kernel(GroupedQuery_attention_params params) { using Tk = typename kernel_type_t::Type; #ifdef ENABLE_FP8 From fa3fa51c7b2c26510c7a572d8d0e65449ac31022 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:51:05 -0700 Subject: [PATCH 131/135] commit --- .../decoder_masked_multihead_attention_128.cu | 89 - .../decoder_masked_multihead_attention_144.cu | 87 - .../decoder_masked_multihead_attention_160.cu | 87 - .../decoder_masked_multihead_attention_192.cu | 87 - .../decoder_masked_multihead_attention_224.cu | 87 - .../decoder_masked_multihead_attention_256.cu | 87 - .../decoder_masked_multihead_attention_32.cu | 88 - .../decoder_masked_multihead_attention_48.cu | 87 - .../decoder_masked_multihead_attention_64.cu | 87 - .../decoder_masked_multihead_attention_80.cu | 87 - .../decoder_masked_multihead_attention_96.cu | 87 - ...er_masked_multihead_attention_template.hpp | 1878 ----------------- 12 files changed, 2838 deletions(-) delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu delete mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu deleted file mode 100644 index 24d8a3e91..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_128.cu +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu deleted file mode 100644 index 350499c5b..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_144.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu deleted file mode 100644 index 8f392cf77..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_160.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu deleted file mode 100644 index ff68387ea..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_192.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu deleted file mode 100644 index ea66caf2d..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_224.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu deleted file mode 100644 index 12876a48e..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_256.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu deleted file mode 100644 index 3f877d3dc..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_32.cu +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - //constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - constexpr bool DO_CROSS_ATTENTION = false; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu deleted file mode 100644 index 243886074..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_48.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu deleted file mode 100644 index b30030d71..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_64.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu deleted file mode 100644 index c34a8546e..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_80.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu deleted file mode 100644 index ac3b6369b..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_96.cu +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "decoder_masked_multihead_attention_template.hpp" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define MGQA_LAUNCH_KERNEL( \ - T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ - size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ - dim3 grid(params.num_heads, params.batch_size); \ - mmha::masked_groupedquery_attention_kernel<<>>(params) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// !!! Specialize the launcher for Cross attention -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) -{ - constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - if (params.cache_indir == nullptr) { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); - } - } - else { - if (tlength < 32) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); - } - else if (tlength < 2048) { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); - } - else { - MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); - } - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -template void mgqa_launch_kernel>( - const GroupedQuery_attention_params& params, const cudaStream_t& stream); -#ifdef ENABLE_BF16 -template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, GroupedQuery_attention_params<__nv_bfloat16>>( - const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); -#endif -#ifdef ENABLE_FP8 -template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( - const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); -#endif - -#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp deleted file mode 100644 index 581d566ca..000000000 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp +++ /dev/null @@ -1,1878 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" -#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" -#include "src/fastertransformer/utils/cuda_fp8_utils.h" -#include "src/fastertransformer/utils/cuda_type_utils.cuh" -#include -#include -#include - -// #define MMHA_USE_HMMA_FOR_REDUCTION - -// Below are knobs to extend FP32 accumulation for higher FP16 accuracy - -// Does not seem to affect the accuracy that much -// #define MMHA_USE_FP32_ACUM_FOR_FMA - -// Seems to slightly improve the accuracy -#define MMHA_USE_FP32_ACUM_FOR_OUT - -#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) - // Does not seem to improve the accuracy - //#define MMHA_USE_FP32_ACUM_FOR_LOGITS -#endif - -namespace mmha { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// -// We use the following terminology to describe the different dimensions. -// -// B: Batch size (number of sequences), -// L: Sequence length, -// D: Hidden dimension, -// H: Number of heads, -// Dh: Hidden dimension per head - Dh = D / H. -// -// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use -// 64, 128 and 256 threads per block. -// -// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to -// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The -// cache buffer helps with memory accesses and contains keys with bias. -// -// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and -// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The -// values for x are chosen to create chunks of 16 bytes. -// -// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs -// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At -// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an -// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. -// -// After that loop, a parallel softmax is computed across the different Q * K^T values stored in -// shared memory. -// -// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many -// timesteps are computed by loop iteration. As with the keys, the values are read from a cache -// except for the current timestep. The layout of the cache buffer for the values is much simpler -// as it is [B, H, L, Dh]. -// - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_m_ { -}; - -template<> -struct Qk_vec_m_ { - using Type = float; -}; -template<> -struct Qk_vec_m_ { - using Type = float2; -}; -template<> -struct Qk_vec_m_ { - using Type = float4; -}; -template<> -struct Qk_vec_m_ { - using Type = float4; -}; -template<> -struct Qk_vec_m_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_m_ { - using Type = uint32_t; -}; -template<> -struct Qk_vec_m_ { - using Type = uint2; -}; -template<> -struct Qk_vec_m_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct Qk_vec_m_<__nv_bfloat16, 32> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_m_<__nv_bfloat16, 64> { - using Type = __nv_bfloat162; -}; -template<> -struct Qk_vec_m_<__nv_bfloat16, 128> { - using Type = bf16_4_t; -}; -template<> -struct Qk_vec_m_<__nv_bfloat16, 256> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 - -#ifdef ENABLE_FP8 -template<> -struct Qk_vec_m_<__nv_fp8_e4m3, 32> { - using Type = fp8_4_t; -}; -template<> -struct Qk_vec_m_<__nv_fp8_e4m3, 64> { - using Type = fp8_4_t; -}; -template<> -struct Qk_vec_m_<__nv_fp8_e4m3, 128> { - using Type = fp8_4_t; -}; -template<> -struct Qk_vec_m_<__nv_fp8_e4m3, 256> { - using Type = fp8_4_t; -}; -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_vec_k_ { - using Type = typename Qk_vec_m_::Type; -}; -#ifdef ENABLE_FP8 -template<> -struct Qk_vec_k_<__nv_fp8_e4m3, 32> { - using Type = float4; -}; -template<> -struct Qk_vec_k_<__nv_fp8_e4m3, 64> { - using Type = float4; -}; -template<> -struct Qk_vec_k_<__nv_fp8_e4m3, 128> { - using Type = float4; -}; -template<> -struct Qk_vec_k_<__nv_fp8_e4m3, 256> { - using Type = float4; -}; -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_m_ { -}; - -template<> -struct K_vec_m_ { - using Type = float; -}; -template<> -struct K_vec_m_ { - using Type = float2; -}; -template<> -struct K_vec_m_ { - using Type = float4; -}; -template<> -struct K_vec_m_ { - using Type = uint32_t; -}; -template<> -struct K_vec_m_ { - using Type = uint2; -}; -template<> -struct K_vec_m_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct K_vec_m_<__nv_bfloat16, 4> { - using Type = __nv_bfloat162; -}; -template<> -struct K_vec_m_<__nv_bfloat16, 2> { - using Type = bf16_4_t; -}; -template<> -struct K_vec_m_<__nv_bfloat16, 1> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 - -// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes -#ifdef ENABLE_FP8 -template<> -struct K_vec_m_<__nv_fp8_e4m3, 4> { - using Type = fp8_4_t; -}; -template<> -struct K_vec_m_<__nv_fp8_e4m3, 2> { - using Type = fp8_4_t; -}; // Defined for compilation-purpose only, do not use -template<> -struct K_vec_m_<__nv_fp8_e4m3, 1> { - using Type = fp8_4_t; -}; // Defined for compilation-purpose only, do not use -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_k_ { - using Type = typename K_vec_m_::Type; -}; -#ifdef ENABLE_FP8 -template<> -struct K_vec_k_<__nv_fp8_e4m3, 4> { - using Type = float4; -}; -template<> -struct K_vec_k_<__nv_fp8_e4m3, 2> { - using Type = float4; -}; // Defined for compilation-purpose only, do not use -template<> -struct K_vec_k_<__nv_fp8_e4m3, 1> { - using Type = float4; -}; // Defined for compilation-purpose only, do not use -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_m_ { -}; - -template<> -struct V_vec_m_ { - using Type = float; -}; -template<> -struct V_vec_m_ { - using Type = float2; -}; -template<> -struct V_vec_m_ { - using Type = float4; -}; -template<> -struct V_vec_m_ { - using Type = uint32_t; -}; -template<> -struct V_vec_m_ { - using Type = uint2; -}; -template<> -struct V_vec_m_ { - using Type = uint4; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_m_<__nv_bfloat16, 2> { - using Type = __nv_bfloat162; -}; -template<> -struct V_vec_m_<__nv_bfloat16, 4> { - using Type = bf16_4_t; -}; -template<> -struct V_vec_m_<__nv_bfloat16, 8> { - using Type = bf16_8_t; -}; -#endif // ENABLE_BF16 -#ifdef ENABLE_FP8 -template<> -struct V_vec_m_<__nv_fp8_e4m3, 4> { - using Type = fp8_4_t; -}; -template<> -struct V_vec_m_<__nv_fp8_e4m3, 8> { - using Type = fp8_4_t; -}; -template<> -struct V_vec_m_<__nv_fp8_e4m3, 16> { - using Type = fp8_4_t; -}; -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct V_vec_k_ { - using Type = typename V_vec_m_::Type; -}; -#ifdef ENABLE_FP8 -template<> -struct V_vec_k_<__nv_fp8_e4m3, 4> { - using Type = float4; -}; -template<> -struct V_vec_k_<__nv_fp8_e4m3, 8> { - using Type = float4; -}; -template<> -struct V_vec_k_<__nv_fp8_e4m3, 16> { - using Type = float4; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA -template -struct Qk_vec_acum_fp32_ { -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float4; -}; -// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; - -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct Qk_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_FP8 -// template<> -// struct Qk_vec_acum_fp32_ { -// using Type = float2; -// }; -template<> -struct Qk_vec_acum_fp32_ { - using Type = Float4_; -}; -// template<> -// struct Qk_vec_acum_fp32_ { -// using Type = Float4_; -// }; -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct K_vec_acum_fp32_ { -}; - -template<> -struct K_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat16> { - using Type = float; -}; -template<> -struct K_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct K_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_FP8 -// template<> -// struct K_vec_acum_fp32_ { -// using Type = float2; -// }; -template<> -struct K_vec_acum_fp32_ { - using Type = Float4_; -}; -// template<> -// struct K_vec_acum_fp32_ { -// using Type = Float4_; -// }; -#endif // ENABLE_FP8 -#endif // MMHA_USE_FP32_ACUM_FOR_FMA - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT -template -struct V_vec_acum_fp32_ { -}; - -template<> -struct V_vec_acum_fp32_ { - using Type = float; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float4; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#ifdef ENABLE_BF16 -template<> -struct V_vec_acum_fp32_<__nv_bfloat162> { - using Type = float2; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -template<> -struct V_vec_acum_fp32_ { - using Type = Float8_; -}; -#endif // ENABLE_BF16 -#ifdef ENABLE_FP8 -// template<> -// struct V_vec_acum_fp32_ { -// using Type = float2; -// }; -template<> -struct V_vec_acum_fp32_ { - using Type = Float4_; -}; -// template<> -// struct V_vec_acum_fp32_ { -// using Type = Float4_; -// }; -#endif // ENABLE_FP8 -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -__inline__ __device__ Tout vec_conversion(const Tin& x) -{ - return x; -} -#ifdef ENABLE_FP8 -// fp8_t -template<> -__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) -{ - return float(a); -} -template<> -__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) -{ - return __nv_fp8_e4m3(a); -} -// fp8_2_t -template<> -__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) -{ - return float2(a); -} -template<> -__inline__ __device__ fp8_2_t vec_conversion(const float2& a) -{ - return fp8_2_t(a); -} -// fp8_4_t -template<> -__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) -{ - return float4(a); -} -template<> -__inline__ __device__ fp8_4_t vec_conversion(const float4& a) -{ - return fp8_4_t(a); -} -#endif // ENABLE_FP8 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) -{ -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = K_vec; -#endif - // Compute the parallel products for Q*K^T (treat vector lanes separately). - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } - - // Finalize the reduction across lanes. - float qk = sum(qk_vec); -#pragma unroll - for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); - } - return qk; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct Qk_dot { - template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) - { - return qk_dot_(q, k); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) -{ - float4 c; - float zero = 0.f; - asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" - " {%0, %1, %2, %3}, \n" - " {%4, %5}, \n" - " {%6}, \n" - " {%7, %7, %7, %7}; \n" - - : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) - : "r"(a.x), "r"(a.y), "r"(b), "f"(zero)); - return c; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using K_vec_acum = typename K_vec_acum_fp32_::Type; -#else - using K_vec_acum = uint32_t; -#endif - K_vec_acum qk_vec = mul(q[0], k[0]); -#pragma unroll - for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); - } -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - uint32_t qk_vec_ = float2_to_half2(qk_vec); - return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; -#else - return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; -#endif -#else - return 0.f; -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template<> -struct Qk_dot { - template - static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) - { -#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) - return qk_hmma_dot_(q, k); -#else - return qk_dot_<4>(q, k); -#endif // defined MMHA_USE_HMMA_FOR_REDUCTION - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ float block_sum(float* red_smem, float sum) -{ - - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - -// Compute the sum per warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < WARPS_PER_BLOCK) { - sum = red_smem[lane]; - } - -// Parallel reduction inside the warp. -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); - } - - // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float& dst, float src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint16_t& dst, float src) -{ - dst = float_to_half(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint32_t& dst, float2 src) -{ - dst = float2_to_half2(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) -{ - dst = __float2bfloat16(src); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst = __float22bfloat162_rn(src); -#else - dst = __floats2bfloat162_rn(src.x, src.y); -#endif -} -#endif // ENABLE_BF16 -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, Float4_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint2& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(uint4& dst, Float8_ src) -{ - dst.x = float2_to_half2(src.x); - dst.y = float2_to_half2(src.y); - dst.z = float2_to_half2(src.z); - dst.w = float2_to_half2(src.w); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_BF16 -inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); -#endif -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) -{ - convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) -{ -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - dst.x = __float22bfloat162_rn(src.x); - dst.y = __float22bfloat162_rn(src.y); - dst.z = __float22bfloat162_rn(src.z); - dst.w = __float22bfloat162_rn(src.w); -#else - dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); - dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); - dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); - dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); -#endif -} -#endif // ENABLE_BF16 - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef ENABLE_FP8 -inline __device__ void convert_from_float(fp8_4_t& dst, float4 src) -{ - dst = fp8_4_t(src); -} -inline __device__ void convert_from_float(fp8_2_t& dst, float2 src) -{ - dst = fp8_2_t(src); -} -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float2& dst, float2 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ void convert_from_float(float4& dst, float4 src) -{ - dst = src; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(float4 u) -{ - return u.x; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float convert_to_float(uint4 u) -{ - float2 tmp = half2_to_float2(u.x); - return tmp.x; -} - -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float cast_to_float(float u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(float2 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 cast_to_float(float4 u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(Float4_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(Float8_ u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 cast_to_float(uint32_t u) -{ - return half2_to_float2(u); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float4_ cast_to_float(uint2 u) -{ - Float4_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - return tmp; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ Float8_ cast_to_float(uint4 u) -{ - Float8_ tmp; - tmp.x = half2_to_float2(u.x); - tmp.y = half2_to_float2(u.y); - tmp.z = half2_to_float2(u.z); - tmp.w = half2_to_float2(u.w); - return tmp; -} - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float float_from_int8(int8_t u) -{ - return u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float2 float_from_int8(int16_t u) -{ - union { - int16_t int16; - int8_t int8[2]; - }; - int16 = u; - return make_float2(int8[0], int8[1]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ float4 float_from_int8(int32_t u) -{ - union { - int32_t int32; - int8_t int8[4]; - }; - int32 = u; - return make_float4(int8[0], int8[1], int8[2], int8[3]); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -// clang-format off -inline __device__ Float8_ float_from_int8(int64_t u) -{ - union { - int64_t int64; - int16_t int16[4]; - }; - int64 = u; - return Float8_ {float_from_int8(int16[0]), - float_from_int8(int16[1]), - float_from_int8(int16[2]), - float_from_int8(int16[3])}; -} -// clang-format on - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int8_t cast_to_int8(float val) -{ - union { - int8_t int8[2]; - int16_t int16; - }; - asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); - return int8[0]; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int32_t cast_to_int8(float4 val) -{ - union { - int8_t int8[4]; - int32_t int32; - }; - int8[0] = cast_to_int8(val.x); - int8[1] = cast_to_int8(val.y); - int8[2] = cast_to_int8(val.z); - int8[3] = cast_to_int8(val.w); - return int32; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ int64_t cast_to_int8(Float8_ val) -{ - union { - int8_t int8[8]; - int64_t int64; - }; - int8[0] = cast_to_int8(val.x.x); - int8[1] = cast_to_int8(val.x.y); - int8[2] = cast_to_int8(val.y.x); - int8[3] = cast_to_int8(val.y.y); - int8[4] = cast_to_int8(val.z.x); - int8[5] = cast_to_int8(val.z.y); - int8[6] = cast_to_int8(val.w.x); - int8[7] = cast_to_int8(val.w.y); - return int64; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline __device__ __host__ T div_up(T m, T n) -{ - return (m + n - 1) / n; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct kernel_type_t { - using Type = T; -}; - -#ifdef ENABLE_FP8 -template<> -struct kernel_type_t<__nv_fp8_e4m3> { - using Type = float; -}; -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& params, - int threads_per_value, - int threads_per_block) -{ - using Tk = typename kernel_type_t::Type; - // The amount of shared memory needed to store the Q*K^T values in float. - const int max_timesteps = min(params.timestep, params.memory_max_len); - size_t qk_sz = div_up(max_timesteps + 1, 4) * 16; - - // The extra memory needed if we are not using floats for the final logits. - size_t logits_sz = 0; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(Tk) != 4) { - // TDOD - logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); - } -#endif - - // The total size needed during softmax. - size_t softmax_sz = qk_sz + logits_sz; - - // The number of partial rows to reduce in the final reduction. - int rows_per_red = threads_per_block / threads_per_value; - // The amount of storage needed to finalize the outputs. - size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2; - - size_t transpose_rotary_size = 0; - if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk); - } - - // The max. - return max(max(softmax_sz, red_sz), transpose_rotary_size); -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -inline __device__ constexpr uint32_t shfl_mask(int threads) -{ - return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template< - // The type of the inputs. Supported types: float and half. - typename T, - // The hidden dimension per head. - int Dh, - int Dh_MAX, - // The number of threads per key. - int THREADS_PER_KEY, - // The number of threads per value. - int THREADS_PER_VALUE, - // The number of threads in a threadblock. - int THREADS_PER_BLOCK, - bool HAS_BEAMS> -__global__ void masked_groupedquery_attention_kernel(GroupedQuery_attention_params params) -{ - using Tk = typename kernel_type_t::Type; -#ifdef ENABLE_FP8 - // FP8 MHA Scales - constexpr bool FP8_MHA_KERNEL = std::is_same::value; -#else - constexpr bool FP8_MHA_KERNEL = false; -#endif - // Make sure the hidden dimension per head is a multiple of the number of threads per key. - static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); - // Make sure the hidden dimension per head is a multiple of the number of threads per value. - static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); - - // The size of a warp. - constexpr int WARP_SIZE = 32; - // The number of warps in a threadblock. - constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; - - // Use smem_size_in_bytes (above) to determine the amount of shared memory. - extern __shared__ char smem_[]; - - // The shared memory for the Q*K^T values and partial logits in softmax. - float* qk_smem = reinterpret_cast(smem_); - - // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. - char* logits_smem_ = smem_; -#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS - if (sizeof(Tk) != 4) { - // TODO - change to tlength - const int max_timesteps = min(params.timestep, params.memory_max_len); - logits_smem_ += div_up(max_timesteps + 1, 4) * 16; - } - Tk* logits_smem = reinterpret_cast(logits_smem_); -#else - float* logits_smem = reinterpret_cast(logits_smem_); -#endif - - // The shared memory to do the final reduction for the output values. Reuse qk_smem. - Tk* out_smem = reinterpret_cast(smem_); - - // The shared memory buffers for the block-wide reductions. One for max, one for sum. - __shared__ float red_smem[WARPS_PER_BLOCK * 2]; - - // A vector of Q or K elements for the current timestep. - using Qk_vec_k = typename Qk_vec_k_::Type; // with kernel-used precision - using Qk_vec_m = typename Qk_vec_m_::Type; // with memory-used precision - - // Use alignment for safely casting the shared buffers as Qk_vec_k. - // Shared memory to store Q inputs. - __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX]; - - // The number of elements per vector. - constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); - // We will use block wide reduction if needed - // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); - // The number of vectors per warp. - constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; - - // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread - // owns x elements, we have to decompose the linear index into chunks of x values and the posi- - // tion of the thread in that chunk. - - // The number of elements in a chunk of 16B (that's the x in the above formula). - constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); - // The number of K vectors in 16B. - constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m); - - // The batch/beam idx - const int bi = blockIdx.y; - if (params.finished != nullptr && params.finished[bi] == true) { - return; - } - // The beam idx - const int beami = bi % params.beam_width; - // The "beam-aware" batch idx - const int bbi = bi / params.beam_width; - const int head_n_rep = params.num_heads / params.num_kv_heads; - // const int head_n_rep = 1; - // The head. - const int hi = blockIdx.x; - const int kvhi = hi / head_n_rep; - // Combine the batch and the head indices. - const int bhi = bi * params.num_heads + hi; - const int bkvhi = bi * params.num_kv_heads + kvhi; - // Combine the "beam-aware" batch idx and the head indices. - const int bbhi = bbi * params.beam_width * params.num_heads + hi; - const int bbkvhi = bbi * params.beam_width * params.num_kv_heads + kvhi; - // The thread in the block. - const int tidx = threadIdx.x; - - constexpr bool handle_kv = true; - - // here. - - // While doing the product Q*K^T for the different keys we track the max. - float qk_max = -FLT_MAX; - - float qk = 0.0F; - - int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; - - const size_t bi_seq_len_offset = bi * params.memory_max_len; - - int tlength = (params.length_per_sample == nullptr) ? - params.timestep : - params.length_per_sample[bi] + params.max_prefix_prompt_length; - const int first_step = max(0, tlength + 1 - params.memory_max_len); - const int tlength_circ = tlength % params.memory_max_len; - - // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. - const bool is_masked = tidx >= QK_VECS_PER_WARP; - - // The offset in the Q and K buffer also accounts for the batch. - int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; - // The offset in the bias buffer. - int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; - - const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; - const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; - - // Trigger the loads from the Q and K buffers. - Qk_vec_k q; - zero(q); - if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto q_scaling = params.qkv_scale_out[0]; - const auto q_quant = - *reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]); - - convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); - } - else { - q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); - } - } - - Qk_vec_k k; - zero(k); - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto k_scaling = params.qkv_scale_out[1]; - const auto k_quant = - *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); - - convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); - } - else { - k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? - vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : - k; - } - - // Trigger the loads from the Q and K bias buffers. - Qk_vec_k q_bias; - zero(q_bias); - q_bias = - (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? - vec_conversion(*reinterpret_cast(¶ms.q_bias[qk_bias_offset])) : - q_bias; - - Qk_vec_k k_bias; - zero(k_bias); - if (handle_kv) { - k_bias = - !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? - vec_conversion(*reinterpret_cast(¶ms.k_bias[qk_bias_offset])) : - k_bias; - } - - // Computes the Q/K values with bias. - q = add(q, q_bias); - if (handle_kv) { - k = add(k, k_bias); - } - if (do_ia3 && !is_masked) { - k = mul( - k, - vec_conversion(*reinterpret_cast( - ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]))); - } - - // Padded len - const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; - if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { - if (handle_kv) { - apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); - } - else { - apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); - } - } - else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { - const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; - - T* q_smem = reinterpret_cast(smem_); - T* k_smem = q_smem + params.rotary_embedding_dim; - - const int half_rotary_dim = params.rotary_embedding_dim / 2; - const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; - const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; - const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts - - assert(half_rotary_dim % QK_VEC_SIZE == 0); - - if (do_rotary) { - *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; - - if (handle_kv) { - *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; - } - } - - __syncthreads(); - - const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; - constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; - if (do_rotary) { - mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - - if (handle_kv) { - mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - - mmha::apply_rotary_embedding( - q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); - - mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); - } - else { - mmha::apply_rotary_embedding( - q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); - } - mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); - } - - __syncthreads(); - - if (do_rotary) { - q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); - if (handle_kv) { - k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); - } - } - - __syncthreads(); - } - - if (!is_masked) { - // Store the Q values to shared memory. - *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; - - // Write the K values to the global memory cache. - // - // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory - // system. We designed it this way as it allows much better memory loads (and there are many - // more loads) + the stores are really "write and forget" since we won't need the ack before - // the end of the kernel. There's plenty of time for the transactions to complete. - - // The 16B chunk written by the thread. - int co = tidx / QK_VECS_IN_16B; - // The position of the thread in that 16B chunk. - int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; - - // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. - // int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // // params.timestep*QK_ELTS_IN_16B + - // tlength_circ * QK_ELTS_IN_16B + ci; - int offset = bkvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + - // params.timestep*QK_ELTS_IN_16B + - tlength_circ * QK_ELTS_IN_16B + ci; - - if (handle_kv && bhi%head_n_rep==0) { - // Trigger the stores to global memory. - if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { - *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); - } - } - - // Compute \sum_i Q[i] * K^T[i] for the current timestep. -#ifdef MMHA_USE_FP32_ACUM_FOR_FMA - using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; -#else - using Qk_vec_acum = Qk_vec_k; -#endif - qk = dot(q, k); - if (QK_VECS_PER_WARP <= WARP_SIZE) { -#pragma unroll - for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); - } - } - } - - if (QK_VECS_PER_WARP > WARP_SIZE) { - constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; - qk = block_sum(&red_smem[WARPS_PER_RED], qk); - } - - // Store that value in shared memory. Keep the Q*K^T value in register for softmax. - if (tidx == 0) { - // Normalize qk. - qk *= params.inv_sqrt_dh; - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + (tlength - padd_len) * params.relative_attention_bias_stride - + (tlength - padd_len)]); - } - // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. - - qk_max = qk; - qk_smem[tlength - first_step] = qk; - // qk_smem[params.timestep] = qk; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The type of queries and keys for the math in the Q*K^T product. - using K_vec_k = typename K_vec_k_::Type; - using K_vec_m = typename K_vec_m_::Type; - // The number of elements per vector. - constexpr int K_VEC_SIZE = sizeof(K_vec_m) / sizeof(T); - // Make sure the hidden size per head is a multiple of the vector size. - static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); - // The number of elements per thread. - constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; - // The number of vectors per thread. - constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; - - // The position the first key loaded by each thread from the cache buffer (for this B * H). - int ko = tidx / THREADS_PER_KEY; - // The position of the thread in the chunk of keys. - int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; - - static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); - - // Load the Q values from shared memory. The values are reused during the loop on K. - K_vec_k q_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); - } - - // The number of timesteps loaded per iteration. - constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; - // The number of keys per warp. - constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - - // The base pointer for the key in the cache buffer. - // T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* k_cache_batch = ¶ms.k_cache[bbkvhi * params.memory_max_len * Dh + ki]; - - // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). - // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; - int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; - - // prefix prompt length if has - const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; - - // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. - const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; - - for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { - const int ti_circ = ti % params.memory_max_len; - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; - - // The keys loaded from the key cache. - K_vec_k k[K_VECS_PER_THREAD]; - K_vec_k k_vec_zero; - zero(k_vec_zero); -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.memory_max_len + ti_circ; - // if( ti < params.timestep ) { - const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); - if (ti < tlength) { - if (!within_bounds) { - k[ii] = k_vec_zero; - } - else { - if (HAS_BEAMS) { - // const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; - const int beam_offset = beam_indices[ti_circ] * params.num_kv_heads * params.memory_max_len * Dh; - k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); - } - else { - k[ii] = vec_conversion( - (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); - } - } - } - } - - // Perform the dot product and normalize qk. - // - // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! - float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; - - // Store the product to shared memory. There's one qk value per timestep. Update the max. - // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { - if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - if (params.relative_attention_bias != nullptr) { - qk = add(qk, - params.relative_attention_bias[hi * params.relative_attention_bias_stride - * params.relative_attention_bias_stride - + tlength * params.relative_attention_bias_stride + ti]); - } - if (params.linear_bias_slopes != nullptr) { - // Apply the linear position bias: (ki - qi) * slope[hi]. - // The padding token locates between the input context and the generated tokens. - // We need to remove the number of padding tokens in the distance computation. - // ti : 0 1 2 3 4 5 6 7 8 9(tlength) - // token: i i i i p p p o o o where i=input, p=pad, o=output. - // e.g. ti = 2, dist = (9 - 3) - 2 = 4. - int max_context_length = params.max_prefix_prompt_length + params.max_input_length; - float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; - - qk += mul(params.linear_bias_slopes[hi], dist); - } - qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); - qk_smem[ti - first_step] = qk; - } - } - -// Perform the final reduction to compute the max inside each warp. -// -// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the -// group so it's not needed to run the reduction inside the group (again). -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Decompose the thread index into warp and lane. - const int warp = tidx / WARP_SIZE; - const int lane = tidx % WARP_SIZE; - - // The warp leader writes the max to shared memory. - if (lane == 0) { - red_smem[warp] = qk_max; - } - - // Make sure the products are in shared memory. - __syncthreads(); - - // The warps finalize the reduction. - qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); - } - - // Broadcast to all the threads in the warp. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); - - // Compute the logits and start the sum. - float sum = 0.f; - // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; -#ifdef FP8_MHA - float logit = 0.f; - if (FP8_MHA_KERNEL) { - logit = is_mask ? 0.f : - __expf((qk_smem[ti - first_step] - qk_max) * params.query_weight_output_scale[0] - * params.query_weight_output_scale[0]); - } - else { - logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); - } -#else - float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); -#endif - sum += logit; - qk_smem[ti - first_step] = logit; - } - - // Compute the sum. - sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); - - // Normalize the logits. - float inv_sum = __fdividef(1.f, sum + 1.e-6f); - for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - float logit = qk_smem[ti - first_step] * inv_sum; - convert_from_float(logits_smem[ti - first_step], logit); - } - - // Put Values part below so we leverage __syncthreads - // from the previous step - - // The number of elements per vector. - constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; - // A vector of V elements for the current timestep. - using V_vec_k = typename V_vec_k_::Type; - using V_vec_m = typename V_vec_m_::Type; - - // The value computed by this thread. - int vo = tidx / THREADS_PER_VALUE; - // The hidden dimensions computed by this particular thread. - int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; - // The base pointer for the value in the cache buffer. - // if (bkvhi == 63) { - // printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); - // } - T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; - // Base pointer for the beam's batch, before offsetting with indirection buffer - T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; - - // The number of values processed per iteration of the loop. - constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; - - // One group of threads computes the product(s) for the current timestep. - V_vec_k v_bias; - zero(v_bias); - // if( vo == params.timestep % V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - if (vo == tlength % V_PER_ITER) { - // Trigger the loads from the V bias buffer. - if (params.v_bias != nullptr) { - v_bias = vec_conversion( - *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); - } - } - } - - // From previous, before values, step - // Also make sure the logits are in shared memory. - __syncthreads(); - - // Values continued -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - using V_vec_acum = typename V_vec_acum_fp32_::Type; -#else - using V_vec_acum = V_vec_k; -#endif - // The partial outputs computed by each thread. - V_vec_acum out; - zero(out); - - // Loop over the timesteps to compute the partial outputs. - // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { - if (Dh == Dh_MAX || vi < Dh) { - - // Separate the ti < memory_max_len and ti > memory_max_len - // to prevent ti % memory_len when ti < memory_len, and - // the compiler cannot optimize the codes automatically. - const int min_length = min(tlength, params.memory_max_len); - for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { - // Fetch offset based on cache_indir when beam sampling - const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; - // const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; - const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; - // Load the values from the cache. - V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else // MMHA_USE_FP32_ACUM_FOR_LOGITS -#ifdef FP8_MHA - Tk logit; - if (FP8_MHA_KERNEL) { - // NOTE: fake quantization - // logit = vec_conversion(vec_conversion(mul(1.0f / - // params.attention_qk_scale[0], logits_smem[ti]))); - logit = logits_smem[ti - first_step]; - } - else { - logit = logits_smem[ti - first_step]; - } - out = fma(logit, v, out); -#else // FP8_MHA - Tk logit = logits_smem[ti - first_step]; - out = fma(logit, v, out); -#endif // FP8_MHA -#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS - } - for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { - if (ti < params.memory_max_len) { - // handled by previous loop - continue; - } - const int ti_circ = ti % params.memory_max_len; - - // Fetch offset based on cache_indir when beam sampling - const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; - const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; - // Load the values from the cache. - V_vec_k v = vec_conversion( - *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); - // Load the logits from shared memory. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - float logit = logits_smem[ti - first_step]; - out = fma(logit, cast_to_float(v), out); -#else // MMHA_USE_FP32_ACUM_FOR_LOGITS -#ifdef FP8_MHA - Tk logit; - if (FP8_MHA_KERNEL) { - // NOTE: fake quantization - // logit = vec_conversion(vec_conversion(mul(1.0f / - // params.attention_qk_scale[0], logits_smem[ti]))); - logit = logits_smem[ti - first_step]; - } - else { - logit = logits_smem[ti - first_step]; - } - out = fma(logit, v, out); -#else // FP8_MHA - Tk logit = logits_smem[ti - first_step]; - out = fma(logit, v, out); -#endif // FP8_MHA -#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS - } - } - - // One group of threads computes the product(s) for the current timestep. - // if( vo == params.timestep % V_PER_ITER ) { - if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { - - V_vec_k v; - // Trigger the loads from the V buffer. - const auto v_offset = qkv_base_offset + vi; - if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - using Packed_Float_t = typename packed_type::value>::type; - const auto v_scaling = params.qkv_scale_out[2]; - const auto v_quant = - *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); - - convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); - } - else { - v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); - } - // Trigger the loads from the V bias buffer. - // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); - - // Compute the V values with bias. - v = add(v, v_bias); - - if (do_ia3) { - v = mul( - v, - *reinterpret_cast( - ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); - } - if (bhi % head_n_rep == 0) { - // Store the values with bias back to global memory in the cache for V. - //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; - *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); - } - - // Initialize the output value with the current timestep. -#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) - // out = fma(logits_smem[params.timestep], cast_to_float(v), out); - out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); -#else // MMHA_USE_FP32_ACUM_FOR_LOGITS - // out = fma(logits_smem[params.timestep], v, out); -#ifdef FP8_MHA - Tk logit; - if (FP8_MHA_KERNEL) { - // NOTE: fake quantization - // logit = mul(1.0f / params.attention_qk_scale[0], logits_smem[tlength]); - logit = logits_smem[tlength - first_step]; - } - else { - logit = logits_smem[tlength - first_step]; - } - out = fma(logit, v, out); -#else // FP8_MHA - out = fma(logits_smem[tlength - first_step], v, out); -#endif // FP8_MHA -#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS - } - - // Make sure we can start writing to shared memory. - __syncthreads(); - - // Run the final reduction amongst the different groups computing different partial outputs. - if (Dh == Dh_MAX || vi < Dh) { -#pragma unroll - for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { - - // The midpoint in the number of active groups. - int midpoint = active_groups / 2; - - // The upper part of active threads store to shared memory. - if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); -#else - *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; -#endif - } - __syncthreads(); - - // The bottom warps update their values. - if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { - out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); - } - __syncthreads(); - } - } - - // Output the final values. - if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { -#ifdef MMHA_USE_FP32_ACUM_FOR_OUT - if (FP8_MHA_KERNEL) { -#ifdef FP8_MHA - // float result_scale = params.attention_qk_scale[0] * params.query_weight_output_scale[0] * - // params.attention_output_weight_input_scale_inv[0]; - float result_scale = - params.query_weight_output_scale[0] * params.attention_output_weight_input_scale_inv[0]; - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), - mul(result_scale, out)); -#endif // FP8_MHA - } - else if (params.int8_mode == 2) { - using Packed_Int8_t = typename packed_type::value>::type; - out = mul(*params.attention_out_scale, out); - *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = - cast_to_int8(out); - } - else { - convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); - } -#else // MMHA_USE_FP32_ACUM_FOR_OUT - // TODO: support int8_mode? - *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = vec_conversion(out); -#endif // MMHA_USE_FP32_ACUM_FOR_OUT - } -} - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace mmha - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct threads_per_value_t { - static const int value = Dh_MAX * sizeof(T) / 16; -}; -#ifdef ENABLE_FP8 -template -struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> { - static const int value = Dh_MAX * 4 / 16; // DEBUG: float v -}; -#endif - -template -void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); From 0edb1335eef0d3fa6dc019c65a8b6062b11b1bfb Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:51:09 -0700 Subject: [PATCH 132/135] commit --- ...coder_masked_groupedquery_attention_128.cu | 89 + ...coder_masked_groupedquery_attention_144.cu | 87 + ...coder_masked_groupedquery_attention_160.cu | 87 + ...coder_masked_groupedquery_attention_192.cu | 87 + ...coder_masked_groupedquery_attention_224.cu | 87 + ...coder_masked_groupedquery_attention_256.cu | 87 + ...ecoder_masked_groupedquery_attention_32.cu | 88 + ...ecoder_masked_groupedquery_attention_48.cu | 87 + ...ecoder_masked_groupedquery_attention_64.cu | 87 + ...ecoder_masked_groupedquery_attention_80.cu | 87 + ...ecoder_masked_groupedquery_attention_96.cu | 87 + ...masked_groupedquery_attention_template.hpp | 1878 +++++++++++++++++ 12 files changed, 2838 insertions(+) create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu create mode 100644 src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu new file mode 100644 index 000000000..24d8a3e91 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 128, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 128, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu new file mode 100644 index 000000000..350499c5b --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 144, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 144, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu new file mode 100644 index 000000000..8f392cf77 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 160, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 160, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu new file mode 100644 index 000000000..ff68387ea --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 192, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 192, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu new file mode 100644 index 000000000..ea66caf2d --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 224, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 224, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu new file mode 100644 index 000000000..12876a48e --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 256, 256, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 256, 256, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu new file mode 100644 index 000000000..3f877d3dc --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + //constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + constexpr bool DO_CROSS_ATTENTION = false; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 32, 32, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 32, 32, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu new file mode 100644 index 000000000..243886074 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 48, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 48, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu new file mode 100644 index 000000000..b30030d71 --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 64, 64, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 64, 64, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu new file mode 100644 index 000000000..c34a8546e --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 80, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 80, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu new file mode 100644 index 000000000..ac3b6369b --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu @@ -0,0 +1,87 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MGQA_LAUNCH_KERNEL( \ + T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + dim3 grid(params.num_heads, params.batch_size); \ + mmha::masked_groupedquery_attention_kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = threads_per_value_t::value; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + if (params.cache_indir == nullptr) { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, false, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, false, stream); + } + } + else { + if (tlength < 32) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, true, stream); + } + else if (tlength < 2048) { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, true, stream); + } + else { + MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, true, stream); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +template void mgqa_launch_kernel>( + const GroupedQuery_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +template void mgqa_launch_kernel<__nv_bfloat16, 96, 128, GroupedQuery_attention_params<__nv_bfloat16>>( + const GroupedQuery_attention_params<__nv_bfloat16>& params, const cudaStream_t& stream); +#endif +#ifdef ENABLE_FP8 +template void mgqa_launch_kernel<__nv_fp8_e4m3, 96, 128, GroupedQuery_attention_params<__nv_fp8_e4m3>>( + const GroupedQuery_attention_params<__nv_fp8_e4m3>& params, const cudaStream_t& stream); +#endif + +#undef MGQA_LAUNCH_KERNEL diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp new file mode 100644 index 000000000..581d566ca --- /dev/null +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp @@ -0,0 +1,1878 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" +#include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" +#include "src/fastertransformer/utils/cuda_bf16_wrapper.h" +#include "src/fastertransformer/utils/cuda_fp8_utils.h" +#include "src/fastertransformer/utils/cuda_type_utils.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +// #define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_m_ { +}; + +template<> +struct Qk_vec_m_ { + using Type = float; +}; +template<> +struct Qk_vec_m_ { + using Type = float2; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = float4; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_m_ { + using Type = uint2; +}; +template<> +struct Qk_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_m_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_m_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 32> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 64> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 128> { + using Type = fp8_4_t; +}; +template<> +struct Qk_vec_m_<__nv_fp8_e4m3, 256> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_k_ { + using Type = typename Qk_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 32> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 64> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 128> { + using Type = float4; +}; +template<> +struct Qk_vec_k_<__nv_fp8_e4m3, 256> { + using Type = float4; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_m_ { +}; + +template<> +struct K_vec_m_ { + using Type = float; +}; +template<> +struct K_vec_m_ { + using Type = float2; +}; +template<> +struct K_vec_m_ { + using Type = float4; +}; +template<> +struct K_vec_m_ { + using Type = uint32_t; +}; +template<> +struct K_vec_m_ { + using Type = uint2; +}; +template<> +struct K_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_m_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_m_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 + +// NOTE: THREADS_PER_KEY * sizeof(K_vec_m_) = 128 bytes +#ifdef ENABLE_FP8 +template<> +struct K_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct K_vec_m_<__nv_fp8_e4m3, 2> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_m_<__nv_fp8_e4m3, 1> { + using Type = fp8_4_t; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_k_ { + using Type = typename K_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct K_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct K_vec_k_<__nv_fp8_e4m3, 2> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +template<> +struct K_vec_k_<__nv_fp8_e4m3, 1> { + using Type = float4; +}; // Defined for compilation-purpose only, do not use +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_m_ { +}; + +template<> +struct V_vec_m_ { + using Type = float; +}; +template<> +struct V_vec_m_ { + using Type = float2; +}; +template<> +struct V_vec_m_ { + using Type = float4; +}; +template<> +struct V_vec_m_ { + using Type = uint32_t; +}; +template<> +struct V_vec_m_ { + using Type = uint2; +}; +template<> +struct V_vec_m_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_m_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_m_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +template<> +struct V_vec_m_<__nv_fp8_e4m3, 4> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 8> { + using Type = fp8_4_t; +}; +template<> +struct V_vec_m_<__nv_fp8_e4m3, 16> { + using Type = fp8_4_t; +}; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_k_ { + using Type = typename V_vec_m_::Type; +}; +#ifdef ENABLE_FP8 +template<> +struct V_vec_k_<__nv_fp8_e4m3, 4> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 8> { + using Type = float4; +}; +template<> +struct V_vec_k_<__nv_fp8_e4m3, 16> { + using Type = float4; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct Qk_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_FP8 +// template<> +// struct K_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct K_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif // MMHA_USE_FP32_ACUM_FOR_FMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#ifdef ENABLE_FP8 +// template<> +// struct V_vec_acum_fp32_ { +// using Type = float2; +// }; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +// template<> +// struct V_vec_acum_fp32_ { +// using Type = Float4_; +// }; +#endif // ENABLE_FP8 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) +{ + return x; +} +#ifdef ENABLE_FP8 +// fp8_t +template<> +__inline__ __device__ float vec_conversion(const __nv_fp8_e4m3& a) +{ + return float(a); +} +template<> +__inline__ __device__ __nv_fp8_e4m3 vec_conversion<__nv_fp8_e4m3, float>(const float& a) +{ + return __nv_fp8_e4m3(a); +} +// fp8_2_t +template<> +__inline__ __device__ float2 vec_conversion(const fp8_2_t& a) +{ + return float2(a); +} +template<> +__inline__ __device__ fp8_2_t vec_conversion(const float2& a) +{ + return fp8_2_t(a); +} +// fp8_4_t +template<> +__inline__ __device__ float4 vec_conversion(const fp8_4_t& a) +{ + return float4(a); +} +template<> +__inline__ __device__ fp8_4_t vec_conversion(const float4& a) +{ + return fp8_4_t(a); +} +#endif // ENABLE_FP8 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x), "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_FP8 +inline __device__ void convert_from_float(fp8_4_t& dst, float4 src) +{ + dst = fp8_4_t(src); +} +inline __device__ void convert_from_float(fp8_2_t& dst, float2 src) +{ + dst = fp8_2_t(src); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct kernel_type_t { + using Type = T; +}; + +#ifdef ENABLE_FP8 +template<> +struct kernel_type_t<__nv_fp8_e4m3> { + using Type = float; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const GroupedQuery_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + using Tk = typename kernel_type_t::Type; + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TDOD + logits_sz = div_up(max_timesteps + 1, 4) * 4 * sizeof(Tk); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(Tk) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool HAS_BEAMS> +__global__ void masked_groupedquery_attention_kernel(GroupedQuery_attention_params params) +{ + using Tk = typename kernel_type_t::Type; +#ifdef ENABLE_FP8 + // FP8 MHA Scales + constexpr bool FP8_MHA_KERNEL = std::is_same::value; +#else + constexpr bool FP8_MHA_KERNEL = false; +#endif + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(Tk) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += div_up(max_timesteps + 1, 4) * 16; + } + Tk* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + Tk* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec_k = typename Qk_vec_k_::Type; // with kernel-used precision + using Qk_vec_m = typename Qk_vec_m_::Type; // with memory-used precision + + // Use alignment for safely casting the shared buffers as Qk_vec_k. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec_k)) Tk q_smem[Dh_MAX]; + + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8/16 for FP32/FP16/FP8. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec_m); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + const int head_n_rep = params.num_heads / params.num_kv_heads; + // const int head_n_rep = 1; + // The head. + const int hi = blockIdx.x; + const int kvhi = hi / head_n_rep; + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + const int bkvhi = bi * params.num_kv_heads + kvhi; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbkvhi = bbi * params.beam_width * params.num_kv_heads + kvhi; + // The thread in the block. + const int tidx = threadIdx.x; + + constexpr bool handle_kv = true; + + // here. + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + int tlength = (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + // The offset in the bias buffer. + int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + + const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; + const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + + // Trigger the loads from the Q and K buffers. + Qk_vec_k q; + zero(q); + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto q_scaling = params.qkv_scale_out[0]; + const auto q_quant = + *reinterpret_cast(&reinterpret_cast(params.q)[qk_offset]); + + convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); + } + else { + q = vec_conversion(*reinterpret_cast(¶ms.q[qk_offset])); + } + } + + Qk_vec_k k; + zero(k); + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[qk_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); + } + else { + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + vec_conversion(*reinterpret_cast(¶ms.k[qk_offset])) : + k; + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec_k q_bias; + zero(q_bias); + q_bias = + (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.q_bias[qk_bias_offset])) : + q_bias; + + Qk_vec_k k_bias; + zero(k_bias); + if (handle_kv) { + k_bias = + !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + vec_conversion(*reinterpret_cast(¶ms.k_bias[qk_bias_offset])) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_kv) { + k = add(k, k_bias); + } + if (do_ia3 && !is_masked) { + k = mul( + k, + vec_conversion(*reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE]))); + } + + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_kv) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.timestep - padd_len); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_kv) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_kv) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep - padd_len); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, params.timestep); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_kv) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + // int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // // params.timestep*QK_ELTS_IN_16B + + // tlength_circ * QK_ELTS_IN_16B + ci; + int offset = bkvhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (handle_kv && bhi%head_n_rep==0) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = vec_conversion(k); + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec_k; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. + + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec_k = typename K_vec_k_::Type; + using K_vec_m = typename K_vec_m_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec_m) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec_k q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + // T* k_cache = ¶ms.k_cache[bkvhi * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbkvhi * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const int* beam_indices = HAS_BEAMS ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // The keys loaded from the key cache. + K_vec_k k[K_VECS_PER_THREAD]; + K_vec_k k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (HAS_BEAMS) { + // const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + const int beam_offset = beam_indices[ti_circ] * params.num_kv_heads * params.memory_max_len * Dh; + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); + } + else { + k[ii] = vec_conversion( + (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + if (params.linear_bias_slopes != nullptr) { + // Apply the linear position bias: (ki - qi) * slope[hi]. + // The padding token locates between the input context and the generated tokens. + // We need to remove the number of padding tokens in the distance computation. + // ti : 0 1 2 3 4 5 6 7 8 9(tlength) + // token: i i i i p p p o o o where i=input, p=pad, o=output. + // e.g. ti = 2, dist = (9 - 3) - 2 = 4. + int max_context_length = params.max_prefix_prompt_length + params.max_input_length; + float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; + + qk += mul(params.linear_bias_slopes[hi], dist); + } + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; +#ifdef FP8_MHA + float logit = 0.f; + if (FP8_MHA_KERNEL) { + logit = is_mask ? 0.f : + __expf((qk_smem[ti - first_step] - qk_max) * params.query_weight_output_scale[0] + * params.query_weight_output_scale[0]); + } + else { + logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + } +#else + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); +#endif + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec_k = typename V_vec_k_::Type; + using V_vec_m = typename V_vec_m_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + // The base pointer for the value in the cache buffer. + // if (bkvhi == 63) { + // printf("%d %d %d %d %d\n", bkvhi, params.memory_max_len, Dh, vi, (bkvhi * params.memory_max_len * Dh + vi)); + // } + T* v_cache = ¶ms.v_cache[bkvhi * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbkvhi * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec_k v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = vec_conversion( + *reinterpret_cast(¶ms.v_bias[hi * Dh + vi])); + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec_k; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + + // Separate the ti < memory_max_len and ti > memory_max_len + // to prevent ti % memory_len when ti < memory_len, and + // the compiler cannot optimize the codes automatically. + const int min_length = min(tlength, params.memory_max_len); + for (int ti = first_step + vo; ti < min_length; ti += V_PER_ITER) { + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti] : 0; + // const int beam_offset = HAS_BEAMS ? beam_src * params.num_heads * params.memory_max_len * Dh : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti * Dh])); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + if (ti < params.memory_max_len) { + // handled by previous loop + continue; + } + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = HAS_BEAMS ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = HAS_BEAMS ? beam_src * params.num_kv_heads * params.memory_max_len * Dh : 0; + // Load the values from the cache. + V_vec_k v = vec_conversion( + *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh])); + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = vec_conversion(vec_conversion(mul(1.0f / + // params.attention_qk_scale[0], logits_smem[ti]))); + logit = logits_smem[ti - first_step]; + } + else { + logit = logits_smem[ti - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + Tk logit = logits_smem[ti - first_step]; + out = fma(logit, v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec_k v; + // Trigger the loads from the V buffer. + const auto v_offset = qkv_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); + } + else { + v = vec_conversion(*reinterpret_cast(¶ms.v[v_offset])); + } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + + // Compute the V values with bias. + v = add(v, v_bias); + + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + if (bhi % head_n_rep == 0) { + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = vec_conversion(v); + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else // MMHA_USE_FP32_ACUM_FOR_LOGITS + // out = fma(logits_smem[params.timestep], v, out); +#ifdef FP8_MHA + Tk logit; + if (FP8_MHA_KERNEL) { + // NOTE: fake quantization + // logit = mul(1.0f / params.attention_qk_scale[0], logits_smem[tlength]); + logit = logits_smem[tlength - first_step]; + } + else { + logit = logits_smem[tlength - first_step]; + } + out = fma(logit, v, out); +#else // FP8_MHA + out = fma(logits_smem[tlength - first_step], v, out); +#endif // FP8_MHA +#endif // MMHA_USE_FP32_ACUM_FOR_LOGITS + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + if (FP8_MHA_KERNEL) { +#ifdef FP8_MHA + // float result_scale = params.attention_qk_scale[0] * params.query_weight_output_scale[0] * + // params.attention_output_weight_input_scale_inv[0]; + float result_scale = + params.query_weight_output_scale[0] * params.attention_output_weight_input_scale_inv[0]; + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), + mul(result_scale, out)); +#endif // FP8_MHA + } + else if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + out = mul(*params.attention_out_scale, out); + *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = + cast_to_int8(out); + } + else { + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); + } +#else // MMHA_USE_FP32_ACUM_FOR_OUT + // TODO: support int8_mode? + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = vec_conversion(out); +#endif // MMHA_USE_FP32_ACUM_FOR_OUT + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct threads_per_value_t { + static const int value = Dh_MAX * sizeof(T) / 16; +}; +#ifdef ENABLE_FP8 +template +struct threads_per_value_t<__nv_fp8_e4m3, Dh_MAX> { + static const int value = Dh_MAX * 4 / 16; // DEBUG: float v +}; +#endif + +template +void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); From 56ad958309b5f4b734e6d216b5dbdd198fb09c92 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:53:11 -0700 Subject: [PATCH 133/135] commit --- .../decoder_masked_groupedquery_attention_128.cu | 2 +- .../decoder_masked_groupedquery_attention_144.cu | 2 +- .../decoder_masked_groupedquery_attention_160.cu | 2 +- .../decoder_masked_groupedquery_attention_192.cu | 2 +- .../decoder_masked_groupedquery_attention_224.cu | 2 +- .../decoder_masked_groupedquery_attention_256.cu | 2 +- .../decoder_masked_groupedquery_attention_32.cu | 2 +- .../decoder_masked_groupedquery_attention_48.cu | 2 +- .../decoder_masked_groupedquery_attention_64.cu | 2 +- .../decoder_masked_groupedquery_attention_80.cu | 2 +- .../decoder_masked_groupedquery_attention_96.cu | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu index 24d8a3e91..9f9f7ca3f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu index 350499c5b..6da6da083 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu index 8f392cf77..bde08b41d 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu index ff68387ea..7fa77808f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu index ea66caf2d..8fdf2e1a5 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu index 12876a48e..359bd9214 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu index 3f877d3dc..827efd738 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu index 243886074..cb7abfbcc 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu index b30030d71..4f3105526 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu index c34a8546e..81645f4fd 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu index ac3b6369b..c8a978952 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "decoder_masked_multihead_attention_template.hpp" +#include "decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" From 7caf88b79f075eeec47f213b283f1757db7b63b9 Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 00:53:40 -0700 Subject: [PATCH 134/135] commit --- .../kernels/llama/decoder_masked_groupedquery_attention.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu index ec6bb68e8..1ec0b3d53 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.cu @@ -15,7 +15,7 @@ */ #include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention.h" -#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_multihead_attention_template.hpp" +#include "src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_template.hpp" #include "src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h" #include "src/fastertransformer/utils/cuda_bf16_wrapper.h" #include From a4b6dd938d8eeb5b805748cb547460dc9ba97fdc Mon Sep 17 00:00:00 2001 From: sfc-gh-zhwang Date: Mon, 11 Sep 2023 19:22:02 -0700 Subject: [PATCH 135/135] commit --- .vscode/settings.json | 3 ++- .../decoder_masked_groupedquery_attention_128.cu | 5 +---- .../decoder_masked_groupedquery_attention_144.cu | 4 +--- .../decoder_masked_groupedquery_attention_160.cu | 4 +--- .../decoder_masked_groupedquery_attention_192.cu | 4 +--- .../decoder_masked_groupedquery_attention_224.cu | 4 +--- .../decoder_masked_groupedquery_attention_256.cu | 4 +--- .../decoder_masked_groupedquery_attention_32.cu | 5 +---- .../decoder_masked_groupedquery_attention_48.cu | 4 +--- .../decoder_masked_groupedquery_attention_64.cu | 4 +--- .../decoder_masked_groupedquery_attention_80.cu | 4 +--- .../decoder_masked_groupedquery_attention_96.cu | 3 +-- 12 files changed, 13 insertions(+), 35 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 4b17335e2..bb1913955 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -91,6 +91,7 @@ "__string": "cpp", "compare": "cpp", "concepts": "cpp", - "filesystem": "cpp" + "filesystem": "cpp", + "__memory": "cpp" } } diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu index 9f9f7ca3f..3c96b2ce5 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_128.cu @@ -38,14 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; - // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength); + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu index 6da6da083..7e20bdccc 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_144.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu index bde08b41d..57c6dd1aa 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_160.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu index 7fa77808f..d8c349cad 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_192.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu index 8fdf2e1a5..03ff2cadd 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_224.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu index 359bd9214..fe496d4a7 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_256.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu index 827efd738..ceeb96484 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_32.cu @@ -38,14 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - //constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - constexpr bool DO_CROSS_ATTENTION = false; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu index cb7abfbcc..f225bef82 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_48.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu index 4f3105526..7a9679952 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_64.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = false; // std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu index 81645f4fd..8af12155f 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_80.cu @@ -38,13 +38,11 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; - constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream); diff --git a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu index c8a978952..f91209194 100644 --- a/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu +++ b/src/fastertransformer/kernels/llama/decoder_masked_groupedquery_attention/decoder_masked_groupedquery_attention_96.cu @@ -38,13 +38,12 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// -// !!! Specialize the launcher for Cross attention template void mgqa_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) { constexpr int THREADS_PER_VALUE = threads_per_value_t::value; constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; - int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + int tlength = params.timestep; if (params.cache_indir == nullptr) { if (tlength < 32) { MGQA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, false, stream);