diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index 1f9dc9278e..af9f088e2f 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -948,6 +948,8 @@ def _setting_environ_variables(self): os.getenv("SOT_UNSAFE_CACHE_FASTPATH", default="1"), "SOT_ENABLE_0_SIZE_FALLBACK": os.getenv("SOT_ENABLE_0_SIZE_FALLBACK", default="0"), + "SOT_SPECIALIZED_DIM_NUMBERS": + os.getenv("SOT_SPECIALIZED_DIM_NUMBERS", default="no"), "FLAGS_specialize_device_in_dy2st": os.getenv("FLAGS_specialize_device_in_dy2st", default="1"), "FLAGS_enable_async_fast_gc": diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index 17ab2e9ade..fbe3c5aa8f 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -17,11 +17,14 @@ import logging from dataclasses import dataclass from enum import IntEnum, auto -from typing import Optional +from typing import Annotated, Optional import paddle from fastdeploy.model_executor.layers.attention import AttentionBackend +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import ( + DynamicDims +) logger = logging.getLogger(__name__) @@ -58,7 +61,7 @@ class ForwardMeta(): # Input tokens IDs input_ids: paddle.Tensor # Input tokens IDs of removed padding - ids_remove_padding: paddle.Tensor + ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)] # Rotation position embedding rotary_embs: Optional[paddle.Tensor] = None @@ -72,9 +75,9 @@ class ForwardMeta(): # Attention mask attn_mask: Optional[paddle.Tensor] = None # Decoder batch id. Used by attention backend. - decoder_batch_ids: Optional[paddle.Tensor] = None + decoder_batch_ids: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None # Tile ID for each batch of the decoder. Used by attention backend. - decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None + decoder_tile_ids_per_batch: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None # Sequence length of encoder for ever batch seq_lens_encoder: Optional[paddle.Tensor] = None @@ -86,7 +89,7 @@ class ForwardMeta(): # Accumulated offset cum_offsets: Optional[paddle.Tensor] = None # Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids - padding_offset: Optional[paddle.Tensor] = None + padding_offset: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None # Accumulated sequence length of query cu_seqlens_q: Optional[paddle.Tensor] = None # Accumulated sequence length of key @@ -97,7 +100,7 @@ class ForwardMeta(): # Block tables block_tables: Optional[paddle.Tensor] = None # KV caches - caches: Optional[list[paddle.Tensor]] = None + caches: Optional[list[Annotated[paddle.Tensor, DynamicDims(0)]]] = None def clear_caches(self): """ Safely clean up the caches """ diff --git a/fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py b/fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py new file mode 100644 index 0000000000..b9d37b2934 --- /dev/null +++ b/fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +import dataclasses +import inspect +import typing +from abc import ABC, abstractmethod +from collections.abc import Callable +from functools import partial +from typing import (Annotated, Any, Optional, TypeVar, Union, get_origin, + get_type_hints) + +import paddle +from paddle import Tensor +from typing_extensions import TypeAlias + +T = TypeVar("T") +U = TypeVar("U") + +Accessor: TypeAlias = Callable[[T], U] + + +class DynamicDims: + def __init__(self, dims: int | tuple[int]): + self.dims = dims if isinstance(dims, tuple) else (dims,) + + def __repr__(self): + return f"DynamicDims({self.dims})" + + +class DynamicDimTypeResolver: + ALL_DYNAMIC_DIM_TYPE_RESOLVERS = [] + + @classmethod + def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]): + cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls()) + return resolver_cls + + @abstractmethod + def check(self, tp) -> bool: + raise NotImplementedError + + @abstractmethod + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + raise NotImplementedError + + def resolve(self, data, data_name, tp) -> None: + inner_types = self.extract_inner_types(data, data_name, tp) + for accessor, inner_data_name, inner_type in inner_types: + self.generic_resolve(accessor(data), inner_data_name, inner_type) + + def generic_resolve(self, data, data_name, tp) -> None: + # assert isinstance(data, tp), f"Expected {data_name} has type {tp}, but got {type(data)}" + for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS: + if resolver.check(tp): + return resolver.resolve(data, data_name, tp) + runtime_tp = type(data) + if runtime_tp is not tp and resolver.check(runtime_tp): + return resolver.resolve(data, data_name, runtime_tp) + else: + print(f"No resolver found for type {tp} and data {data_name}") + + +@DynamicDimTypeResolver.register_resolver +class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return dataclasses.is_dataclass(tp) and isinstance(tp, type) + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + type_hints = get_type_hints(tp, include_extras=True) + return [ # type: ignore + ( + # bind name by partial to avoid capture wrong free vars + partial(lambda name, dt: getattr(dt, name), field.name), + f"{data_name}.{field.name}", + type_hints[field.name], + ) + for field in dataclasses.fields(tp) + ] + + +@DynamicDimTypeResolver.register_resolver +class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return ( + get_origin(tp) is Union + and len(tp.__args__) == 2 + and tp.__args__[1] is type(None) + ) + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + if data is None: + return [] + inner_type = tp.__args__[0] + return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional + + +@DynamicDimTypeResolver.register_resolver +class ListDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return get_origin(tp) is list + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + if not data: + return [] + inner_type = typing.get_args(tp)[0] if tp.__args__ else Any + return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore + +@DynamicDimTypeResolver.register_resolver +class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver): + INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__" + def check(self, tp) -> bool: + return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME) + + def extract_inner_types( + self, data, data_name, tp + ) -> list[tuple[Accessor[Any, Any], str, type[Any]]]: + fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME) + if isinstance(fields, str): + raise TypeError(f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}") + inner_types_dict = typing.get_type_hints(tp) + return [(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type) for field_name, inner_type in inner_types_dict.items()] + +@DynamicDimTypeResolver.register_resolver +class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver): + def check(self, tp) -> bool: + return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor + + def resolve(self, data, data_name, tp) -> None: + base_type, *metadata = typing.get_args(tp) + # Filter out DynamicDims instances + dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)] + if not dynamic_dims: + return + if len(dynamic_dims) > 1: + raise ValueError( + "Multiple DynamicDims annotations found. Only one is allowed." + ) + dynamic_dims = dynamic_dims[0].dims + if not isinstance(data, Tensor): + raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}") + print(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}") + paddle.jit.marker.dynamic_dims( + data, dynamic_dims + ) + +def resolve_dynamic_dims(arg, arg_name, annotation): + DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation) diff --git a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py index 9ce6f73729..87f1ed7ab1 100644 --- a/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py +++ b/fastdeploy/model_executor/graph_optimization/graph_optimization_backend.py @@ -14,13 +14,66 @@ # limitations under the License. """ -from typing import Callable, Optional - -from paddle.jit.dy2static.utils import Backend +import functools +import inspect +import types +from typing import Callable, Optional, get_type_hints +import paddle from fastdeploy.config import FDConfig from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \ CudaGraphPiecewiseBackend +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + resolve_dynamic_dims +from paddle.jit import sot +from paddle.jit.dy2static.utils import Backend + + +def apply_to_static_optimization(fn): + forward_fn = fn + forward_sig = inspect.signature(forward_fn) + # forward_annotations = inspect.get_annotations(forward_fn) + forward_type_hints = get_type_hints(forward_fn) + static_forward_fn = sot.symbolic_translate( + forward_fn, training=False, backend=Backend.PHI + ) + unsafe_static_forward_fn = None + + @functools.wraps(forward_fn) + def static_forward(self, *args, **kwargs): + nonlocal unsafe_static_forward_fn + if unsafe_static_forward_fn is not None: + return unsafe_static_forward_fn(self, *args, **kwargs) + bound_args = forward_sig.bind(self, *args, **kwargs) + bound_args.apply_defaults() + for name, arg in bound_args.arguments.items(): + if name not in forward_type_hints: + continue + annotation = forward_type_hints[name] + resolve_dynamic_dims(arg, name, annotation) + + # print(f"Processing argument '{name}' with annotation: {annotation}") + # if isinstance(arg, paddle.Tensor): + # print( + # f"Argument '{name}' is a Tensor with dynamic dims: {extract_dynamic_dims(annotation)}" + # ) + # paddle.jit.marker.dynamic_dims( + # arg, extract_dynamic_dims(annotation) + # ) + result = static_forward_fn(self, *args, **kwargs) + original_code = forward_fn.__code__ + (new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[original_code] + new_code = new_guarded_codes[0][0][0] + unsafe_static_forward_fn = types.FunctionType( + new_code, + forward_fn.__globals__, + forward_fn.__name__, + forward_fn.__defaults__, + forward_fn.__closure__, + ) + return result + + return static_forward class GraphOptBackend: @@ -46,9 +99,10 @@ def __init__(self, runnable: Callable, fd_config: FDConfig): backend = (Backend.CINN if self.fd_config.graph_opt_config.graph_opt_level > 1 else Backend.PHI) - self.runnable = sot.symbolic_translate(self.runnable, - training=False, - backend=backend) + # self.runnable = sot.symbolic_translate(self.runnable, + # training=False, + # backend=backend) + self.runnable = apply_to_static_optimization(self.runnable.__func__).__get__(self.runnable.__self__) def __call__(self, **kwargs): if not self.fd_config.graph_opt_config.use_cudagraph: diff --git a/fastdeploy/model_executor/layers/attention/__init__.py b/fastdeploy/model_executor/layers/attention/__init__.py index 5557616f0a..078cd2282b 100644 --- a/fastdeploy/model_executor/layers/attention/__init__.py +++ b/fastdeploy/model_executor/layers/attention/__init__.py @@ -19,6 +19,7 @@ from .mla_attention_backend import MLAAttentionBackend from .native_paddle_backend import PaddleNativeAttnBackend from .xpu_attn_backend import XPUAttentionBackend +from .attention import Attention from .iluvatar_attn_backend import IluvatarAttnBackend from .block_multihead_attn_backend import BlockAttentionBackend @@ -26,5 +27,5 @@ "AttentionBackend", "PaddleNativeAttnBackend", "get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend", "MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend", - "BlockAttentionBackend" + "BlockAttentionBackend", "Attention" ] diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 9f57f41791..42348bed69 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -18,10 +18,9 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Annotated, List, Optional, Tuple import paddle - from fastdeploy.model_executor.layers.attention.ops import ( append_attention, get_block_shape_and_split_kv_block, init_signal_layerwise, open_shm_and_get_meta_signal) @@ -30,6 +29,8 @@ from fastdeploy.model_executor.forward_meta import ForwardMeta from fastdeploy.config import FDConfig +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) @@ -44,15 +45,15 @@ class AppendAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -74,6 +75,9 @@ class AppendAttentionBackend(AttentionBackend): AppendAttentionBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: AppendAttentionMetadata + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int) -> None: """ @@ -110,7 +114,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, self.rank, self.device_id = init_rank_and_device_id(fd_config) - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = AppendAttentionMetadata() metadata.encoder_block_shape_q = 64 @@ -185,7 +189,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ forward_mixed diff --git a/fastdeploy/model_executor/layers/attention/attention.py b/fastdeploy/model_executor/layers/attention/attention.py index 457e5d5215..b740583079 100644 --- a/fastdeploy/model_executor/layers/attention/attention.py +++ b/fastdeploy/model_executor/layers/attention/attention.py @@ -20,12 +20,12 @@ import numpy as np import paddle -from paddle import nn -from paddleformers.utils.log import logger - from fastdeploy.config import FDConfig from fastdeploy.model_executor.layers.quantization.quant_base import \ QuantMethodBase +from paddle import nn +from paddleformers.utils.log import logger + if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -116,7 +116,7 @@ def forward( qkv: paddle.Tensor = None, compressed_kv: paddle.Tensor = None, k_pe: paddle.Tensor = None, - forward_meta: ForwardMeta = None, + forward_meta: "ForwardMeta" = None, ) -> paddle.Tensor: """ The forward function of attention layer. diff --git a/fastdeploy/model_executor/layers/attention/base_attention_backend.py b/fastdeploy/model_executor/layers/attention/base_attention_backend.py index 4a442e5c34..7c0238dc5e 100644 --- a/fastdeploy/model_executor/layers/attention/base_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/base_attention_backend.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING import paddle + if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -37,7 +38,7 @@ class AttentionBackend(ABC): """The base class of attention backends""" @abstractmethod - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize the forward metadata.""" raise NotImplementedError() @@ -50,7 +51,7 @@ def forward( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Run a forward. @@ -106,7 +107,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """Run a forward for mix.""" raise NotImplementedError() @@ -120,7 +121,7 @@ def forward_decode( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """Run a forward for decode.""" raise NotImplementedError() @@ -134,7 +135,7 @@ def forward_extend( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """Run a forward for extend.""" raise NotImplementedError() diff --git a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py index 5d48f54779..1b37d500eb 100644 --- a/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional +from typing import TYPE_CHECKING, List, Optional, Annotated import paddle @@ -29,6 +29,8 @@ from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims @dataclass @@ -38,15 +40,15 @@ class BlockAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -68,6 +70,9 @@ class BlockAttentionBackend(AttentionBackend): BlockAttentionBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: BlockAttentionBackend + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index d78b444d21..208876fc2d 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field -from typing import List, Optional, TYPE_CHECKING +from typing import List, Optional, Annotated, TYPE_CHECKING import paddle @@ -31,10 +31,13 @@ from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims from fastdeploy.model_executor.layers.attention.ops import ( get_block_shape_and_split_kv_block, gqa_rope_write_cache, init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat) from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id + if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -48,15 +51,15 @@ class FlashAttentionMetadata(AttentionMetadata): set_max_lengths: int = -1 rotary_embs: Optional[paddle.Tensor] = None block_tables: Optional[paddle.Tensor] = None - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None encoder_block_shape_q: Optional[paddle.Tensor] = None decoder_block_shape_q: Optional[paddle.Tensor] = None @@ -81,6 +84,9 @@ class FlashAttentionBackend(AttentionBackend): FlashAttentionBackend backend implementation """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: FlashAttentionMetadata + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ @@ -127,7 +133,7 @@ def get_kv_cache_shape( return (max_num_blocks, self.kv_num_heads, self.block_size, self.head_dim) - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): metadata = FlashAttentionMetadata() metadata.encoder_block_shape_q = 64 metadata.decoder_block_shape_q = 16 @@ -190,7 +196,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ): metadata = self.attention_metadata diff --git a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py index a29d5fe68f..8651ab9ba8 100644 --- a/fastdeploy/model_executor/layers/attention/mla_attention_backend.py +++ b/fastdeploy/model_executor/layers/attention/mla_attention_backend.py @@ -19,7 +19,7 @@ import math import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Annotated import paddle from paddle.nn.functional.flash_attention import flash_attn_unpadded @@ -41,6 +41,8 @@ from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id @@ -59,15 +61,15 @@ class MLAAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -89,6 +91,10 @@ class MLAAttentionBackend(AttentionBackend): MLA Attention Backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: MLAAttentionMetadata + + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int) -> None: """ @@ -137,7 +143,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, self.rank, self.device_id = init_rank_and_device_id(fd_config) - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize attention metadata hence all layers in the forward pass can reuse it.""" metadata = MLAAttentionMetadata() metadata.encoder_block_shape_q = 64 @@ -216,7 +222,7 @@ def forward_extend( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Prefill阶段的前向传播 @@ -271,7 +277,7 @@ def forward_decode( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Decode阶段的前向传播 @@ -367,7 +373,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Mixed模式的前向传播 diff --git a/fastdeploy/model_executor/layers/attention/native_paddle_backend.py b/fastdeploy/model_executor/layers/attention/native_paddle_backend.py index b8f5db6a1d..b286bd98fa 100644 --- a/fastdeploy/model_executor/layers/attention/native_paddle_backend.py +++ b/fastdeploy/model_executor/layers/attention/native_paddle_backend.py @@ -18,11 +18,13 @@ from __future__ import annotations from typing import TYPE_CHECKING + import paddle from paddle.nn.functional import scaled_dot_product_attention from fastdeploy.model_executor.layers.attention.base_attention_backend import \ AttentionBackend + if TYPE_CHECKING: from fastdeploy.model_executor.forward_meta import ForwardMeta @@ -36,7 +38,7 @@ class PaddleNativeAttnBackend(AttentionBackend): def __init__(self) -> None: super().__init__() - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Init the metadata for a forward pass.""" pass @@ -214,7 +216,7 @@ def forward_extend( k: paddle.Tensor, v: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", save_kv_cache: bool = True, ) -> paddle.Tensor: """ @@ -255,7 +257,7 @@ def forward_decode( k: paddle.Tensor, v: paddle.Tensor, layer: paddle.nn.Layer, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ Run the decoding attention forward by using paddle native sdpa op. diff --git a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py index 6c3cade149..8b6d6f3e05 100644 --- a/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/xpu_attn_backend.py @@ -18,7 +18,7 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple, Annotated import paddle @@ -32,7 +32,8 @@ from fastdeploy.model_executor.layers.attention.attention import Attention from fastdeploy.model_executor.layers.attention.base_attention_backend import ( AttentionBackend, AttentionMetadata) - +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \ + DynamicDims @dataclass class XPUAttentionMetadata(AttentionMetadata): @@ -41,15 +42,15 @@ class XPUAttentionMetadata(AttentionMetadata): """ max_len_kv: paddle.Tensor = None set_max_lengths: int = -1 - encoder_batch_ids: paddle.Tensor = None - encoder_tile_ids_per_batch: paddle.Tensor = None - encoder_num_blocks: paddle.Tensor = None - kv_batch_ids: paddle.Tensor = None - kv_tile_ids_per_batch: paddle.Tensor = None - kv_num_blocks: paddle.Tensor = None - decoder_batch_ids: paddle.Tensor = None - decoder_tile_ids_per_batch: paddle.Tensor = None - decoder_num_blocks: paddle.Tensor = None + encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None + decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None _dtype: paddle.dtype = paddle.bfloat16 encoder_max_partition_size: int = 32768 @@ -71,6 +72,9 @@ class XPUAttentionBackend(AttentionBackend): XPUAttentionBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: XPUAttentionMetadata + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ @@ -102,7 +106,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, os.getenv("FLAGS_use_pd_disaggregation", 0)) self.start_layer_index: int = fd_config.model_config.start_layer_index - def init_attention_metadata(self, forward_meta: ForwardMeta): + def init_attention_metadata(self, forward_meta: "ForwardMeta"): """Initialize attntion metadata hence all layers in the forward pass can reuse it.""" metadata = XPUAttentionMetadata() metadata.encoder_block_shape_q = 64 @@ -151,7 +155,7 @@ def forward_mixed( compressed_kv: paddle.Tensor, k_pe: paddle.Tensor, layer: Attention, - forward_meta: ForwardMeta, + forward_meta: "ForwardMeta", ) -> paddle.Tensor: """ forward_mixed diff --git a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py index 00032e26fd..bfc0e4f721 100644 --- a/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/backends/gcu/attention/flash_attn_backend.py @@ -70,6 +70,9 @@ class GCUFlashAttnBackend(AttentionBackend): GCUFlashAttnBackend backend implementation. """ + __infer_dynamic_dims_fields__ = ["attention_metadata"] + attention_metadata: GCUFlashAttnBackend + def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int, head_dim: int): """ diff --git a/fastdeploy/model_executor/models/ernie4_5_moe.py b/fastdeploy/model_executor/models/ernie4_5_moe.py index 3c8e0d8e5b..15337c0209 100644 --- a/fastdeploy/model_executor/models/ernie4_5_moe.py +++ b/fastdeploy/model_executor/models/ernie4_5_moe.py @@ -17,7 +17,7 @@ from __future__ import annotations from functools import partial -from typing import Dict, Union +from typing import Dict, Union, Annotated import numpy as np import paddle @@ -43,7 +43,9 @@ from fastdeploy.model_executor.models.utils import \ LayerIdPlaceholder as layerid from fastdeploy.model_executor.models.utils import WeightMeta - +from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import ( + DynamicDims +) class Ernie4_5_MLP(nn.Layer): @@ -388,7 +390,7 @@ def load_state_dict(self, state_dict): def forward( self, - ids_remove_padding: paddle.Tensor, + ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)], forward_meta: ForwardMeta, ): hidden_states = self.embed_tokens(ids_remove_padding=ids_remove_padding)