1
+ #
2
+ # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # This file is a part of the vllm-ascend project.
16
+ #
17
+ import os
18
+ import json
19
+ import torch
20
+ import random
21
+ import torch .distributed as dist
22
+ import numpy as np
23
+
24
+ from vllm_ascend .eplb .adaptor .abstract_adaptor import EplbAdaptor
25
+ from vllm .logger import logger
26
+
27
+
28
+
29
+ class VllmEplbAdaptor (EplbAdaptor ):
30
+
31
+ def __init__ (self , model , ** args ):
32
+ super ().__init__ (** args )
33
+ self .model = model
34
+ self .rank_id = dist .get_rank ()
35
+ self .world_size = dist .get_world_size ()
36
+ self .param_dict = dict (self .model .named_parameters ())
37
+ self .num_dense_layers = self .model .config .first_k_dense_replace
38
+ self .num_moe_layers = self .model .config .num_hidden_layers - self .num_dense_layers
39
+ self .global_expert_num = self .model .config .n_routed_experts
40
+
41
+
42
+ # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 is supported here
43
+ self .expert_weight_names = ["w13_weight" , "w2_weight" , "w13_weight_scale" , "w13_weight_offset" ,
44
+ "w2_weight_scale" , "w2_weight_offset" ]
45
+
46
+ self .expert_map_per_layer = dict () # reference to expert map on device for expert map update
47
+ self .expert_map_per_layer_cpu = dict () # copy of expert map on CPU to avoid device synchronize frequently
48
+ for layer_idx in range (self .num_moe_layers ):
49
+ self .expert_map_per_layer [self .num_dense_layers + layer_idx ] = \
50
+ self .model .get_expert_map (self .num_dense_layers + layer_idx )
51
+
52
+ # TODO: here we set number of buffer tensor equal to number of expert in each laryer, which can be improved
53
+ num_buffer_tensor = torch .where (self .expert_map_per_layer [self .num_dense_layers ] != - 1 )[0 ].numel ()
54
+ self .buffer_tensor_list = [[] for _ in range (num_buffer_tensor )]
55
+ self .init_buffer_tensor (num_buffer_tensor )
56
+
57
+ self .expert_param_per_layer = dict ()
58
+ self .init_expert_param_per_layer ()
59
+
60
+ self .log2phy_map_per_layer = dict ()
61
+ for layer_idx in range (self .num_moe_layers ):
62
+ self .log2phy_map_per_layer [self .num_dense_layers + layer_idx ] = \
63
+ self .model .get_log2phy_map (self .num_dense_layers + layer_idx )
64
+
65
+ self .all_topk_ids = []
66
+
67
+ def init_buffer_tensor (self , num_buffer_tensor ):
68
+ for name in self .expert_weight_names :
69
+ complete_name = "model.layers." + str (self .num_dense_layers ) + ".mlp.experts." + name
70
+ expert_tensor = self .param_dict [complete_name ].data [0 :num_buffer_tensor ]
71
+ buffer_tensors = torch .empty_like (expert_tensor )
72
+ for buffer_id in range (num_buffer_tensor ):
73
+ self .buffer_tensor_list [buffer_id ].append (buffer_tensors [buffer_id ])
74
+
75
+ def init_expert_param_per_layer (self ):
76
+ num_local_expert = self .param_dict ["model.layers." + str (self .num_dense_layers ) + \
77
+ ".mlp.experts." + self .expert_weight_names [0 ]].data .shape [0 ]
78
+ for moe_layer_id in range (self .num_moe_layers ):
79
+ layer_idx = self .num_dense_layers + moe_layer_id
80
+ self .expert_param_per_layer [layer_idx ] = list ()
81
+ for local_expert_id in range (num_local_expert ):
82
+ self .expert_param_per_layer [layer_idx ].append (
83
+ [self .param_dict ["model.layers." + str (layer_idx ) + ".mlp.experts." + name ].data [local_expert_id ]
84
+ for name in self .expert_weight_names ]
85
+ )
86
+
87
+ # def collect_topk_ids(self, dummy_run=False):
88
+ # if dummy_run:
89
+ # return
90
+ # self.all_topk_ids.append(self.model.get_all_topk_ids(self.num_moe_layers))
91
+
92
+ def get_rank_expert_workload (self ) -> torch .Tensor :
93
+ self .moe_load = self .model .get_all_moe_loads ()
94
+ return self .moe_load
95
+
96
+ def get_init_expert_map (self , num_moe_layers ):
97
+ expert_map = self .model .get_all_expert_map (num_moe_layers )
98
+ if dist .is_initialized ():
99
+ world_size = dist .get_world_size ()
100
+ rank = dist .get_rank ()
101
+
102
+ gathered = torch .empty ((world_size , * expert_map .shape ), # [W, L, E]
103
+ dtype = expert_map .dtype ,
104
+ device = expert_map .device )
105
+
106
+ dist .all_gather_into_tensor (gathered , expert_map )
107
+ all_maps = gathered .permute (1 , 0 , 2 )
108
+ all_expert_maps = all_maps .cpu ()
109
+
110
+ for layer_idx in range (num_moe_layers ):
111
+ self .expert_map_per_layer_cpu [self .num_dense_layers + layer_idx ] = \
112
+ all_expert_maps [layer_idx ][self .rank_id ]
113
+
114
+ return all_expert_maps
115
+
116
+ def get_init_expert_map_from_file (self , num_moe_layers , expert_map_path ):
117
+
118
+ try :
119
+ expert_map_tensor , layers_num , ranks_num = self ._expert_file_to_tensor (expert_map_path )
120
+ expert_map_all = self .local2global (expert_map_tensor )
121
+ except (TypeError , FileNotFoundError , OSError ):
122
+ expert_map_all = self .determine_expert_map_all ()
123
+
124
+ for layer_idx in range (num_moe_layers ):
125
+ self .expert_map_per_layer_cpu [layer_idx + 3 ] = \
126
+ expert_map_all [layer_idx ][self .rank_id ]
127
+ return expert_map_all
128
+
129
+ def _expert_file_to_tensor (self , expert_map_path : str ):
130
+ with open (expert_map_path , "r" ) as f :
131
+ data = json .load (f )
132
+ layers_num = data ["moe_layer_count" ]
133
+ gpus_num = data ["layer_list" ][0 ]["device_count" ]
134
+
135
+ tensor_data = []
136
+ for layer in data ["layer_list" ]:
137
+ device_data = []
138
+ for device in layer ["device_list" ]:
139
+ device_data .append (device ["device_expert" ])
140
+ tensor_data .append (device_data )
141
+ expert_map_tensor = torch .tensor (tensor_data , dtype = torch .int32 )
142
+ return expert_map_tensor , layers_num , gpus_num
143
+ logger .error (f"failed to read expert_map_path: { expert_map_path } " )
144
+
145
+ def do_update_expert_map (self , layer_id , updated_expert_map ):
146
+ self .expert_map_per_layer [layer_id ].copy_ (updated_expert_map )
147
+ self .expert_map_per_layer_cpu [layer_id ].copy_ (updated_expert_map )
148
+
149
+ def do_update_expert_weight (self , layer_id , local_expert_to_replace , buffer_tensor_id ):
150
+ for expert_tensor , buffer_tensor in zip (
151
+ self .expert_param_per_layer [layer_id ][local_expert_to_replace ],
152
+ self .buffer_tensor_list [buffer_tensor_id ]
153
+ ):
154
+ expert_tensor .copy_ (buffer_tensor )
155
+
156
+ def do_update_log2phy_map (self , layer_id , updated_log2phy_map ):
157
+ if self .log2phy_map_per_layer [layer_id ] is not None :
158
+ self .log2phy_map_per_layer [layer_id ].copy_ (updated_log2phy_map [self .rank_id ])
159
+
160
+ def local2global (self ,
161
+ placement_local : torch .Tensor
162
+ ) -> torch .Tensor :
163
+
164
+ L , G , E_local = placement_local .shape
165
+ device = placement_local .device
166
+
167
+ max_id = torch .max (placement_local )
168
+ E_global = (max_id + 1 ).item () if max_id >= 0 else 0
169
+
170
+ if E_global == 0 :
171
+ return torch .empty ((L , G , 0 ), dtype = torch .long , device = device )
172
+
173
+ placement_global = torch .full ((L , G , E_global ),
174
+ fill_value = - 1 ,
175
+ dtype = torch .long ,
176
+ device = device )
177
+
178
+ valid = placement_local >= 0
179
+ l_idx , g_idx , slot_idx = valid .nonzero (as_tuple = True )
180
+ gid_idx = placement_local [l_idx , g_idx , slot_idx ]
181
+
182
+ placement_global [l_idx , g_idx , gid_idx ] = slot_idx
183
+
184
+ return placement_global
185
+
186
+ def determine_expert_map_all (self ):
187
+
188
+ local_num_experts = self .global_expert_num // self .world_size
189
+
190
+ expert_map_all = torch .full (
191
+ (self .num_moe_layers , self .world_size , self .global_expert_num ),
192
+ - 1 ,
193
+ dtype = torch .int32
194
+ )
195
+
196
+ for r in range (self .world_size ):
197
+ if r < self .world_size - 1 :
198
+ start = r * local_num_experts
199
+ end = (r + 1 ) * local_num_experts
200
+ local_count = local_num_experts
201
+ else :
202
+ start = r * local_num_experts
203
+ end = self .global_expert_num
204
+ local_count = self .global_expert_num - r * local_num_experts
205
+
206
+ local_ids = torch .arange (local_count , dtype = torch .int32 )
207
+ expert_map_all [:, r , start :end ] = local_ids .unsqueeze (0 ).expand (self .num_moe_layers , - 1 )
208
+
209
+ return expert_map_all
0 commit comments