File tree Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Expand file tree Collapse file tree 2 files changed +6
-6
lines changed Original file line number Diff line number Diff line change @@ -187,10 +187,10 @@ def create_and_prepopulate_kv_cache(
187
187
class MockAttentionLayer :
188
188
"""A mock attention layer for testing."""
189
189
190
- def __init__ (self ):
191
- self ._q_scale = torch .tensor (1.0 )
192
- self ._k_scale = torch .tensor (1.0 )
193
- self ._v_scale = torch .tensor (1.0 )
190
+ def __init__ (self , device : torch . device ):
191
+ self ._q_scale = torch .tensor (1.0 , device = device )
192
+ self ._k_scale = torch .tensor (1.0 , device = device )
193
+ self ._v_scale = torch .tensor (1.0 , device = device )
194
194
# Add float versions for flashinfer
195
195
self ._k_scale_float = 1.0
196
196
self ._v_scale_float = 1.0
@@ -258,7 +258,7 @@ def mock_get_per_layer_parameters(vllm_config):
258
258
)
259
259
260
260
# Create mock layer and output buffer
261
- mock_layer = MockAttentionLayer ()
261
+ mock_layer = MockAttentionLayer (device )
262
262
output = torch .empty_like (query )
263
263
264
264
# Run forward pass
Original file line number Diff line number Diff line change @@ -114,7 +114,7 @@ def get_attention_backend(backend_name: _Backend):
114
114
_Backend .FLEX_ATTENTION :
115
115
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" ,
116
116
_Backend .TRITON_ATTN_VLLM_V1 :
117
- "vllm.v1.attention.backends.triton_attn.TritonAttnBackend " ,
117
+ "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend " ,
118
118
}
119
119
120
120
if backend_name not in backend_map :
You can’t perform that action at this time.
0 commit comments