Skip to content

Commit 5907126

Browse files
authored
[Executor] Move forward_meta.py to fastdeploy/model_executor (#2774)
* Use PEP 563 in attention.py and fix conflict * merge commit * Change what was left out last time
1 parent 8c660a0 commit 5907126

27 files changed

+53
-55
lines changed

fastdeploy/worker/forward_meta.py renamed to fastdeploy/model_executor/forward_meta.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,10 @@
1818
from dataclasses import dataclass
1919
from enum import IntEnum, auto
2020
from typing import TYPE_CHECKING, Optional
21+
from fastdeploy.model_executor.layers.attention import AttentionBackend
2122

2223
import paddle
23-
24-
if TYPE_CHECKING:
25-
from fastdeploy.model_executor.layers.attention import AttentionBackend
24+
2625

2726
logger = logging.getLogger(__name__)
2827

@@ -69,7 +68,7 @@ class ForwardMeta():
6968
is_decode_batch: bool = False
7069

7170
# Attention backend object
72-
attn_backend: 'AttentionBackend' = None
71+
attn_backend: AttentionBackend = None
7372
# Forward mode used during attention
7473
forward_mode: ForwardMode = ForwardMode.MIXED
7574
# Attention mask
@@ -100,7 +99,7 @@ class ForwardMeta():
10099
# Block tables
101100
block_tables: Optional[paddle.Tensor] = None
102101
# KV caches
103-
caches: Optional[paddle.Tensor] = None
102+
caches: Optional[list[paddle.Tensor]] = None
104103

105104
def clear_caches(self):
106105
""" Safely clean up the caches """

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@
2727
init_signal_layerwise, open_shm_and_get_meta_signal)
2828

2929
if TYPE_CHECKING:
30-
from paddle._typing.dtype_like import _DTypeLiteral
30+
from fastdeploy.model_executor.forward_meta import ForwardMeta
3131

3232
from fastdeploy.config import FDConfig
3333
from fastdeploy.model_executor.layers.attention.attention import Attention
3434
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
3535
AttentionBackend, AttentionMetadata)
3636
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
37-
from fastdeploy.worker.forward_meta import ForwardMeta
3837

3938

4039
@dataclass
@@ -54,7 +53,7 @@ class AppendAttentionMetadata(AttentionMetadata):
5453
decoder_tile_ids_per_batch: paddle.Tensor = None
5554
decoder_num_blocks: paddle.Tensor = None
5655

57-
_dtype: _DTypeLiteral = paddle.bfloat16
56+
_dtype: paddle.dtype = paddle.bfloat16
5857
encoder_max_partition_size: int = 32768
5958
max_partition_size: int = 32768
6059
block_tables: Optional[paddle.Tensor] = None

fastdeploy/model_executor/layers/attention/attention.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414
# limitations under the License.
1515
"""
1616

17-
from typing import Dict, Optional
17+
from __future__ import annotations
18+
19+
from typing import TYPE_CHECKING, Dict, Optional
1820

1921
import numpy as np
2022
import paddle
@@ -24,7 +26,8 @@
2426
from fastdeploy.config import FDConfig
2527
from fastdeploy.model_executor.layers.quantization.quant_base import \
2628
QuantMethodBase
27-
from fastdeploy.worker.forward_meta import ForwardMeta
29+
if TYPE_CHECKING:
30+
from fastdeploy.model_executor.forward_meta import ForwardMeta
2831

2932

3033
class Attention(nn.Layer):

fastdeploy/model_executor/layers/attention/base_attention_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@
2121

2222
from abc import ABC, abstractmethod
2323
from dataclasses import dataclass
24+
from typing import TYPE_CHECKING
2425

2526
import paddle
26-
27-
from fastdeploy.worker.forward_meta import ForwardMeta
27+
if TYPE_CHECKING:
28+
from fastdeploy.model_executor.forward_meta import ForwardMeta
2829

2930

3031
@dataclass

fastdeploy/model_executor/layers/attention/block_multihead_attn_backend.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@
2323
import paddle
2424

2525
if TYPE_CHECKING:
26-
from paddle._typing.dtype_like import _DTypeLiteral
26+
from fastdeploy.model_executor.forward_meta import ForwardMeta
2727

2828
from fastdeploy.config import FDConfig
2929
from fastdeploy.model_executor.layers.attention.attention import Attention
3030
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
3131
AttentionBackend, AttentionMetadata)
32-
from fastdeploy.worker.forward_meta import ForwardMeta
32+
3333

3434
@dataclass
3535
class BlockAttentionMetadata(AttentionMetadata):
@@ -48,7 +48,7 @@ class BlockAttentionMetadata(AttentionMetadata):
4848
decoder_tile_ids_per_batch: paddle.Tensor = None
4949
decoder_num_blocks: paddle.Tensor = None
5050

51-
_dtype: _DTypeLiteral = paddle.bfloat16
51+
_dtype: paddle.dtype = paddle.bfloat16
5252
encoder_max_partition_size: int = 32768
5353
max_partition_size: int = 32768
5454
block_tables: Optional[paddle.Tensor] = None

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import os
2020
from dataclasses import dataclass, field
21-
from typing import List, Optional
21+
from typing import List, Optional, TYPE_CHECKING
2222

2323
import paddle
2424

@@ -35,7 +35,8 @@
3535
get_block_shape_and_split_kv_block, gqa_rope_write_cache,
3636
init_signal_layerwise, open_shm_and_get_meta_signal, pre_cache_len_concat)
3737
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
38-
from fastdeploy.worker.forward_meta import ForwardMeta
38+
if TYPE_CHECKING:
39+
from fastdeploy.model_executor.forward_meta import ForwardMeta
3940

4041

4142
@dataclass

fastdeploy/model_executor/layers/attention/iluvatar_attn_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import paddle
2121

2222
from dataclasses import dataclass
23-
from typing import Optional
23+
from typing import Optional, TYPE_CHECKING
2424
from math import sqrt
2525

2626
from paddle.nn.functional.flash_attention import flash_attn_unpadded
@@ -30,7 +30,8 @@
3030
from fastdeploy.model_executor.layers.attention.attention import Attention
3131
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
3232
AttentionBackend, AttentionMetadata)
33-
from fastdeploy.worker.forward_meta import ForwardMeta
33+
if TYPE_CHECKING:
34+
from fastdeploy.model_executor.forward_meta import ForwardMeta
3435

3536

3637
@dataclass

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,13 @@
3535
prefill_mla_write_cache)
3636

3737
if TYPE_CHECKING:
38-
from paddle._typing.dtype_like import _DTypeLiteral
38+
from fastdeploy.model_executor.forward_meta import ForwardMeta
3939

4040
from fastdeploy.config import FDConfig
4141
from fastdeploy.model_executor.layers.attention.attention import Attention
4242
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
4343
AttentionBackend, AttentionMetadata)
44-
from fastdeploy.model_executor.layers.attention.utils import \
45-
init_rank_and_device_id
46-
from fastdeploy.worker.forward_meta import ForwardMeta
44+
from fastdeploy.model_executor.layers.attention.utils import init_rank_and_device_id
4745

4846

4947
def yarn_get_mscale(scale=1, mscale=1):
@@ -71,7 +69,7 @@ class MLAAttentionMetadata(AttentionMetadata):
7169
decoder_tile_ids_per_batch: paddle.Tensor = None
7270
decoder_num_blocks: paddle.Tensor = None
7371

74-
_dtype: _DTypeLiteral = paddle.bfloat16
72+
_dtype: paddle.dtype = paddle.bfloat16
7573
encoder_max_partition_size: int = 32768
7674
max_partition_size: int = 32768
7775
block_tables: Optional[paddle.Tensor] = None

fastdeploy/model_executor/layers/attention/native_paddle_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
from __future__ import annotations
1919

20+
from typing import TYPE_CHECKING
2021
import paddle
2122
from paddle.nn.functional import scaled_dot_product_attention
2223

2324
from fastdeploy.model_executor.layers.attention.base_attention_backend import \
2425
AttentionBackend
25-
from fastdeploy.worker.forward_meta import ForwardMeta
26+
if TYPE_CHECKING:
27+
from fastdeploy.model_executor.forward_meta import ForwardMeta
2628

2729

2830
class PaddleNativeAttnBackend(AttentionBackend):

fastdeploy/model_executor/layers/attention/xpu_attn_backend.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@
2626
init_signal_layerwise, open_shm_and_get_meta_signal)
2727

2828
if TYPE_CHECKING:
29-
from paddle._typing.dtype_like import _DTypeLiteral
29+
from fastdeploy.model_executor.forward_meta import ForwardMeta
3030

3131
from fastdeploy.config import FDConfig
3232
from fastdeploy.model_executor.layers.attention.attention import Attention
3333
from fastdeploy.model_executor.layers.attention.base_attention_backend import (
3434
AttentionBackend, AttentionMetadata)
35-
from fastdeploy.worker.forward_meta import ForwardMeta
3635

3736

3837
@dataclass
@@ -52,7 +51,7 @@ class XPUAttentionMetadata(AttentionMetadata):
5251
decoder_tile_ids_per_batch: paddle.Tensor = None
5352
decoder_num_blocks: paddle.Tensor = None
5453

55-
_dtype: _DTypeLiteral = paddle.bfloat16
54+
_dtype: paddle.dtype = paddle.bfloat16
5655
encoder_max_partition_size: int = 32768
5756
max_partition_size: int = 32768
5857
block_tables: Optional[paddle.Tensor] = None

0 commit comments

Comments
 (0)