13
13
from .args import TransformerModelArgs
14
14
15
15
16
+ # TODO: keeping this for-loop implementation for comparison
17
+ # and readability, may remove later
18
+ @expert_parallel
19
+ def _run_experts_for_loop (
20
+ w1 : torch .Tensor ,
21
+ w2 : torch .Tensor ,
22
+ w3 : torch .Tensor ,
23
+ x : torch .Tensor ,
24
+ num_tokens_per_expert : torch .Tensor | None = None ,
25
+ ) -> torch .Tensor :
26
+ if num_tokens_per_expert is not None :
27
+ # NOTE: this would incur a synchronization between device and host
28
+ num_tokens_per_expert = num_tokens_per_expert .tolist ()
29
+
30
+ # side-effect code due to the usage of generate_permute_indices
31
+ num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
32
+
33
+ # a tuple of tensors indexed by experts
34
+ # each with shape (tokens_per_expert(varying), dim)
35
+ x = torch .split (
36
+ x [: sum (num_tokens_per_expert )],
37
+ split_size_or_sections = num_tokens_per_expert ,
38
+ dim = 0 ,
39
+ )
40
+ out_experts_splits = []
41
+ for expert_idx , x_expert in enumerate (x ):
42
+ h = F .silu (torch .matmul (x_expert , w1 [expert_idx ]))
43
+ h = h * torch .matmul (x_expert , w3 [expert_idx ])
44
+ h = torch .matmul (h , w2 [expert_idx ])
45
+ # h shape (tokens_per_expert(varying), dim)
46
+ out_experts_splits .append (h )
47
+ out = torch .cat (out_experts_splits , dim = 0 )
48
+
49
+ # side-effect code due to the usage of generate_permute_indices
50
+ out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
51
+ else :
52
+ # x shape (num_experts, tokens_per_expert, dim)
53
+ h = F .silu (torch .bmm (x , w1 ))
54
+ h = h * torch .bmm (x , w3 )
55
+ # out shape (num_experts, tokens_per_expert, dim)
56
+ out = torch .bmm (h , w2 )
57
+
58
+ return out
59
+
60
+
61
+ @expert_parallel
62
+ def _run_experts_grouped_mm (
63
+ w1 : torch .Tensor ,
64
+ w2 : torch .Tensor ,
65
+ w3 : torch .Tensor ,
66
+ x : torch .Tensor ,
67
+ num_tokens_per_expert : torch .Tensor | None = None ,
68
+ ) -> torch .Tensor :
69
+ if num_tokens_per_expert is not None :
70
+ offsets = torch .cumsum (num_tokens_per_expert , dim = 0 , dtype = torch .int32 )
71
+ # grouped mm between a 2D tensor and a 3D tensor
72
+ assert x .dim () == 2
73
+ else :
74
+ offsets = None
75
+ # fall back to regular bmm between 3D tensors
76
+ assert x .dim () == 3
77
+
78
+ h = F .silu (torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 (), offs = offsets ))
79
+ h = h * torch ._grouped_mm (x .bfloat16 (), w3 .bfloat16 (), offs = offsets )
80
+ out = torch ._grouped_mm (h , w2 .bfloat16 (), offs = offsets ).type_as (x )
81
+
82
+ return out
83
+
84
+
16
85
class GroupedExperts (nn .Module ):
17
86
def __init__ (
18
87
self ,
@@ -34,83 +103,14 @@ def forward(
34
103
num_tokens_per_expert : torch .Tensor | None = None ,
35
104
) -> torch .Tensor :
36
105
if self .use_grouped_mm :
37
- return GroupedExperts . _run_experts_grouped_mm (
106
+ return _run_experts_grouped_mm (
38
107
self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
39
108
)
40
109
else :
41
- return GroupedExperts . _run_experts_for_loop (
110
+ return _run_experts_for_loop (
42
111
self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
43
112
)
44
113
45
- # TODO: keeping this for-loop implementation for comparison
46
- # and readability, may remove later
47
- @expert_parallel
48
- @staticmethod
49
- def _run_experts_for_loop (
50
- w1 : torch .Tensor ,
51
- w2 : torch .Tensor ,
52
- w3 : torch .Tensor ,
53
- x : torch .Tensor ,
54
- num_tokens_per_expert : torch .Tensor | None = None ,
55
- ) -> torch .Tensor :
56
- if num_tokens_per_expert is not None :
57
- # NOTE: this would incur a synchronization between device and host
58
- num_tokens_per_expert = num_tokens_per_expert .tolist ()
59
-
60
- # side-effect code due to the usage of generate_permute_indices
61
- num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
62
-
63
- # a tuple of tensors indexed by experts
64
- # each with shape (tokens_per_expert(varying), dim)
65
- x = torch .split (
66
- x [: sum (num_tokens_per_expert )],
67
- split_size_or_sections = num_tokens_per_expert ,
68
- dim = 0 ,
69
- )
70
- out_experts_splits = []
71
- for expert_idx , x_expert in enumerate (x ):
72
- h = F .silu (torch .matmul (x_expert , w1 [expert_idx ]))
73
- h = h * torch .matmul (x_expert , w3 [expert_idx ])
74
- h = torch .matmul (h , w2 [expert_idx ])
75
- # h shape (tokens_per_expert(varying), dim)
76
- out_experts_splits .append (h )
77
- out = torch .cat (out_experts_splits , dim = 0 )
78
-
79
- # side-effect code due to the usage of generate_permute_indices
80
- out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
81
- else :
82
- # x shape (num_experts, tokens_per_expert, dim)
83
- h = F .silu (torch .bmm (x , w1 ))
84
- h = h * torch .bmm (x , w3 )
85
- # out shape (num_experts, tokens_per_expert, dim)
86
- out = torch .bmm (h , w2 )
87
-
88
- return out
89
-
90
- @expert_parallel
91
- @staticmethod
92
- def _run_experts_grouped_mm (
93
- w1 : torch .Tensor ,
94
- w2 : torch .Tensor ,
95
- w3 : torch .Tensor ,
96
- x : torch .Tensor ,
97
- num_tokens_per_expert : torch .Tensor | None = None ,
98
- ) -> torch .Tensor :
99
- if num_tokens_per_expert is not None :
100
- offsets = torch .cumsum (num_tokens_per_expert , dim = 0 , dtype = torch .int32 )
101
- # grouped mm between a 2D tensor and a 3D tensor
102
- assert x .dim () == 2
103
- else :
104
- offsets = None
105
- # fall back to regular bmm between 3D tensors
106
- assert x .dim () == 3
107
-
108
- h = F .silu (torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 (), offs = offsets ))
109
- h = h * torch ._grouped_mm (x .bfloat16 (), w3 .bfloat16 (), offs = offsets )
110
- out = torch ._grouped_mm (h , w2 .bfloat16 (), offs = offsets ).type_as (x )
111
-
112
- return out
113
-
114
114
def init_weights (self , init_std : float ):
115
115
nn .init .trunc_normal_ (self .w1 , mean = 0.0 , std = 0.02 )
116
116
nn .init .trunc_normal_ (self .w2 , mean = 0.0 , std = init_std )
0 commit comments