17
17
18
18
import torch
19
19
import torch .distributed as dist
20
+ import numpy as np
20
21
21
22
from vllm_ascend .eplb .adaptor .abstract_adaptor import EplbAdaptor
22
23
from vllm .logger import logger
24
+ import random
25
+
23
26
24
27
class VllmEplbAdaptor (EplbAdaptor ):
25
28
@@ -29,6 +32,7 @@ def __init__(self, model, **args):
29
32
self .param_dict = dict (self .model .named_parameters ())
30
33
self .num_dense_layers = self .model .config .first_k_dense_replace
31
34
self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers
35
+ self .global_expert_num = 256
32
36
33
37
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
34
38
self .expert_weight_names = ["w13_weight" , "w2_weight" , "w13_weight_scale" , "w13_weight_offset" ,
@@ -66,18 +70,16 @@ def get_init_expert_map(self, num_moe_layers):
66
70
expert_map = self .model .get_all_expert_map (num_moe_layers )
67
71
if dist .is_initialized ():
68
72
world_size = dist .get_world_size ()
73
+ rank = dist .get_rank ()
69
74
70
- rank = dist .get_rank ()
71
-
72
- tensor_list = [
73
- torch .zeros_like (expert_map ) for _ in range (world_size )
74
- ]
75
+ gathered = torch .empty ((world_size , * expert_map .shape ), # [W, L, E]
76
+ dtype = expert_map .dtype ,
77
+ device = expert_map .device )
75
78
76
- dist .all_gather ( tensor_list , expert_map )
77
- gathered = torch . stack ( tensor_list , dim = 0 )
78
- all_maps = gathered . permute ( 1 , 0 , 2 ). contiguous ()
79
+ dist .all_gather_into_tensor ( gathered , expert_map )
80
+ all_maps = gathered . permute ( 1 , 0 , 2 )
81
+ all_expert_maps = all_maps . cpu ()
79
82
80
- all_expert_maps = all_maps .to (torch .device ("cpu" ))
81
83
return all_expert_maps
82
84
83
85
def do_update_expert_map (self , layer_id , updated_expert_map ):
@@ -105,6 +107,8 @@ def generate_index_dicts(self,tensor_2d):
105
107
return dict_list
106
108
107
109
def generate_log2phy_map (self , expert_map ):
110
+ num_local_experts = expert_map .max () + 1
111
+ expert_map = self .global2local (expert_map ,num_local_experts )
108
112
ranks_num , global_expert_num = expert_map .shape
109
113
concatenated = torch .flatten (expert_map )
110
114
rank_expert_to_global = self .generate_index_dicts (
@@ -116,7 +120,7 @@ def generate_log2phy_map(self, expert_map):
116
120
result_dict [key ] = []
117
121
result_dict [key ].append (idx )
118
122
119
- log2phy_map = torch .full ((ranks_num , global_expert_num ),
123
+ log2phy_map = torch .full ((ranks_num , self . global_expert_num ),
120
124
- 1 ,
121
125
dtype = torch .int32 )
122
126
for rank in range (ranks_num ):
@@ -130,7 +134,27 @@ def generate_log2phy_map(self, expert_map):
130
134
return log2phy_map
131
135
132
136
def do_update_log2phy_map (self , layer_id , updated_log2phy_map ):
133
-
137
+ rank_id = torch . distributed . get_rank ()
134
138
if self .log2phy_map_per_layer [layer_id ] is not None :
135
- rank_id = torch .distributed .get_rank ()
136
- self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map [rank_id ])
139
+ self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map [rank_id ])
140
+
141
+ def global2local (self ,
142
+ placement : torch .Tensor ,
143
+ E_local : int
144
+ ) -> tuple [torch .Tensor , torch .Tensor ]:
145
+
146
+ G , _ = placement .shape
147
+ device = placement .device
148
+
149
+ pt_local = torch .full (( G , E_local ),
150
+ fill_value = - 1 ,
151
+ dtype = torch .long ,
152
+ device = device )
153
+
154
+ valid = placement >= 0
155
+ g_idx , k_idx = valid .nonzero (as_tuple = True )
156
+ slot_idx = placement [g_idx , k_idx ]
157
+
158
+ pt_local [g_idx , slot_idx ] = k_idx
159
+
160
+ return pt_local
0 commit comments