From 07ea00e1b940e8e5406fdae143dcc7fdea45e2dc Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 9 Jul 2025 15:50:56 +0800 Subject: [PATCH 1/4] [SOT] Mark dynamic dims by type annotations --- .../forward_meta.py | 22 +-- .../graph_optimization/dynamic_dims_marker.py | 155 ++++++++++++++++++ .../graph_optimization_backend.py | 66 +++++++- .../layers/attention/__init__.py | 3 +- .../layers/attention/append_attn_backend.py | 35 ++-- .../layers/attention/attention.py | 13 +- .../attention/base_attention_backend.py | 14 +- .../layers/attention/flash_attn_backend.py | 33 ++-- .../layers/attention/mla_attention_backend.py | 41 +++-- .../layers/attention/native_paddle_backend.py | 12 +- .../layers/attention/xpu_attn_backend.py | 37 +++-- .../model_executor/models/deepseek_v3.py | 2 +- .../model_executor/models/ernie4_5_moe.py | 10 +- .../model_executor/models/ernie4_5_mtp.py | 2 +- .../models/ernie4_5_vl/ernie4_5_vl_moe.py | 2 +- fastdeploy/model_executor/models/qwen2.py | 2 +- fastdeploy/model_executor/models/qwen3.py | 2 +- fastdeploy/model_executor/models/qwen3moe.py | 2 +- fastdeploy/spec_decode/mtp.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 2 +- fastdeploy/worker/vl_gpu_model_runner.py | 2 +- fastdeploy/worker/xpu_model_runner.py | 2 +- test/layers/test_attention.py | 2 +- test/worker/test_cuda_graph.py | 2 +- 24 files changed, 352 insertions(+), 113 deletions(-) rename fastdeploy/{worker => model_executor}/forward_meta.py (94%) create mode 100644 fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py diff --git a/fastdeploy/worker/forward_meta.py b/fastdeploy/model_executor/forward_meta.py similarity index 94% rename from fastdeploy/worker/forward_meta.py rename to fastdeploy/model_executor/forward_meta.py index a1007f4e11..f03ecc40d5 100644 --- a/fastdeploy/worker/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -18,14 +18,16 @@ import logging from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, Annotated +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import ( + DynamicDims +) +from fastdeploy.model_executor.layers.attention import (Attention, + AttentionBackend) import numpy as np import paddle -if TYPE_CHECKING: - from fastdeploy.model_executor.layers.attention import (Attention, - AttentionBackend) logger = logging.getLogger(__name__) @@ -282,7 +284,7 @@ class ForwardMeta(): forward_mode: ForwardMode = ForwardMode.MIXED # - ids_remove_padding: paddle.Tensor = None + ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)] = None # seq_lens_encoder: Optional[paddle.Tensor] = None @@ -300,13 +302,13 @@ class ForwardMeta(): block_tables: Optional[paddle.Tensor] = None # - attn_backend: 'AttentionBackend' = None + attn_backend: AttentionBackend = None # rotary_embs: Optional[paddle.Tensor] = None # - padding_offset: Optional[paddle.Tensor] = None + padding_offset: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None # cu_seqlens_q: Optional[paddle.Tensor] = None @@ -315,7 +317,7 @@ class ForwardMeta(): cu_seqlens_k: Optional[paddle.Tensor] = None # - caches: Optional[paddle.Tensor] = None + caches: Optional[list[Annotated[paddle.Tensor, DynamicDims(0)]]] = None # attn_mask: Optional[paddle.Tensor] = None @@ -327,9 +329,9 @@ class ForwardMeta(): step_use_cudagraph: bool = False # for attention backend - decoder_batch_ids: Optional[paddle.Tensor] = None + decoder_batch_ids: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None # for attention backend - decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + decoder_tile_ids_per_batch: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None # is_decode_batch or not is_decode_batch: bool = False diff --git a/fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py b/fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py new file mode 100644 index 0000000000..b9d37b2934 --- /dev/null +++ b/fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import dataclasses +import inspect +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import partial +from typing import (Annotated, Any, Optional, TypeVar, Union, get_origin, + get_type_hints) + +import paddle +from paddle import Tensor +from typing_extensions import TypeAlias + +T = TypeVar("T") +U = TypeVar("U") + +Accessor: TypeAlias = Callable[[T], U] + + +class DynamicDims: + def __init__(self, dims: int | tuple[int]): + self.dims = dims if isinstance(dims, tuple) else (dims,) + + def __repr__(self): + return f"DynamicDims({self.dims})" + + +class DynamicDimTypeResolver: + ALL_DYNAMIC_DIM_TYPE_RESOLVERS = [] + + @classmethod + def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]): + cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls()) + return resolver_cls + + @abstractmethod + def check(self, tp) -> bool: + raise NotImplementedError + + @abstractmethod + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + raise NotImplementedError + + def resolve(self, data, data_name, tp) -> None: + inner_types = self.extract_inner_types(data, data_name, tp) + for accessor, inner_data_name, inner_type in inner_types: + self.generic_resolve(accessor(data), inner_data_name, inner_type) + + def generic_resolve(self, data, data_name, tp) -> None: + # assert isinstance(data, tp), f"Expected {data_name} has type {tp}, but got {type(data)}" + for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS: + if resolver.check(tp): + return resolver.resolve(data, data_name, tp) + runtime_tp = type(data) + if runtime_tp is not tp and resolver.check(runtime_tp): + return resolver.resolve(data, data_name, runtime_tp) + else: + print(f"No resolver found for type {tp} and data {data_name}") + + +@DynamicDimTypeResolver.register_resolver +class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return dataclasses.is_dataclass(tp) and isinstance(tp, type) + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + type_hints = get_type_hints(tp, include_extras=True) + return [ # type: ignore + ( + # bind name by partial to avoid capture wrong free vars + partial(lambda name, dt: getattr(dt, name), field.name), + f"{data_name}.{field.name}", + type_hints[field.name], + ) + for field in dataclasses.fields(tp) + ] + + +@DynamicDimTypeResolver.register_resolver +class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return ( + get_origin(tp) is Union + and len(tp.__args__) == 2 + and tp.__args__[1] is type(None) + ) + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + if data is None: + return [] + inner_type = tp.__args__[0] + return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional + + +@DynamicDimTypeResolver.register_resolver +class ListDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return get_origin(tp) is list + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + if not data: + return [] + inner_type = typing.get_args(tp)[0] if tp.__args__ else Any + return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore + +@DynamicDimTypeResolver.register_resolver +class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver): + INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__" + def check(self, tp) -> bool: + return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME) + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME) + if isinstance(fields, str): + raise TypeError(f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}") + inner_types_dict = typing.get_type_hints(tp) + return [(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type) for field_name, inner_type in inner_types_dict.items()] + +@DynamicDimTypeResolver.register_resolver +class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor + + def resolve(self, data, data_name, tp) -> None: + base_type, *metadata = typing.get_args(tp) + # Filter out DynamicDims instances + dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)] + if not dynamic_dims: + return + if len(dynamic_dims) > 1: + raise ValueError( + "Multiple DynamicDims annotations found. Only one is allowed." + ) + dynamic_dims = dynamic_dims[0].dims + if not isinstance(data, Tensor): + raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}") + print(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}") + paddle.jit.marker.dynamic_dims( + data, dynamic_dims + ) + +def resolve_dynamic_dims(arg, arg_name, annotation): + DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation) diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 7189989dd0..ba79f351b2 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -14,13 +14,66 @@ # limitations under the License. """ -from typing import Callable, Optional - -from paddle.jit.dy2static.utils import Backend +import functools +import inspect +import types +from typing import Callable, Optional, get_type_hints +import paddle from fastdeploy.config import FDConfig from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \ CudaGraphPiecewiseBackend +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + resolve_dynamic_dims +from paddle.jit import sot +from paddle.jit.dy2static.utils import Backend + + +def apply_to_static_optimization(fn): + forward_fn = fn + forward_sig = inspect.signature(forward_fn) + # forward_annotations = inspect.get_annotations(forward_fn) + forward_type_hints = get_type_hints(forward_fn) + static_forward_fn = sot.symbolic_translate( + forward_fn, training=False, backend=Backend.PHI + ) + unsafe_static_forward_fn = None + + @functools.wraps(forward_fn) + def static_forward(self, *args, **kwargs): + nonlocal unsafe_static_forward_fn + if unsafe_static_forward_fn is not None: + return unsafe_static_forward_fn(self, *args, **kwargs) + bound_args = forward_sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for name, arg in bound_args.arguments.items(): + if name not in forward_type_hints: + continue + annotation = forward_type_hints[name] + resolve_dynamic_dims(arg, name, annotation) + + # print(f"Processing argument '{name}' with annotation: {annotation}") + # if isinstance(arg, paddle.Tensor): + # print( + # f"Argument '{name}' is a Tensor with dynamic dims: {extract_dynamic_dims(annotation)}" + # ) + # paddle.jit.marker.dynamic_dims( + # arg, extract_dynamic_dims(annotation) + # ) + result = static_forward_fn(self, *args, **kwargs) + original_code = forward_fn.__code__ + (new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[original_code] + new_code = new_guarded_codes[0][0][0] + unsafe_static_forward_fn = types.FunctionType( + new_code, + forward_fn.__globals__, + forward_fn.__name__, + forward_fn.__defaults__, + forward_fn.__closure__, + ) + return result + + return static_forward class GraphOptBackend: @@ -43,9 +96,10 @@ def __init__(self, runnable: Callable, fd_config: FDConfig): backend = (Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI) - self.runnable = sot.symbolic_translate(self.runnable, - training=False, - backend=backend) + # self.runnable = sot.symbolic_translate(self.runnable, + # training=False, + # backend=backend) + self.runnable = apply_to_static_optimization(self.runnable.__func__).__get__(self.runnable.__self__) def __call__(self, **kwargs): if not self.fd_config.graph_opt_config.use_cudagraph: diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 6a1d0e1c12..40212e2604 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -19,9 +19,10 @@ from .mla_attention_backend import MLAAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend +from .attention import Attention __all__ = [ "AttentionBackend", "PaddleNativeAttnBackend", "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", - "MLAAttentionBackend", "FlashAttentionBackend" + "MLAAttentionBackend", "FlashAttentionBackend", "Attention", ] diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 5bc7f420aa..6d8598b129 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -18,22 +18,22 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Annotated, List, Optional, Tuple import paddle - from fastdeploy.model_executor.layers.attention.ops import ( append_attention, get_block_shape_and_split_kv_block, init_signal_layerwise, open_shm_and_get_meta_signal) if TYPE_CHECKING: - from paddle._typing.dtype_like import _DTypeLiteral + from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.config import FDConfig +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) -from fastdeploy.worker.forward_meta import ForwardMeta @dataclass @@ -43,17 +43,17 @@ class AppendAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None - _dtype: _DTypeLiteral = paddle.bfloat16 + _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 block_tables: Optional[paddle.Tensor] = None @@ -73,6 +73,9 @@ class AppendAttentionBackend(AttentionBackend): AppendAttentionBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: AppendAttentionMetadata + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int) -> None: """ @@ -115,7 +118,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, else: self.device_id = self.device_id.split(",")[device_id] - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = AppendAttentionMetadata() metadata.encoder_block_shape_q = 64 @@ -190,7 +193,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ forward_mixed diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 3f676f0317..555b6e6c13 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -14,17 +14,18 @@ # limitations under the License. """ -from typing import Dict, Optional +from typing import TYPE_CHECKING, Dict, Optional import numpy as np import paddle -from paddle import nn -from paddleformers.utils.log import logger - from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.quantization.quant_base import \ QuantMethodBase -from fastdeploy.worker.forward_meta import ForwardMeta +from paddle import nn +from paddleformers.utils.log import logger + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta class Attention(nn.Layer): @@ -113,7 +114,7 @@ def forward( qkv: paddle.Tensor = None, compressed_kv: paddle.Tensor = None, k_pe: paddle.Tensor = None, - forward_meta: ForwardMeta = None, + forward_meta: "ForwardMeta" = None, ) -> paddle.Tensor: """ The forward function of attention layer. diff --git a/fastdeploy/model_executor/layers/attention/base_attention_backend.py b/fastdeploy/model_executor/layers/attention/base_attention_backend.py index 02d1d65db6..7c0238dc5e 100644 --- a/fastdeploy/model_executor/layers/attention/base_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/base_attention_backend.py @@ -21,10 +21,12 @@ from abc import ABC, abstractmethod from dataclasses import dataclass +from typing import TYPE_CHECKING import paddle -from fastdeploy.worker.forward_meta import ForwardMeta +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta @dataclass @@ -36,7 +38,7 @@ class AttentionBackend(ABC): """The base class of attention backends""" @abstractmethod - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize the forward metadata.""" raise NotImplementedError() @@ -49,7 +51,7 @@ def forward( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Run a forward. @@ -105,7 +107,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """Run a forward for mix.""" raise NotImplementedError() @@ -119,7 +121,7 @@ def forward_decode( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """Run a forward for decode.""" raise NotImplementedError() @@ -133,7 +135,7 @@ def forward_extend( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """Run a forward for extend.""" raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index 74a234bd19..8d0a09dc8a 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field -from typing import List, Optional +from typing import List, Optional, Annotated, TYPE_CHECKING import paddle from paddle.nn.functional.flash_attention import flash_attention_v3_varlen @@ -27,10 +27,14 @@ from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, gqa_rope_write_cache, init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) -from fastdeploy.worker.forward_meta import ForwardMeta + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta @dataclass @@ -42,15 +46,15 @@ class FlashAttentionMetadata(AttentionMetadata): set_max_lengths: int = -1 rotary_embs: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None encoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: Optional[paddle.Tensor] = None @@ -75,6 +79,9 @@ class FlashAttentionBackend(AttentionBackend): FlashAttentionBackend backend implementation """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: FlashAttentionMetadata + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ @@ -127,7 +134,7 @@ def get_kv_cache_shape( return (max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim) - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): metadata = FlashAttentionMetadata() metadata.encoder_block_shape_q = 64 metadata.decoder_block_shape_q = 16 @@ -190,7 +197,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ): metadata = self.attention_metadata diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index 1d9c9773be..758e4fb69b 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -19,7 +19,7 @@ import math import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Annotated import paddle from paddle.nn.functional.flash_attention import flash_attn_unpadded @@ -35,13 +35,14 @@ prefill_mla_write_cache) if TYPE_CHECKING: - from paddle._typing.dtype_like import _DTypeLiteral + from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims def yarn_get_mscale(scale=1, mscale=1): @@ -59,17 +60,17 @@ class MLAAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None - - _dtype: _DTypeLiteral = paddle.bfloat16 + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + + _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 block_tables: Optional[paddle.Tensor] = None @@ -89,6 +90,10 @@ class MLAAttentionBackend(AttentionBackend): MLA Attention Backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: MLAAttentionMetadata + + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int) -> None: """ @@ -140,7 +145,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, else: self.device_id = self.device_id.split(",")[self.rank] - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" metadata = MLAAttentionMetadata() metadata.encoder_block_shape_q = 64 @@ -217,7 +222,7 @@ def forward_extend( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Prefill阶段的前向传播 @@ -272,7 +277,7 @@ def forward_decode( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Decode阶段的前向传播 @@ -368,7 +373,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Mixed模式的前向传播 diff --git a/fastdeploy/model_executor/layers/attention/native_paddle_backend.py b/fastdeploy/model_executor/layers/attention/native_paddle_backend.py index 8e8b9ce77b..b286bd98fa 100644 --- a/fastdeploy/model_executor/layers/attention/native_paddle_backend.py +++ b/fastdeploy/model_executor/layers/attention/native_paddle_backend.py @@ -17,12 +17,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING + import paddle from paddle.nn.functional import scaled_dot_product_attention from fastdeploy.model_executor.layers.attention.base_attention_backend import \ AttentionBackend -from fastdeploy.worker.forward_meta import ForwardMeta + +if TYPE_CHECKING: + from fastdeploy.model_executor.forward_meta import ForwardMeta class PaddleNativeAttnBackend(AttentionBackend): @@ -34,7 +38,7 @@ class PaddleNativeAttnBackend(AttentionBackend): def __init__(self) -> None: super().__init__() - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Init the metadata for a forward pass.""" pass @@ -212,7 +216,7 @@ def forward_extend( k: paddle.Tensor, v: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", save_kv_cache: bool = True, ) -> paddle.Tensor: """ @@ -253,7 +257,7 @@ def forward_decode( k: paddle.Tensor, v: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Run the decoding attention forward by using paddle native sdpa op. diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 9ecc01fb89..7141ed16e4 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Annotated import paddle @@ -26,14 +26,14 @@ init_signal_layerwise, open_shm_and_get_meta_signal) if TYPE_CHECKING: - from paddle._typing.dtype_like import _DTypeLiteral + from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) -from fastdeploy.worker.forward_meta import ForwardMeta - +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims @dataclass class XPUAttentionMetadata(AttentionMetadata): @@ -42,17 +42,17 @@ class XPUAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None - - _dtype: _DTypeLiteral = paddle.bfloat16 + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + + _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 max_partition_size: int = 32768 block_tables: Optional[paddle.Tensor] = None @@ -72,6 +72,9 @@ class XPUAttentionBackend(AttentionBackend): XPUAttentionBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: XPUAttentionMetadata + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ @@ -103,7 +106,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, os.getenv("FLAGS_use_pd_disaggregation", 0)) self.start_layer_index: int = fd_config.model_config.start_layer_index - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = XPUAttentionMetadata() metadata.encoder_block_shape_q = 64 @@ -152,7 +155,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ forward_mixed diff --git a/fastdeploy/model_executor/models/deepseek_v3.py b/fastdeploy/model_executor/models/deepseek_v3.py index 73997c2acd..1966155a22 100644 --- a/fastdeploy/model_executor/models/deepseek_v3.py +++ b/fastdeploy/model_executor/models/deepseek_v3.py @@ -40,7 +40,7 @@ DeepseekScalingRotaryEmbedding from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.platforms import current_platform -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta if current_platform.is_cuda(): from fastdeploy.model_executor.ops.gpu import \ diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index f6b73622a9..cab543b10f 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import partial -from typing import Dict, Union +from typing import Dict, Union, Annotated import numpy as np import paddle @@ -41,8 +41,10 @@ from fastdeploy.model_executor.models.utils import \ LayerIdPlaceholder as layerid from fastdeploy.model_executor.models.utils import WeightMeta -from fastdeploy.worker.forward_meta import ForwardMeta - +from fastdeploy.model_executor.forward_meta import ForwardMeta +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import ( + DynamicDims +) class Ernie4_5_MLP(nn.Layer): @@ -387,7 +389,7 @@ def load_state_dict(self, state_dict): def forward( self, - ids_remove_padding: paddle.Tensor, + ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)], forward_meta: ForwardMeta, ): hidden_states = self.embeddings(ids_remove_padding=ids_remove_padding) diff --git a/fastdeploy/model_executor/models/ernie4_5_mtp.py b/fastdeploy/model_executor/models/ernie4_5_mtp.py index 029becc1e4..60c467c158 100644 --- a/fastdeploy/model_executor/models/ernie4_5_mtp.py +++ b/fastdeploy/model_executor/models/ernie4_5_mtp.py @@ -30,7 +30,7 @@ from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.ernie4_5_moe import Ernie4_5_DecoderLayer from fastdeploy.model_executor.models.model_base import ModelForCasualLM -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta class Ernie4_5_MTPPretrainedModel(PretrainedModel): diff --git a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py index a08433a570..c7e558309f 100644 --- a/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py @@ -42,7 +42,7 @@ text_image_gather_scatter, text_image_index_out) -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta class Ernie4_5_VLMLP(Ernie4_5_MLP): diff --git a/fastdeploy/model_executor/models/qwen2.py b/fastdeploy/model_executor/models/qwen2.py index 0a5912afb0..eae554d377 100644 --- a/fastdeploy/model_executor/models/qwen2.py +++ b/fastdeploy/model_executor/models/qwen2.py @@ -34,7 +34,7 @@ from fastdeploy.model_executor.layers.lm_head import ParallelLMHead from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.model_base import ModelForCasualLM -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta class Qwen2MLP(nn.Layer): diff --git a/fastdeploy/model_executor/models/qwen3.py b/fastdeploy/model_executor/models/qwen3.py index c1654f4144..86244a6e9f 100644 --- a/fastdeploy/model_executor/models/qwen3.py +++ b/fastdeploy/model_executor/models/qwen3.py @@ -34,7 +34,7 @@ from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.model_base import ModelForCasualLM from fastdeploy.model_executor.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta class Qwen3MLP(Qwen2MLP): diff --git a/fastdeploy/model_executor/models/qwen3moe.py b/fastdeploy/model_executor/models/qwen3moe.py index c4d01ef6ea..7777f2150e 100644 --- a/fastdeploy/model_executor/models/qwen3moe.py +++ b/fastdeploy/model_executor/models/qwen3moe.py @@ -35,7 +35,7 @@ from fastdeploy.model_executor.layers.moe.moe import FusedMoE from fastdeploy.model_executor.layers.normalization import RMSNorm from fastdeploy.model_executor.models.model_base import ModelForCasualLM -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta class Qwen3MLP(nn.Layer): diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 97e8364451..3036578d6c 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -36,7 +36,7 @@ share_external_data) from fastdeploy.model_executor.pre_and_post_process import (pre_process, rebuild_padding) -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta from .base import Proposer diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 8d6ca79a1b..a40980a557 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -42,7 +42,7 @@ rebuild_padding, step_cuda) from fastdeploy.spec_decode import MTPProposer, NgramProposer -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput diff --git a/fastdeploy/worker/vl_gpu_model_runner.py b/fastdeploy/worker/vl_gpu_model_runner.py index f48cefe8f3..27be4fb2d6 100644 --- a/fastdeploy/worker/vl_gpu_model_runner.py +++ b/fastdeploy/worker/vl_gpu_model_runner.py @@ -41,7 +41,7 @@ from fastdeploy.model_executor.models.ernie4_5_vl.modeling_resampler import ( ScatterOp, VariableResolutionResamplerModel) from fastdeploy.platforms import current_platform -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.worker.utils import check_safetensors_model from fastdeploy.worker.vl_model_runner_base import VLModelRunnerBase from fastdeploy.config import (DeviceConfig, FDConfig, KVCacheConfig, diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index b075356f99..60164bc1ec 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -31,7 +31,7 @@ from fastdeploy.model_executor.layers.sample.sampler import Sampler from fastdeploy.model_executor.model_loader import get_model_from_loader from fastdeploy.utils import get_logger -from fastdeploy.worker.forward_meta import ForwardMeta, XPUForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta, XPUForwardMeta from fastdeploy.worker.model_runner_base import ModelRunnerBase from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput diff --git a/test/layers/test_attention.py b/test/layers/test_attention.py index 9d4b096798..6ed70932ec 100644 --- a/test/layers/test_attention.py +++ b/test/layers/test_attention.py @@ -21,7 +21,7 @@ from fastdeploy.model_executor.layers.attention import ( Attention, PaddleNativeAttnBackend) -from fastdeploy.worker.forward_meta import (ForwardMeta, ForwardMode, +from fastdeploy.model_executor.forward_meta import (ForwardMeta, ForwardMode, MHATokenToKVPool) diff --git a/test/worker/test_cuda_graph.py b/test/worker/test_cuda_graph.py index f00b129c5a..30c0dca1e3 100644 --- a/test/worker/test_cuda_graph.py +++ b/test/worker/test_cuda_graph.py @@ -18,7 +18,7 @@ from fastdeploy.config import FDConfig, GraphOptimizationConfig from fastdeploy.model_executor.graph_optimization.decorator import \ support_graph_optimization -from fastdeploy.worker.forward_meta import ForwardMeta +from fastdeploy.model_executor.forward_meta import ForwardMeta @support_graph_optimization From ac71beec357819c7fad69c685d3e41f1ffe5cc37 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Fri, 11 Jul 2025 13:18:45 +0800 Subject: [PATCH 2/4] fix conflict of forward_meta --- fastdeploy/model_executor/forward_meta.py | 358 +++------------------- 1 file changed, 43 insertions(+), 315 deletions(-) diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index f03ecc40d5..421f0a17f1 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -14,20 +14,16 @@ # limitations under the License. """ -import abc import logging from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, Annotated +from typing import TYPE_CHECKING, Optional, Annotated +from fastdeploy.model_executor.layers.attention import AttentionBackend from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import ( DynamicDims ) -from fastdeploy.model_executor.layers.attention import (Attention, - AttentionBackend) - -import numpy as np import paddle - + logger = logging.getLogger(__name__) @@ -36,333 +32,79 @@ class ForwardMode(IntEnum): """ Forward mode used during attention. """ - - # for prefill and extend + # Prefill and Extend mode EXTEND = auto() - # for generation + # Decode mode DECODE = auto() - + # Mixed mode MIXED = auto() def is_prefill(self): - """Whether it's a prefill forward""" + """ Is Extend mode """ return self == ForwardMode.EXTEND def is_decode(self): - """Whether it's a decode forward""" + """ Is Decode mode """ return self == ForwardMode.DECODE def is_mixed(self): - """Whether it's a decode forward""" + """ Is Mixed mode """ return self == ForwardMode.MIXED -class ReqToTokenPool: - """A memory pool that maps a request to its token locations.""" - - def __init__(self, size: int, max_context_len: int): - - self.size = size - self.max_context_len = max_context_len - self.req_to_token = paddle.zeros((size, max_context_len), - dtype=paddle.int32) - self.free_slots = list(range(size)) - - def write(self, indices, values): - """Write data into request buffer""" - self.req_to_token[indices] = values - - def available_size(self): - """Get number of slots left""" - return len(self.free_slots) - - def alloc(self, need_size: int) -> List[int]: - """Allocate `need_size` slots""" - if need_size > len(self.free_slots): - return None - - select_index = self.free_slots[:need_size] - self.free_slots = self.free_slots[need_size:] - - return select_index - - def free(self, free_index: Union[int, List[int]]): - """Free slot""" - if isinstance(free_index, (int, )): - self.free_slots.append(free_index) - else: - self.free_slots.extend(free_index) - - def clear(self): - """Clear all slots""" - self.free_slots = list(range(self.size)) - - -class KVCache(abc.ABC): - """Abstract base class representing a key value cache""" - - @abc.abstractmethod - def get_kv_buffer(self, - layer_id: int) -> Tuple[paddle.Tensor, paddle.Tensor]: - """ - Return cached keys and values given layer id. - Args: - layer_id: int - Returns: - tuple: (keys, values) - """ - raise NotImplementedError() - - @abc.abstractmethod - def set_kv_buffer( - self, - layer: 'Attention', - loc: paddle.Tensor, - cache_k: paddle.Tensor, - cache_v: paddle.Tensor, - ) -> None: - """ - Set cached keys and values given layer id. - Args: - layer: Attention - loc: paddle.Tensor - cache_k: paddle.Tensor - cache_v: paddle.Tensor - """ - raise NotImplementedError() - - @abc.abstractmethod - def transfer(self, indices, flat_data): - """Transfer kv_data between devices""" - raise NotImplementedError() - - @abc.abstractmethod - def transfer_per_layer(self, indices, flat_data, layer_id): - """Not used yet""" - raise NotImplementedError() - - def register_layer_transfer_counter(self, layer_transfer_counter): - """Not used yet""" - self.layer_transfer_counter = layer_transfer_counter - - -class MHATokenToKVPool(KVCache): - """Token To Key Value Pool for MultiHeadAttention""" - - def __init__( - self, - max_block_num: int, - block_size: int, - dtype: paddle.dtype, - head_num: int, - head_dim: int, - layer_num: int, - device: str, - ): - self.max_block_num = max_block_num - self.block_size = block_size - self.dtype = dtype - self.device = device - if dtype in (paddle.int8, paddle.float8_e4m3fn): - # NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2 - self.store_dtype = paddle.uint8 - else: - self.store_dtype = dtype - - self.head_num = head_num - self.head_dim = head_dim - self.layer_num = layer_num - self._create_buffers() - - k_size, v_size = self.get_kv_size_bytes() - GB = 1024 * 1024 * 1024 - logger.info( - f"KV Cache is allocated. #tokens: {self.size}, K size: {k_size / GB:.2f} GB, V size: {v_size / GB:.2f} GB" - ) - - def _create_buffers(self): - # [size, head_num, head_dim] for each layer - # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.k_buffer = [ - paddle.zeros( - (self.max_block_num, self.head_num, self.block_size, - self.head_dim), - dtype=self.store_dtype, - ) for _ in range(self.layer_num) - ] - self.v_buffer = [ - paddle.zeros( - (self.max_block_num, self.head_num, self.block_size, - self.head_dim), - dtype=self.store_dtype, - ) for _ in range(self.layer_num) - ] - - def _clear_buffers(self): - del self.k_buffer - del self.v_buffer - - def get_kv_size_bytes(self): - """for debugging purpose""" - assert hasattr(self, "k_buffer") - assert hasattr(self, "v_buffer") - k_size_bytes = 0 - for k_cache in self.k_buffer: - k_size_bytes += np.prod(k_cache.shape) * 4 - v_size_bytes = 0 - for v_cache in self.v_buffer: - v_size_bytes += np.prod(v_cache.shape) * 4 - return k_size_bytes, v_size_bytes - - def transfer(self, indices, flat_data): - # transfer prepared data from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - k_data, v_data = flat_data[0], flat_data[1] - for i in range(self.layer_num): - self.k_buffer[i][indices] = k_data[i] - self.v_buffer[i][indices] = v_data[i] - - def transfer_per_layer(self, indices, flat_data, layer_id): - # transfer prepared data for a specific layer from host to device - flat_data = flat_data.to(device=self.device, non_blocking=False) - k_data, v_data = flat_data[0], flat_data[1] - self.k_buffer[layer_id][indices] = k_data - self.v_buffer[layer_id][indices] = v_data - - def get_key_buffer(self, layer_id: int): - """Return cached keys given layer id.""" - if self.store_dtype != self.dtype: - return self.k_buffer[layer_id].view(self.dtype) - return self.k_buffer[layer_id] - - def get_value_buffer(self, layer_id: int): - """Return cached values given layer id.""" - if self.store_dtype != self.dtype: - return self.v_buffer[layer_id].view(self.dtype) - return self.v_buffer[layer_id] - - def get_kv_buffer(self, layer_id: int): - """Return cached keys and values given layer id.""" - return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) - - def set_kv_buffer( - self, - layer: 'Attention', - loc: paddle.Tensor, - cache_k: paddle.Tensor, - cache_v: paddle.Tensor, - k_scale: Optional[float] = None, - v_scale: Optional[float] = None, - ): - """Set cached keys and values given layer id.""" - layer_id = layer.layer_id - if cache_k.dtype != self.dtype: - if k_scale is not None: - cache_k.div_(k_scale) - if v_scale is not None: - cache_v.div_(v_scale) - cache_k = cache_k.to(self.dtype) - cache_v = cache_v.to(self.dtype) - - if self.store_dtype != self.dtype: - cache_k = cache_k.view(self.store_dtype) - cache_v = cache_v.view(self.store_dtype) - - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v - - @dataclass class ForwardMeta(): """ - ForwardMeta is used to store the global meta information of the forward. + ForwardMeta is used to store the global meta information of the model forward. """ - # + # Input tokens IDs input_ids: paddle.Tensor + # Input tokens IDs of removed padding + ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)] + # Rotation position embedding + rotary_embs: Optional[paddle.Tensor] = None - #attention meta - forward_mode: ForwardMode = ForwardMode.MIXED + # Use cuda graph in this step or not. Used to avoid run cuda graph when in dummy run or prefill stage. + step_use_cudagraph: bool = False + # Batch type flag + is_decode_batch: bool = False - # - ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)] = None + # Attention backend object + attn_backend: AttentionBackend = None + # Forward mode used during attention + forward_mode: ForwardMode = ForwardMode.MIXED + # Attention mask + attn_mask: Optional[paddle.Tensor] = None + # Decoder batch id. Used by attention backend. + decoder_batch_ids: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None + # Tile ID for each batch of the decoder. Used by attention backend. + decoder_tile_ids_per_batch: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None - # + # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None - - # + # Sequence length of Encoder for ever batch seq_lens_decoder: Optional[paddle.Tensor] = None - - # + # The sequence length processed in the current step seq_lens_this_time: Optional[paddle.Tensor] = None - # + # Accumulated offset cum_offsets: Optional[paddle.Tensor] = None - - # - block_tables: Optional[paddle.Tensor] = None - - # - attn_backend: AttentionBackend = None - - # - rotary_embs: Optional[paddle.Tensor] = None - - # + # Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids padding_offset: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None - - # + # Accumulated sequence length of query cu_seqlens_q: Optional[paddle.Tensor] = None - - # + # Accumulated sequence length of key cu_seqlens_k: Optional[paddle.Tensor] = None - # - caches: Optional[list[Annotated[paddle.Tensor, DynamicDims(0)]]] = None - - # - attn_mask: Optional[paddle.Tensor] = None - - # + # Pre-cache length pre_caches_length: int = 0 + # Block tables + block_tables: Optional[paddle.Tensor] = None + # KV caches + caches: Optional[list[Annotated[paddle.Tensor, DynamicDims(0)]]] = None - # Use cuda graph in this step. Used to avoid run cuda graph when in dummy run or prefill stage. - step_use_cudagraph: bool = False - - # for attention backend - decoder_batch_ids: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None - # for attention backend - decoder_tile_ids_per_batch: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None - # is_decode_batch or not - is_decode_batch: bool = False - - @classmethod - def init_forward_meta(cls, share_inputs: Dict, - attn_backend: "AttentionBackend"): - """ init forward meta """ - # TODO(gongshaotian): delete this func - ret = cls( - forward_mode=ForwardMode.MIXED, - input_ids=share_inputs["input_ids"], - ids_remove_padding=share_inputs["ids_remove_padding"], - seq_lens_encoder=share_inputs["seq_lens_encoder"], - seq_lens_decoder=share_inputs["seq_lens_decoder"], - seq_lens_this_time=share_inputs["seq_lens_this_time"], - cum_offsets=share_inputs["cum_offsets"], - block_tables=share_inputs["block_tables"], - attn_backend=attn_backend, - rotary_embs=share_inputs["rope_emb"], - padding_offset=share_inputs["padding_offset"], - cu_seqlens_q=share_inputs["cu_seqlens_q"], - cu_seqlens_k=share_inputs["cu_seqlens_k"], - caches=share_inputs["caches"], - decoder_batch_ids=share_inputs.get("decoder_batch_ids", None), - decoder_tile_ids_per_batch=share_inputs.get( - "decoder_tile_ids_per_batch", None), - ) - return ret - def clear_caches(self): - """safe clear caches""" + """ Safely clean up the caches """ if self.caches: del self.caches @@ -372,56 +114,42 @@ class XPUForwardMeta(ForwardMeta): """ XPUForwardMeta is used to store the global meta information of the forward, and some XPU specific meta info. """ + # TODO(wanghaitao): Supplementary notes # encoder_batch_map: Optional[paddle.Tensor] = None - # decoder_batch_map: Optional[paddle.Tensor] = None - # encoder_batch_idx: Optional[paddle.Tensor] = None - # decoder_batch_idx: Optional[paddle.Tensor] = None - # encoder_seq_lod: Optional[paddle.Tensor] = None - # decoder_context_len: Optional[paddle.Tensor] = None - # decoder_context_len_cache: Optional[paddle.Tensor] = None # encoder_batch_map_cpu: Optional[paddle.Tensor] = None - # decoder_batch_map_cpu: Optional[paddle.Tensor] = None - # encoder_batch_idx_cpu: Optional[paddle.Tensor] = None - # decoder_batch_idx_cpu: Optional[paddle.Tensor] = None - # encoder_seq_lod_cpu: Optional[paddle.Tensor] = None - # decoder_context_len_cpu: Optional[paddle.Tensor] = None - # decoder_context_len_cache_cpu: Optional[paddle.Tensor] = None # batch_tensor: Optional[paddle.Tensor] = None - # enc_batch: Optional[paddle.Tensor] = None - # dec_batch: Optional[paddle.Tensor] = None - # total_enc_len: Optional[paddle.Tensor] = None From 5002ccc85910c063e156b24876eee807dcecfcea Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 16 Jul 2025 10:32:52 +0800 Subject: [PATCH 3/4] mark more attn backend --- .../attention/block_multihead_attn_backend.py | 25 +++++++++++-------- .../gcu/attention/flash_attn_backend.py | 3 +++ 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index 5d48f54779..1b37d500eb 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Annotated import paddle @@ -29,6 +29,8 @@ from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims @dataclass @@ -38,15 +40,15 @@ class BlockAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -68,6 +70,9 @@ class BlockAttentionBackend(AttentionBackend): BlockAttentionBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: BlockAttentionBackend + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py index 00032e26fd..bfc0e4f721 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py @@ -70,6 +70,9 @@ class GCUFlashAttnBackend(AttentionBackend): GCUFlashAttnBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: GCUFlashAttnBackend + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ From 00a57f60257090787c4c7d9f2fb08ae5e0eb8756 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Wed, 16 Jul 2025 15:38:48 +0800 Subject: [PATCH 4/4] fix missing annotated and add env SOT_SPECIALIZED_DIM_NUMBERS --- fastdeploy/engine/engine.py | 2 ++ fastdeploy/model_executor/forward_meta.py | 1 + 2 files changed, 3 insertions(+) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 1f9dc9278e..af9f088e2f 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -948,6 +948,8 @@ def _setting_environ_variables(self): os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), + "SOT_SPECIALIZED_DIM_NUMBERS": + os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_enable_async_fast_gc": diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index c939d670af..fbe3c5aa8f 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -17,6 +17,7 @@ import logging from dataclasses import dataclass from enum import IntEnum, auto +from typing import Annotated, Optional import paddle