diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index fe3291d6eb..321f02c951 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -73,6 +73,7 @@ std::vector AppendAttentionKernel( const paddle::optional& out_linear_shifts, const paddle::optional& out_linear_smooths, const paddle::optional& kv_signal_data, + paddle::Tensor& fmha_out, const std::string& cache_quant_type_str, const bool use_neox_rotary_style, const bool rope_3d, @@ -118,27 +119,6 @@ std::vector AppendAttentionKernel( } else { qkv_out = qkv; } - paddle::Tensor fmha_out; - if (out_linear_in_scale > 0.0) { - if (fabs(quant_max_bound - 127.0f) < 0.000001) { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::INT8, - qkv.place()); - } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - paddle::DataType::FLOAT8_E4M3FN, - qkv.place()); - }else{ - PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); - } - } else { - fmha_out = GetEmptyTensor( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, - D, - qkv.place()); - } auto dispatch_CascadeAppendAttentionKernel = [&](auto temp_args, const paddle::Tensor& lambda_batch_ids, @@ -393,7 +373,7 @@ std::vector AppendAttentionKernel( } } - return {fmha_out, qkv_out}; + return {fmha_out}; } std::vector AppendAttention( @@ -464,6 +444,53 @@ std::vector AppendAttention( meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = cum_offsets.dims()[0]; + // template dtype generation + phi::DataType dtype_id; + switch (qkv.dtype()) { + case paddle::DataType::FLOAT16: {dtype_id = phi::DataType::FLOAT16; break;} + case paddle::DataType::BFLOAT16: {dtype_id = phi::DataType::BFLOAT16; break;} + case paddle::DataType::INT32: { + if (compute_dtype == "bf16") { + dtype_id = phi::DataType::BFLOAT16; + break; + } else if (compute_dtype == "fp16") { + dtype_id = phi::DataType::FLOAT16; + break; + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + break; + } + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + } + // fmha_out generation + paddle::Tensor fmha_out; + if (out_linear_in_scale > 0.0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::INT8, + qkv.place()); + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + paddle::DataType::FLOAT8_E4M3FN, + qkv.place()); + } else{ + PD_THROW("Only supported attr of quant_max_bound in ['127', '448']."); + } + } else { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + dtype_id, + qkv.place()); + } + auto dispatch_by_template = [&](auto temp_args) -> std::vector { return AppendAttentionKernel::value>( meta_data, @@ -500,6 +527,143 @@ std::vector AppendAttention( out_linear_shifts, out_linear_smooths, kv_signal_data, + fmha_out, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder); + }; + + phi::dtype::float16 fp16_dtype; + phi::dtype::bfloat16 bp16_dtype; + switch (dtype_id){ + case phi::DataType::FLOAT16: { + return dispatch_by_template(fp16_dtype); + } + case phi::DataType::BFLOAT16: { + return dispatch_by_template(bp16_dtype); + } + default: + break; + } +} + +std::vector AppendAttentionWithOutput( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& set_max_lengths, + const paddle::Tensor& max_len_kv, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& kv_signal_data, + paddle::Tensor& fmha_out, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + // TODO: trick method support c4, add attr head_dims in the future + if (cache_quant_type_str == "cache_int4_zp") { + meta_data.head_dims *= 2; + } + const int total_num_head = + qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; + meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = cum_offsets.dims()[0]; + + auto dispatch_by_template = [&](auto temp_args) -> std::vector { + return AppendAttentionKernel::value>( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + max_len_kv, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + kv_signal_data, + fmha_out, cache_quant_type_str, use_neox_rotary_style, rope_3d, @@ -521,11 +685,16 @@ std::vector AppendAttention( phi::dtype::bfloat16 bp16_dtype; switch (qkv.dtype()) { - case paddle::DataType::FLOAT16: return dispatch_by_template(fp16_dtype); - case paddle::DataType::BFLOAT16: return dispatch_by_template(bp16_dtype); + case paddle::DataType::FLOAT16: { + return dispatch_by_template(fp16_dtype); + break; + } + case paddle::DataType::BFLOAT16: { + return dispatch_by_template(bp16_dtype); + break;} case paddle::DataType::INT32: { if (compute_dtype == "bf16") { - return dispatch_by_template(bp16_dtype); + return dispatch_by_template(bp16_dtype); } else if (compute_dtype == "fp16") { return dispatch_by_template(fp16_dtype); } else { @@ -540,7 +709,7 @@ std::vector AppendAttention( break; } } - return {paddle::Tensor{}}; + return {paddle::Tensor()}; } std::vector> AppendAttentionInferShape( @@ -600,7 +769,7 @@ std::vector> AppendAttentionInferShape( } const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim; const int num_heads = total_num_head - 2 * kv_num_heads; - return {{token_num, num_heads * head_dim}, qkv_shape}; + return {{token_num, num_heads * head_dim}}; } std::vector AppendAttentionInferDtype( @@ -655,32 +824,139 @@ std::vector AppendAttentionInferDtype( if (compute_dtype == "bf16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16}; + return {paddle::DataType::FLOAT8_E4M3FN}; }else{ PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { - return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16}; + return {paddle::DataType::BFLOAT16}; } } else if (compute_dtype == "fp16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { - return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + return {paddle::DataType::INT8}; } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { - return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16}; + return {paddle::DataType::FLOAT8_E4M3FN}; }else{ PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); } } else { - return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16}; + return {paddle::DataType::FLOAT16}; } } else { PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); } } +std::vector> AppendAttentionWithOutputInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& padding_offsets_shape, + const std::vector& cum_offsets_shape, + const std::vector& block_tables_shape, + const std::vector& encoder_batch_ids_shape, + const std::vector& encoder_tile_ids_per_batch_shape, + const std::vector& encoder_num_blocks_shape, + const std::vector& kv_batch_ids_shape, + const std::vector& kv_tile_ids_per_batch_shape, + const std::vector& kv_num_blocks_shape, + const std::vector& decoder_batch_ids_shape, + const std::vector& decoder_tile_ids_per_batch_shape, + const std::vector& decoder_num_blocks_shape, + const std::vector& set_max_lengths_shape, + const std::vector& max_len_kv_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& qkv_out_scales_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& out_linear_shifts_shape, + const paddle::optional>& out_linear_smooths_shape, + const paddle::optional>& kv_signal_data_shape, + const std::vector& fmha_out_shape, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + return {fmha_out_shape}; +} + +std::vector AppendAttentionWithOuputInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& padding_offsets_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& encoder_batch_ids_dtype, + const paddle::DataType& encoder_tile_ids_per_batch_dtype, + const paddle::DataType& encoder_num_blocks_dtype, + const paddle::DataType& kv_batch_ids_dtype, + const paddle::DataType& kv_tile_ids_per_batch_dtype, + const paddle::DataType& kv_num_blocks_dtype, + const paddle::DataType& decoder_batch_ids_dtype, + const paddle::DataType& decoder_tile_ids_per_batch_dtype, + const paddle::DataType& decoder_num_blocks_dtype, + const paddle::DataType& set_max_lengths_dtype, + const paddle::DataType& max_len_kv_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& qkv_out_scales_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& out_linear_shifts_dtype, + const paddle::optional& out_linear_smooths_dtype, + const paddle::optional& kv_signal_data_dtype, + const paddle::DataType& fmha_out_dtype, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + + return {fmha_out_dtype}; +} + PD_BUILD_STATIC_OP(append_attention) .Inputs({"qkv", "key_cache", @@ -715,7 +991,7 @@ PD_BUILD_STATIC_OP(append_attention) paddle::Optional("out_linear_shifts"), paddle::Optional("out_linear_smooths"), paddle::Optional("kv_signal_data")}) - .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) + .Outputs({"fmha_out", "key_cache_out", "value_cache_out"}) .SetInplaceMap({{"key_cache", "key_cache_out"}, {"value_cache", "value_cache_out"}}) .Attrs({"compute_type: std::string", @@ -736,3 +1012,61 @@ PD_BUILD_STATIC_OP(append_attention) .SetKernelFn(PD_KERNEL(AppendAttention)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionInferDtype)); + +PD_BUILD_STATIC_OP(append_attention_with_output) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "padding_offsets", + "cum_offsets", + "block_tables", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks", + "set_max_lengths", + "max_len_kv", + paddle::Optional("rotary_embs"), + paddle::Optional("attn_mask"), + paddle::Optional("qkv_bias"), + paddle::Optional("qkv_out_scales"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("out_linear_shifts"), + paddle::Optional("out_linear_smooths"), + paddle::Optional("kv_signal_data"), + "fmha_out"}) + .Outputs({"key_cache_out", "value_cache_out", "fmha_out_out"}) + .SetInplaceMap({{"key_cache", "key_cache_out"}, + {"value_cache", "value_cache_out"}, + {"fmha_out", "fmha_out_out"}}) + .Attrs({"compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "rope_3d: bool", + "max_input_length: int", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "encoder_block_shape_q: int", + "decoder_block_shape_q: int", + "max_partition_size: int", + "encoder_max_partition_size: int", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool"}) + .SetKernelFn(PD_KERNEL(AppendAttentionWithOutput)) + .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionWithOutputInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AppendAttentionWithOuputInferDtype)); \ No newline at end of file diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 5eb56c14f4..46620c7ff9 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -50,7 +50,8 @@ void cuda_host_free(uintptr_t ptr) { } std::vector AppendAttention( - const paddle::Tensor &qkv, const paddle::Tensor &key_cache, + const paddle::Tensor &qkv, + const paddle::Tensor &key_cache, const paddle::Tensor &value_cache, const paddle::Tensor &seq_lens_encoder, const paddle::Tensor &seq_lens_decoder, const paddle::Tensor &seq_lens_this_time, @@ -87,6 +88,57 @@ std::vector AppendAttention( const int speculate_max_draft_token_num, const bool causal, const bool speculate_decoder); +std::vector AppendAttentionWithOutput( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& set_max_lengths, + const paddle::Tensor& max_len_kv, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const paddle::optional& kv_signal_data, + paddle::Tensor& fmha_out, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const bool rope_3d, + const int max_input_length, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int encoder_block_shape_q, + const int decoder_block_shape_q, + const int max_partition_size, + const int encoder_max_partition_size, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder); + std::vector GQARopeWriteCacheKernel( const paddle::Tensor &qkv, const paddle::Tensor &key_cache, const paddle::Tensor &value_cache, const paddle::Tensor &cu_seqlens_q, @@ -521,6 +573,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { * append_attention */ m.def("append_attention", &AppendAttention, "append attention function"); + m.def("append_attention_with_output", &AppendAttentionWithOutput, "append attention function with output"); /** * gqa_rope_write_cache.cu * gqa_rope_write_cache diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index eb82e0bf98..0174f52549 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -23,8 +23,9 @@ import paddle from fastdeploy.model_executor.layers.attention.ops import ( - append_attention, get_block_shape_and_split_kv_block, - init_signal_layerwise, open_shm_and_get_meta_signal) + append_attention, append_attention_with_output, + get_block_shape_and_split_kv_block, init_signal_layerwise, + open_shm_and_get_meta_signal) if TYPE_CHECKING: from paddle._typing.dtype_like import _DTypeLiteral @@ -74,7 +75,7 @@ class AppendAttentionBackend(AttentionBackend): """ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, - head_dim: int) -> None: + head_dim: int, use_output: bool = True) -> None: """ AppendAttentionBackend __init__ """ @@ -106,6 +107,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, self.start_layer_index: int = fd_config.model_config.start_layer_index self.device_id: int = os.getenv("CUDA_VISIBLE_DEVICES", None) + self.use_output = use_output if fd_config.parallel_config.expert_parallel_rank is None: fd_config.parallel_config.expert_parallel_rank = 0 device_id = self.rank + fd_config.parallel_config.tensor_parallel_degree * \ @@ -202,55 +204,146 @@ def forward_mixed( layer.layer_id] = init_signal_layerwise( metadata.kv_signal_metadata, layer.layer_id + self.start_layer_index) - - res = append_attention( - qkv, - forward_meta.caches[2 * layer.layer_id], - forward_meta.caches[2 * layer.layer_id + 1], - forward_meta.seq_lens_encoder, - forward_meta.seq_lens_decoder, - forward_meta.seq_lens_this_time, - forward_meta.padding_offset, - forward_meta.cum_offsets, - metadata.block_tables, - metadata.encoder_batch_ids, - metadata.encoder_tile_ids_per_batch, - metadata.encoder_num_blocks, - metadata.kv_batch_ids, - metadata.kv_tile_ids_per_batch, - metadata.kv_num_blocks, - forward_meta.decoder_batch_ids, # from buffer - forward_meta.decoder_tile_ids_per_batch, # from buffer - metadata.decoder_num_blocks, - metadata.set_max_lengths, - metadata.max_len_kv, - metadata.rotary_embs, - metadata.attn_mask, - layer.qkv_bias, - layer.qkv_scale, - getattr(layer, "cache_k_scale", None), - getattr(layer, "cache_v_scale", None), - getattr(layer, "cache_k_out_scale", None), - getattr(layer, "cache_v_out_scale", None), - getattr(layer, "cache_k_zp", None), - getattr(layer, "cache_v_zp", None), - layer.linear_shift, - layer.linear_smooth, - metadata.kv_signal_data_list[layer.layer_id], - metadata._fuse_kernel_compute_dtype, - getattr(layer, "cache_quant_type_str", "none"), - layer.use_neox_rotary_style, - self.rope_3d, - self.max_seq_len, - getattr(layer, "quant_max_bound", 0.0), - getattr(layer, "quant_min_bound", 0.0), - getattr(layer, "out_scale", -1.0), - metadata.encoder_block_shape_q, - metadata.decoder_block_shape_q, - metadata.max_partition_size, - metadata.encoder_max_partition_size, - self.speculate_max_draft_token_num + 1, - self.causal, - self.speculative_method is not None, - )[0] - return res + quant_max_bound = getattr(layer, "quant_max_bound", 0.0) + quant_min_bound = getattr(layer, "quant_min_bound", 0.0) + cache_quant_type = getattr(layer, "cache_quant_type_str", "none") + if self.use_output: + compute_type = metadata._fuse_kernel_compute_dtype + ## 1. get output datatype + qkv_dtype = qkv.dtype + if qkv_dtype == paddle.float16: + D_type = paddle.float16 + elif qkv_dtype == paddle.bfloat16: + D_type = paddle.bfloat16 + elif qkv_dtype == paddle.int32: + if compute_type == "bf16": + D_type = paddle.bfloat16 + elif compute_type == "fp16": + D_type = paddle.float16 + else: + raise NotImplementedError( + "Only supported attr of qkv_type in ['float16', 'bfloat16'].") + else: + raise NotImplementedError( + "Only supported attr of qkv_type in ['float16', 'bfloat16', 'int32'].") + ## 2.Extract related parameters + out_linear_in_scale = getattr(layer, "out_linear_in_scale", -1.0) + token_nums = qkv.shape[0] + head_dims = self.head_dim if cache_quant_type != "cache_int4_zp" else self.head_dim * 2 + q_num_heads = self.num_heads + # 3. generate output tensor of different dtypes + if out_linear_in_scale > 0.0: + if abs(quant_max_bound - 127) < 0.000001: + output = paddle.empty([token_nums, q_num_heads * head_dims], + dtype='int8').to(qkv.place) + elif abs(quant_max_bound - 448) < 0.000001: + output = paddle.empty([token_nums, q_num_heads * head_dims], + dtype='float8_e4m3fn').to(qkv.place) + else: + raise NotImplementedError( + "Only supported attr of quant_max_bound in ['127', '448'].") + else: + output = paddle.empty([token_nums, q_num_heads * head_dims], dtype=D_type).to(qkv.place) + + append_attention_with_output( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + forward_meta.decoder_batch_ids, # from buffer + forward_meta.decoder_tile_ids_per_batch, # from buffer + metadata.decoder_num_blocks, + metadata.set_max_lengths, + metadata.max_len_kv, + metadata.rotary_embs, + metadata.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + metadata.kv_signal_data_list[layer.layer_id], + output, + metadata._fuse_kernel_compute_dtype, + cache_quant_type, + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + quant_max_bound, + quant_min_bound, + layer.out_scale, + metadata.encoder_block_shape_q, + metadata.decoder_block_shape_q, + metadata.max_partition_size, + metadata.encoder_max_partition_size, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None) + else: + output = append_attention( + qkv, + forward_meta.caches[2 * layer.layer_id], + forward_meta.caches[2 * layer.layer_id + 1], + forward_meta.seq_lens_encoder, + forward_meta.seq_lens_decoder, + forward_meta.seq_lens_this_time, + forward_meta.padding_offset, + forward_meta.cum_offsets, + metadata.block_tables, + metadata.encoder_batch_ids, + metadata.encoder_tile_ids_per_batch, + metadata.encoder_num_blocks, + metadata.kv_batch_ids, + metadata.kv_tile_ids_per_batch, + metadata.kv_num_blocks, + forward_meta.decoder_batch_ids, # from buffer + forward_meta.decoder_tile_ids_per_batch, # from buffer + metadata.decoder_num_blocks, + metadata.set_max_lengths, + metadata.max_len_kv, + metadata.rotary_embs, + metadata.attn_mask, + layer.qkv_bias, + layer.qkv_scale, + getattr(layer, "cache_k_scale", None), + getattr(layer, "cache_v_scale", None), + getattr(layer, "cache_k_out_scale", None), + getattr(layer, "cache_v_out_scale", None), + getattr(layer, "cache_k_zp", None), + getattr(layer, "cache_v_zp", None), + layer.linear_shift, + layer.linear_smooth, + metadata.kv_signal_data_list[layer.layer_id], + metadata._fuse_kernel_compute_dtype, + cache_quant_type, + layer.use_neox_rotary_style, + self.rope_3d, + self.max_seq_len, + quant_max_bound, + quant_min_bound, + layer.out_scale, + metadata.encoder_block_shape_q, + metadata.decoder_block_shape_q, + metadata.max_partition_size, + metadata.encoder_max_partition_size, + self.speculate_max_draft_token_num + 1, + self.causal, + self.speculative_method is not None, + )[0] + return output diff --git a/fastdeploy/model_executor/layers/attention/ops/__init__.py b/fastdeploy/model_executor/layers/attention/ops/__init__.py index 8b75ce6f0e..43c00c3774 100644 --- a/fastdeploy/model_executor/layers/attention/ops/__init__.py +++ b/fastdeploy/model_executor/layers/attention/ops/__init__.py @@ -14,13 +14,14 @@ # limitations under the License. """ -from .append_attention import append_attention +from .append_attention import append_attention, append_attention_with_output from .get_block_shape_and_split_kv_block import \ get_block_shape_and_split_kv_block from .init_signal_layerwise import init_signal_layerwise from .open_shm_and_get_meta_signal import open_shm_and_get_meta_signal __all__ = [ - "get_block_shape_and_split_kv_block", "append_attention", + "get_block_shape_and_split_kv_block", + "append_attention", "append_attention_with_output", "open_shm_and_get_meta_signal", "init_signal_layerwise" ] diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index b488451a9f..8ae4eb0d16 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -23,6 +23,8 @@ if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import \ append_attention as append_attention_gpu + from fastdeploy.model_executor.ops.gpu import \ + append_attention_with_output as append_attention_with_output_gpu def append_attention( @@ -78,8 +80,124 @@ def append_attention( """ append_attention """ + + if current_platform.is_cuda(): + fmha_out = append_attention_gpu( + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + set_max_lengths, + max_len_kv, + rotary_embs, + attn_mask, + qkv_bias, + qkv_scale, + k_quant_scale, + v_quant_scale, + k_dequant_scale, + v_dequant_scale, + cache_k_zp, + cache_v_zp, + linear_shift, + linear_smooth, + kv_signal_data, + compute_type, + cache_quant_type, + use_neox_rotary_style, + rope_3d, + max_input_length, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + encoder_block_shape_q, + decoder_block_shape_q, + max_partition_size, + encoder_max_partition_size, + speculate_max_draft_token_num, + causal, + speculate_decoder, + ) + return fmha_out + else: + raise NotImplementedError() + + + +def append_attention_with_output( + qkv: paddle.Tensor, + key_cache: paddle.Tensor, + value_cache: paddle.Tensor, + seq_lens_encoder: paddle.Tensor, + seq_lens_decoder: paddle.Tensor, + seq_lens_this_time: paddle.Tensor, + padding_offsets: paddle.Tensor, + cum_offsets: paddle.Tensor, + block_tables: paddle.Tensor, + encoder_batch_ids: paddle.Tensor, + encoder_tile_ids_per_batch: paddle.Tensor, + encoder_num_blocks: paddle.Tensor, + kv_batch_ids: paddle.Tensor, + kv_tile_ids_per_batch: paddle.Tensor, + kv_num_blocks: paddle.Tensor, + decoder_batch_ids: paddle.Tensor, + decoder_tile_ids_per_batch: paddle.Tensor, + decoder_num_blocks: paddle.Tensor, + set_max_lengths: paddle.Tensor, + max_len_kv: paddle.Tensor, + rotary_embs: Optional[paddle.Tensor] = None, + attn_mask: Optional[paddle.Tensor] = None, + qkv_bias: Optional[paddle.Tensor] = None, + qkv_scale: Optional[paddle.Tensor] = None, + k_quant_scale: Optional[paddle.Tensor] = None, + v_quant_scale: Optional[paddle.Tensor] = None, + k_dequant_scale: Optional[paddle.Tensor] = None, + v_dequant_scale: Optional[paddle.Tensor] = None, + cache_k_zp: Optional[paddle.Tensor] = None, + cache_v_zp: Optional[paddle.Tensor] = None, + linear_shift: Optional[paddle.Tensor] = None, + linear_smooth: Optional[paddle.Tensor] = None, + kv_signal_data: Optional[paddle.Tensor] = None, + fmha_out: Optional[paddle.Tensor] = None, + compute_type: str = "bf16", + cache_quant_type: str = "none", + use_neox_rotary_style: bool = False, + rope_3d: bool = False, + max_input_length: int = 0, + quant_max_bound: float = 0.0, + quant_min_bound: float = 0.0, + out_linear_in_scale: float = -1.0, + encoder_block_shape_q: int = 64, + decoder_block_shape_q: int = 16, + max_partition_size: int = 32768, + encoder_max_partition_size: int = 32768, + speculate_max_draft_token_num: int = 1, + causal: bool = True, + speculate_decoder: bool = False, +): + """ + append_attention + """ + ## fmha_output can't be None in `append_attention_with_output` + assert fmha_out is not None, ValueError( + "fmha_out must not be None when append_attention_with_output is selected.") + if current_platform.is_cuda(): - out = append_attention_gpu( + append_attention_with_output_gpu( qkv, key_cache, value_cache, @@ -113,6 +231,7 @@ def append_attention( linear_shift, linear_smooth, kv_signal_data, + fmha_out, compute_type, cache_quant_type, use_neox_rotary_style, @@ -129,6 +248,6 @@ def append_attention( causal, speculate_decoder, ) - return out + return [fmha_out] else: raise NotImplementedError()