Skip to content

[SOT] Mark dynamic dims by type annotations #2771

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
import logging
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Annotated
from fastdeploy.model_executor.layers.attention import AttentionBackend

from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import (
DynamicDims
)
import paddle


Expand Down Expand Up @@ -58,7 +60,7 @@ class ForwardMeta():
# Input tokens IDs
input_ids: paddle.Tensor
# Input tokens IDs of removed padding
ids_remove_padding: paddle.Tensor
ids_remove_padding: Annotated[paddle.Tensor, DynamicDims(0)]
# Rotation position embedding
rotary_embs: Optional[paddle.Tensor] = None

Expand All @@ -74,9 +76,9 @@ class ForwardMeta():
# Attention mask
attn_mask: Optional[paddle.Tensor] = None
# Decoder batch id. Used by attention backend.
decoder_batch_ids: Optional[paddle.Tensor] = None
decoder_batch_ids: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None
# Tile ID for each batch of the decoder. Used by attention backend.
decoder_tile_ids_per_batch: Optional[paddle.Tensor] = None
decoder_tile_ids_per_batch: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None

# Sequence length of encoder for ever batch
seq_lens_encoder: Optional[paddle.Tensor] = None
Expand All @@ -88,7 +90,7 @@ class ForwardMeta():
# Accumulated offset
cum_offsets: Optional[paddle.Tensor] = None
# Offset tensor, used to restore the position of ids_remove_madding after padding removal to the original input_ids
padding_offset: Optional[paddle.Tensor] = None
padding_offset: Optional[Annotated[paddle.Tensor, DynamicDims(0)]] = None
# Accumulated sequence length of query
cu_seqlens_q: Optional[paddle.Tensor] = None
# Accumulated sequence length of key
Expand All @@ -99,7 +101,7 @@ class ForwardMeta():
# Block tables
block_tables: Optional[paddle.Tensor] = None
# KV caches
caches: Optional[list[paddle.Tensor]] = None
caches: Optional[list[Annotated[paddle.Tensor, DynamicDims(0)]]] = None

def clear_caches(self):
""" Safely clean up the caches """
Expand Down
155 changes: 155 additions & 0 deletions fastdeploy/model_executor/graph_optimization/dynamic_dims_marker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

import dataclasses
import inspect
import typing
from abc import ABC, abstractmethod
from collections.abc import Callable
from functools import partial
from typing import (Annotated, Any, Optional, TypeVar, Union, get_origin,
get_type_hints)

import paddle
from paddle import Tensor
from typing_extensions import TypeAlias

T = TypeVar("T")
U = TypeVar("U")

Accessor: TypeAlias = Callable[[T], U]


class DynamicDims:
def __init__(self, dims: int | tuple[int]):
self.dims = dims if isinstance(dims, tuple) else (dims,)

def __repr__(self):
return f"DynamicDims({self.dims})"


class DynamicDimTypeResolver:
ALL_DYNAMIC_DIM_TYPE_RESOLVERS = []

@classmethod
def register_resolver(cls, resolver_cls: type[DynamicDimTypeResolver]):
cls.ALL_DYNAMIC_DIM_TYPE_RESOLVERS.append(resolver_cls())
return resolver_cls

@abstractmethod
def check(self, tp) -> bool:
raise NotImplementedError

@abstractmethod
def extract_inner_types(
self, data, data_name, tp
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
raise NotImplementedError

def resolve(self, data, data_name, tp) -> None:
inner_types = self.extract_inner_types(data, data_name, tp)
for accessor, inner_data_name, inner_type in inner_types:
self.generic_resolve(accessor(data), inner_data_name, inner_type)

def generic_resolve(self, data, data_name, tp) -> None:
# assert isinstance(data, tp), f"Expected {data_name} has type {tp}, but got {type(data)}"
for resolver in self.ALL_DYNAMIC_DIM_TYPE_RESOLVERS:
if resolver.check(tp):
return resolver.resolve(data, data_name, tp)
runtime_tp = type(data)
if runtime_tp is not tp and resolver.check(runtime_tp):
return resolver.resolve(data, data_name, runtime_tp)
else:
print(f"No resolver found for type {tp} and data {data_name}")


@DynamicDimTypeResolver.register_resolver
class DataClassDynamicDimTypeResolver(DynamicDimTypeResolver):
def check(self, tp) -> bool:
return dataclasses.is_dataclass(tp) and isinstance(tp, type)

def extract_inner_types(
self, data, data_name, tp
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
type_hints = get_type_hints(tp, include_extras=True)
return [ # type: ignore
(
# bind name by partial to avoid capture wrong free vars
partial(lambda name, dt: getattr(dt, name), field.name),
f"{data_name}.{field.name}",
type_hints[field.name],
)
for field in dataclasses.fields(tp)
]


@DynamicDimTypeResolver.register_resolver
class OptionalDynamicDimTypeResolver(DynamicDimTypeResolver):
def check(self, tp) -> bool:
return (
get_origin(tp) is Union
and len(tp.__args__) == 2
and tp.__args__[1] is type(None)
)

def extract_inner_types(
self, data, data_name, tp
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if data is None:
return []
inner_type = tp.__args__[0]
return [(lambda x: x, data_name, inner_type)] # No accessor needed for Optional


@DynamicDimTypeResolver.register_resolver
class ListDynamicDimTypeResolver(DynamicDimTypeResolver):
def check(self, tp) -> bool:
return get_origin(tp) is list

def extract_inner_types(
self, data, data_name, tp
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
if not data:
return []
inner_type = typing.get_args(tp)[0] if tp.__args__ else Any
return [(partial(lambda i, x: x[i], i), f"{data_name}[{i}]", inner_type) for i in range(len(data))] # type: ignore

@DynamicDimTypeResolver.register_resolver
class ManualMarkedInnerFieldsDynamicDimTypeResolver(DynamicDimTypeResolver):
INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME = "__infer_dynamic_dims_fields__"
def check(self, tp) -> bool:
return hasattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)

def extract_inner_types(
self, data, data_name, tp
) -> list[tuple[Accessor[Any, Any], str, type[Any]]]:
fields = getattr(tp, ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME)
if isinstance(fields, str):
raise TypeError(f"{ManualMarkedInnerFieldsDynamicDimTypeResolver.INFER_DYNAMIC_DIMS_FIELDS_ATTR_NAME} should be tuple, but got {type(fields)}")
inner_types_dict = typing.get_type_hints(tp)
return [(partial(lambda name, x: getattr(x, name), field_name), f"{data_name}.{field_name}", inner_type) for field_name, inner_type in inner_types_dict.items()]

@DynamicDimTypeResolver.register_resolver
class AnnotatedTensorDynamicDimTypeResolver(DynamicDimTypeResolver):
def check(self, tp) -> bool:
return get_origin(tp) is Annotated and typing.get_args(tp)[0] is Tensor

def resolve(self, data, data_name, tp) -> None:
base_type, *metadata = typing.get_args(tp)
# Filter out DynamicDims instances
dynamic_dims = [m for m in metadata if isinstance(m, DynamicDims)]
if not dynamic_dims:
return
if len(dynamic_dims) > 1:
raise ValueError(
"Multiple DynamicDims annotations found. Only one is allowed."
)
dynamic_dims = dynamic_dims[0].dims
if not isinstance(data, Tensor):
raise TypeError(f"data {data_name} has type annotation Tensor but got type {type(data)}")
print(f"data {data_name} has dynamic dims {dynamic_dims} for type {tp}")
paddle.jit.marker.dynamic_dims(
data, dynamic_dims
)

def resolve_dynamic_dims(arg, arg_name, annotation):
DynamicDimTypeResolver().generic_resolve(arg, arg_name, annotation)
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,66 @@
# limitations under the License.
"""

from typing import Callable, Optional

from paddle.jit.dy2static.utils import Backend
import functools
import inspect
import types
from typing import Callable, Optional, get_type_hints

import paddle
from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.cudagraph_piecewise_backend import \
CudaGraphPiecewiseBackend
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \
resolve_dynamic_dims
from paddle.jit import sot
from paddle.jit.dy2static.utils import Backend


def apply_to_static_optimization(fn):
forward_fn = fn
forward_sig = inspect.signature(forward_fn)
# forward_annotations = inspect.get_annotations(forward_fn)
forward_type_hints = get_type_hints(forward_fn)
static_forward_fn = sot.symbolic_translate(
forward_fn, training=False, backend=Backend.PHI
)
unsafe_static_forward_fn = None

@functools.wraps(forward_fn)
def static_forward(self, *args, **kwargs):
nonlocal unsafe_static_forward_fn
if unsafe_static_forward_fn is not None:
return unsafe_static_forward_fn(self, *args, **kwargs)
bound_args = forward_sig.bind(self, *args, **kwargs)
bound_args.apply_defaults()
for name, arg in bound_args.arguments.items():
if name not in forward_type_hints:
continue
annotation = forward_type_hints[name]
resolve_dynamic_dims(arg, name, annotation)

# print(f"Processing argument '{name}' with annotation: {annotation}")
# if isinstance(arg, paddle.Tensor):
# print(
# f"Argument '{name}' is a Tensor with dynamic dims: {extract_dynamic_dims(annotation)}"
# )
# paddle.jit.marker.dynamic_dims(
# arg, extract_dynamic_dims(annotation)
# )
result = static_forward_fn(self, *args, **kwargs)
original_code = forward_fn.__code__
(new_guarded_codes, _) = sot.opcode_translator.executor.executor_cache.OpcodeExecutorCache().cache[original_code]
new_code = new_guarded_codes[0][0][0]
unsafe_static_forward_fn = types.FunctionType(
new_code,
forward_fn.__globals__,
forward_fn.__name__,
forward_fn.__defaults__,
forward_fn.__closure__,
)
return result

return static_forward


class GraphOptBackend:
Expand All @@ -46,9 +99,10 @@ def __init__(self, runnable: Callable, fd_config: FDConfig):
backend = (Backend.CINN
if self.fd_config.graph_opt_config.graph_opt_level > 1
else Backend.PHI)
self.runnable = sot.symbolic_translate(self.runnable,
training=False,
backend=backend)
# self.runnable = sot.symbolic_translate(self.runnable,
# training=False,
# backend=backend)
self.runnable = apply_to_static_optimization(self.runnable.__func__).__get__(self.runnable.__self__)

def __call__(self, **kwargs):
if not self.fd_config.graph_opt_config.use_cudagraph:
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/model_executor/layers/attention/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from .mla_attention_backend import MLAAttentionBackend
from .native_paddle_backend import PaddleNativeAttnBackend
from .xpu_attn_backend import XPUAttentionBackend
from .attention import Attention
from .iluvatar_attn_backend import IluvatarAttnBackend
from .block_multihead_attn_backend import BlockAttentionBackend

__all__ = [
"AttentionBackend", "PaddleNativeAttnBackend",
"get_attention_backend", "AppendAttentionBackend", "XPUAttentionBackend",
"MLAAttentionBackend", "FlashAttentionBackend", "IluvatarAttnBackend",
"BlockAttentionBackend"
"BlockAttentionBackend", "Attention"
]
30 changes: 17 additions & 13 deletions fastdeploy/model_executor/layers/attention/append_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List, Optional, Tuple
from typing import TYPE_CHECKING, Annotated, List, Optional, Tuple

import paddle

from fastdeploy.model_executor.layers.attention.ops import (
append_attention, get_block_shape_and_split_kv_block,
init_signal_layerwise, open_shm_and_get_meta_signal)
Expand All @@ -30,6 +29,8 @@
from fastdeploy.model_executor.forward_meta import ForwardMeta

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.graph_optimization.dynamic_dims_marker import \
DynamicDims
from fastdeploy.model_executor.layers.attention.attention import Attention
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
AttentionBackend, AttentionMetadata)
Expand All @@ -43,15 +44,15 @@ class AppendAttentionMetadata(AttentionMetadata):
"""
max_len_kv: paddle.Tensor = None
set_max_lengths: int = -1
encoder_batch_ids: paddle.Tensor = None
encoder_tile_ids_per_batch: paddle.Tensor = None
encoder_num_blocks: paddle.Tensor = None
kv_batch_ids: paddle.Tensor = None
kv_tile_ids_per_batch: paddle.Tensor = None
kv_num_blocks: paddle.Tensor = None
decoder_batch_ids: paddle.Tensor = None
decoder_tile_ids_per_batch: paddle.Tensor = None
decoder_num_blocks: paddle.Tensor = None
encoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None
encoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None
encoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None
kv_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None
kv_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None
kv_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None
decoder_batch_ids: Annotated[paddle.Tensor, DynamicDims(0)] = None
decoder_tile_ids_per_batch: Annotated[paddle.Tensor, DynamicDims(0)] = None
decoder_num_blocks: Annotated[paddle.Tensor, DynamicDims(0)] = None

_dtype: paddle.dtype = paddle.bfloat16
encoder_max_partition_size: int = 32768
Expand All @@ -73,6 +74,9 @@ class AppendAttentionBackend(AttentionBackend):
AppendAttentionBackend backend implementation.
"""

__infer_dynamic_dims_fields__ = ["attention_metadata"]
attention_metadata: AppendAttentionMetadata

def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,
head_dim: int) -> None:
"""
Expand Down Expand Up @@ -109,7 +113,7 @@ def __init__(self, fd_config: FDConfig, kv_num_heads: int, num_heads: int,

self.rank, self.device_id = init_rank_and_device_id(fd_config)

def init_attention_metadata(self, forward_meta: ForwardMeta):
def init_attention_metadata(self, forward_meta: "ForwardMeta"):
"""Initialize attntion metadata hence all layers in the forward pass can reuse it."""
metadata = AppendAttentionMetadata()
metadata.encoder_block_shape_q = 64
Expand Down Expand Up @@ -184,7 +188,7 @@ def forward_mixed(
compressed_kv: paddle.Tensor,
k_pe: paddle.Tensor,
layer: Attention,
forward_meta: ForwardMeta,
forward_meta: "ForwardMeta",
) -> paddle.Tensor:
"""
forward_mixed
Expand Down
Loading