@@ -261,9 +261,20 @@ def create_attention_instances(self) -> dict[int, Attention]:
261
261
num_kv_heads = self .model_config .get_num_kv_heads (self .parallel_config )
262
262
start , end = get_pp_indices (self .config .num_hidden_layers ,
263
263
self .pp_rank , self .pp_size )
264
- return {
265
- i :
266
- Attention (
264
+
265
+ attention_instances = {}
266
+ if hasattr (self .config , "global_attention_layers" ) and isinstance (
267
+ self .config .global_attention_layers , list ):
268
+ global_attention_layers = self .config .global_attention_layers
269
+ else :
270
+ global_attention_layers = None
271
+
272
+ for i in range (start , end ):
273
+ sliding_window = None
274
+ if i in global_attention_layers :
275
+ assert self .config .sliding_window is not None
276
+ sliding_window = self .config .sliding_window
277
+ attention_instances [i ] = Attention (
267
278
num_heads = num_heads ,
268
279
head_size = head_size ,
269
280
# NOTE: We use Llama scale as default, if it's set by
@@ -272,9 +283,10 @@ def create_attention_instances(self) -> dict[int, Attention]:
272
283
num_kv_heads = num_kv_heads ,
273
284
cache_config = self .cache_config ,
274
285
quant_config = self .quant_config ,
286
+ per_layer_sliding_window = sliding_window ,
275
287
prefix = f"{ i } .attn" )
276
- for i in range ( start , end )
277
- }
288
+
289
+ return attention_instances
278
290
279
291
def init_buffers (self , module : nn .Module ):
280
292
"""
0 commit comments