@@ -64,10 +64,11 @@ def __init__(
64
64
65
65
@staticmethod
66
66
def prepare_inputs (
67
- # [batch_size + 1]
68
- cu_target_query_lens : torch .Tensor ,
69
- # [batch_size]
70
- num_rejected_tokens : torch .Tensor ,
67
+ # [batch_size + 1]
68
+ cu_target_query_lens : torch .Tensor ,
69
+ # [batch_size]
70
+ num_rejected_tokens : torch .Tensor ,
71
+ force_one_token : bool = False
71
72
) -> tuple [torch .Tensor , torch .Tensor ]:
72
73
# cu_target_query_lens: [0, a, a + b, a + b + c]
73
74
# num_rejected_tokens: [n1, n2, n3]
@@ -76,32 +77,39 @@ def prepare_inputs(
76
77
# token_indices: [0, 1, ..., a - n1 - 1,
77
78
# a, a + 1, ..., a + b - n2 - 1,
78
79
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
79
-
80
80
# [0, a, a + b, a + b + c] -> [a, b, c]
81
81
query_len_per_req = (cu_target_query_lens [1 :] -
82
82
cu_target_query_lens [:- 1 ])
83
83
# [a, b, c] -> [a - n1, b - n2, c - n3]
84
84
num_tokens_per_req = query_len_per_req - num_rejected_tokens
85
+ if force_one_token :
86
+ # enable force_one_token means we only focus on the last token position of each request
87
+ # token_indices: [batch_size]
88
+ cu_num_tokens = torch .arange (cu_target_query_lens .size (0 ),
89
+ device = cu_target_query_lens .device ,
90
+ dtype = torch .int32 )
91
+ relative_index = query_len_per_req - num_rejected_tokens - 1
92
+ token_indices = cu_target_query_lens [:- 1 ] + relative_index
93
+ else :
94
+ cu_num_tokens = torch .empty_like (cu_target_query_lens )
95
+ torch .cumsum (num_tokens_per_req , dim = 0 , out = cu_num_tokens [1 :])
96
+ cu_num_tokens [0 ] = 0
97
+
98
+ # FIXME(woosuk): Avoid synchronization.
99
+ num_tokens = cu_num_tokens [- 1 ].item ()
100
+ token_indices = torch .empty (
101
+ num_tokens ,
102
+ dtype = torch .int32 ,
103
+ device = cu_num_tokens .device ,
104
+ )
85
105
86
- cu_num_tokens = torch .empty_like (cu_target_query_lens )
87
- torch .cumsum (num_tokens_per_req , dim = 0 , out = cu_num_tokens [1 :])
88
- cu_num_tokens [0 ] = 0
89
-
90
- # FIXME(woosuk): Avoid synchronization.
91
- num_tokens = cu_num_tokens [- 1 ].item ()
92
- token_indices = torch .empty (
93
- num_tokens ,
94
- dtype = torch .int32 ,
95
- device = cu_num_tokens .device ,
96
- )
97
-
98
- BLOCK_SIZE = 1024
99
- prepare_input_kernel (
100
- token_indices ,
101
- cu_target_query_lens ,
102
- cu_num_tokens ,
103
- block_size = BLOCK_SIZE ,
104
- )
106
+ BLOCK_SIZE = 1024
107
+ prepare_input_kernel (
108
+ token_indices ,
109
+ cu_target_query_lens ,
110
+ cu_num_tokens ,
111
+ block_size = BLOCK_SIZE ,
112
+ )
105
113
return cu_num_tokens , token_indices
106
114
107
115
def propose (
@@ -160,7 +168,9 @@ def propose(
160
168
common_prefix_len = 0 ,
161
169
common_attn_metadata = common_attn_metadata ,
162
170
)
163
-
171
+ # When proposing, we set the prefill query_lens to 1.
172
+ if attn_metadata .prefill is not None :
173
+ attn_metadata .prefill .query_lens [:] = 1
164
174
with set_ascend_forward_context (attn_metadata , self .vllm_config ):
165
175
hidden_states = self .model (
166
176
input_ids = input_ids ,
0 commit comments