Skip to content

Commit 7e2c722

Browse files
authored
Add Continuous Decoding support in GQA (microsoft#21523)
### Description This PR will add support for Continuous Decoding for batch_size = 1 input. From now on, GQA can take arbitrary length input using seqlens_k as total_sequence_length - 1 and the sequence length of qkv as new_sequence_length. **This change will not affect the default behavior of GQA** ### Motivation and Context Prior to this change it was impossible to support sequence_length > 1 inputs when past context was given. This use case is essential to making continuous decoding work, which is one of our current efforts in ORT-GenAI.
1 parent 59b7b6b commit 7e2c722

17 files changed

+498
-502
lines changed

docs/ContribOperators.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,8 @@ This version of the operator has been available since version 1 of the 'com.micr
25212521
Only supports causal and local attention.
25222522
Supports rotary position embedding for CPU and CUDA.
25232523
Supports packed input for CPU and CUDA.
2524+
Supports continuous decoding for batch_size == 1 for CPU and CUDA.
2525+
25242526

25252527
#### Version
25262528

@@ -2561,9 +2563,9 @@ This version of the operator has been available since version 1 of the 'com.micr
25612563
<dt><tt>past_value</tt> (optional) : T</dt>
25622564
<dd>past state value with support for format BNSH. When past_value uses same tensor as present_value(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.</dd>
25632565
<dt><tt>seqlens_k</tt> : M</dt>
2564-
<dd>1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.</dd>
2566+
<dd>1D Tensor of shape (batch_size). Equivalent to (total_sequence_lengths - 1).</dd>
25652567
<dt><tt>total_sequence_length</tt> : M</dt>
2566-
<dd>Scalar tensor of total sequence length (past + new).</dd>
2568+
<dd>Scalar tensor equivalent to the maximum total sequence length (past + new) of the batch. Used for checking inputs and determining prompt vs token generation case.</dd>
25672569
<dt><tt>cos_cache</tt> (optional) : T</dt>
25682570
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
25692571
<dt><tt>sin_cache</tt> (optional) : T</dt>

onnxruntime/contrib_ops/cpu/bert/attention_common.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,8 @@ struct GroupQueryAttentionParameters {
114114
int local_window_size;
115115
bool kv_share_buffer;
116116
bool is_packed_qkv;
117-
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
117+
bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1
118+
bool is_first_prompt; // indicates whether this is first decoding step
118119
bool do_rotary;
119120
bool rotary_interleaved;
120121
bool use_smooth_softmax;

onnxruntime/contrib_ops/cpu/bert/attention_helper.h

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -236,19 +236,16 @@ T* ConcatStateChunkGQA(const T* past,
236236
size_t past_buff_chunk_length,
237237
size_t past_chunk_length,
238238
size_t new_chunk_length,
239-
bool is_prompt,
240239
bool past_present_share_buffer,
241240
std::ptrdiff_t i) {
242241
T* start = present + i * present_buff_chunk_length;
243242

244243
T* p = start;
245-
if (!is_prompt) {
246-
if (!past_present_share_buffer) {
247-
const T* src_past = past + i * past_buff_chunk_length;
248-
memcpy(p, src_past, past_chunk_length * sizeof(T));
249-
}
250-
p += past_chunk_length;
244+
if (!past_present_share_buffer && past_chunk_length > 0) {
245+
const T* src_past = past + i * past_buff_chunk_length;
246+
memcpy(p, src_past, past_chunk_length * sizeof(T));
251247
}
248+
p += past_chunk_length;
252249

253250
memcpy(p, chunk, new_chunk_length * sizeof(T));
254251
return start;

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h

Lines changed: 89 additions & 88 deletions
Large diffs are not rendered by default.

onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
4545
const Tensor* past_key = context->Input<Tensor>(3);
4646
const Tensor* past_value = context->Input<Tensor>(4);
4747
const Tensor* seqlens_k = context->Input<Tensor>(5);
48-
const Tensor* total_seqlen = context->Input<Tensor>(6);
48+
const Tensor* total_seqlen_tensor = context->Input<Tensor>(6);
4949
const Tensor* cos_cache = context->Input<Tensor>(7);
5050
const Tensor* sin_cache = context->Input<Tensor>(8);
5151

@@ -61,7 +61,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
6161
num_heads_,
6262
kv_num_heads_,
6363
seqlens_k,
64-
total_seqlen,
64+
total_seqlen_tensor,
6565
scale_,
6666
softcap_));
6767

@@ -103,6 +103,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
103103
}
104104

105105
if (do_rotary_) {
106+
// Initialize rotary parameters
106107
rotary_embedding_helper::RotaryParameters rotary_params = {};
107108
rotary_params.batch_size = batch_size;
108109
rotary_params.sequence_length = sequence_length;
@@ -114,17 +115,29 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
114115
rotary_params.seq_stride = head_size;
115116
rotary_params.head_stride = sequence_length * rotary_params.seq_stride;
116117
rotary_params.batch_stride = (packed_qkv ? (num_heads_ + 2 * kv_num_heads_) : num_heads_) * rotary_params.head_stride;
117-
rotary_params.position_ids_format = sequence_length == 1 ? 1 : 0;
118+
rotary_params.position_ids_format = !parameters.is_first_prompt ? 1 : 0;
118119
rotary_params.transposed = true;
119120
auto* tp = context->GetOperatorThreadPool();
120-
std::vector<int64_t> pos_ids(sequence_length == 1 ? batch_size : 1);
121-
if (sequence_length == 1) {
121+
// Generate position ids
122+
const int pos_ids_size = parameters.is_first_prompt ? 1 : batch_size * sequence_length;
123+
std::vector<int64_t> pos_ids(pos_ids_size);
124+
if (parameters.is_first_prompt) {
125+
pos_ids[0] = static_cast<int64_t>(0);
126+
} else {
127+
// Note: As of now, interactive decoding supports only batch size 1 and token generation supports only sequence length 1.
122128
for (int b = 0; b < batch_size; b++) {
123-
pos_ids[b] = static_cast<int64_t>(seqlens_k->Data<int32_t>()[b]);
129+
const int total_seqlen = seqlens_k->Data<int32_t>()[b] + 1;
130+
const int past_seqlen = total_seqlen - sequence_length;
131+
for (int s = 0; s < sequence_length; s++) {
132+
if (past_seqlen + s < total_seqlen) {
133+
pos_ids[b * sequence_length + s] = static_cast<int64_t>(past_seqlen) + s;
134+
} else {
135+
pos_ids[b * sequence_length + s] = static_cast<int64_t>(1);
136+
}
137+
}
124138
}
125-
} else {
126-
pos_ids[0] = static_cast<int64_t>(0);
127139
}
140+
// Initialize separate buffers for rotary embeddings
128141
const T* q_input;
129142
const T* k_input;
130143
T* q_rotary;
@@ -149,6 +162,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
149162
Q = RotaryQ;
150163
K = RotaryK;
151164
}
165+
// Run rotary embedding for Q and K
152166
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
153167
pos_ids.data(), cos_cache->Data<T>(),
154168
sin_cache->Data<T>(), q_rotary, rotary_interleaved_));
@@ -161,6 +175,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
161175
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, k_input,
162176
pos_ids.data(), cos_cache->Data<T>(),
163177
sin_cache->Data<T>(), k_rotary, rotary_interleaved_));
178+
// Pack V into rotary QKV buffer
164179
if (packed_qkv) {
165180
const T* v_input = k_input + kv_num_heads_ * sequence_length * head_size;
166181
T* v_rotary = k_rotary + kv_num_heads_ * sequence_length * head_size;

onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,14 +168,13 @@ Status CheckInputs(const Tensor* query,
168168
"Input 'past_key' and 'past_value' shall be both present or both absent.");
169169
}
170170

171-
// Check seqlens_k tensor (holding past seqlen for token gen)
172-
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
173-
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
171+
const auto& seqlens_k_dim = seqlens_k->Shape().GetDims();
172+
if (seqlens_k_dim.size() != 1 && seqlens_k_dim[0] != batch_size) {
174173
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
175174
"seqlens_k must be shape (batch_size).");
176175
}
177176

178-
// Set present sequence length and kv_share_buffer from input total_seqlen tensor
177+
// Set present sequence length from input total_seqlen tensor
179178
if (!onnxruntime::IsScalarOr1ElementVector(total_seqlen)) {
180179
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
181180
"total_sequence_length tensor must be of one element.");
@@ -195,11 +194,11 @@ Status CheckInputs(const Tensor* query,
195194
}
196195
if (cos_dims[0] < total_sequence_length) {
197196
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
198-
"cos_cache dimension 0 should be not be less than total_sequence_length.");
197+
"cos_cache dimension 0 shall not be less than total_sequence_length.");
199198
}
200199
if (sin_dims[0] < total_sequence_length) {
201200
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
202-
"sin_cache dimension 0 should be not be less than total_sequence_length.");
201+
"sin_cache dimension 0 shall not be less than total_sequence_length.");
203202
}
204203
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
205204
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -219,14 +218,34 @@ Status CheckInputs(const Tensor* query,
219218
"Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
220219
}
221220

222-
bool is_prompt = sequence_length != 1;
221+
bool is_subsequent_prompt = false;
222+
if (sequence_length > 1 && sequence_length != total_sequence_length) {
223+
if (batch_size != 1) {
224+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
225+
"batch_size must be 1 when sequence_length > 1 and past context is given.");
226+
}
227+
is_subsequent_prompt = true;
228+
}
229+
230+
bool is_first_prompt;
231+
if (is_subsequent_prompt) {
232+
is_first_prompt = false; // irrelevant for interactive decoding
233+
} else {
234+
// If not interactive, sequence_length is 1 for token gen and arbitrarily large for prompt
235+
is_first_prompt = (sequence_length == total_sequence_length);
236+
if (!is_first_prompt && sequence_length != 1) {
237+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
238+
"sequence_length shall be 1 when it is not prompt.");
239+
}
240+
}
223241

224242
if (parameters != nullptr) {
225243
GroupQueryAttentionParameters* output_parameters = reinterpret_cast<GroupQueryAttentionParameters*>(parameters);
226244
output_parameters->batch_size = batch_size;
227245
output_parameters->sequence_length = sequence_length; // sequence length of Q
228246
output_parameters->seqlen_past_kv_cache = past_sequence_length; // max sequence length of past kv tensors
229247
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
248+
output_parameters->total_sequence_length = total_sequence_length; // total sequence length
230249
output_parameters->hidden_size = q_hidden_size;
231250
output_parameters->num_heads = num_heads;
232251
output_parameters->head_size = head_size;
@@ -235,7 +254,8 @@ Status CheckInputs(const Tensor* query,
235254
output_parameters->rotary_dim = rotary_dim;
236255
output_parameters->is_packed_qkv = is_packed_qkv;
237256
output_parameters->is_unidirectional = true;
238-
output_parameters->is_prompt = is_prompt;
257+
output_parameters->is_subsequent_prompt = is_subsequent_prompt;
258+
output_parameters->is_first_prompt = is_first_prompt;
239259
output_parameters->scale = scale;
240260
output_parameters->softcap = softcap;
241261
output_parameters->qkv_format = qkv_format;

onnxruntime/contrib_ops/cpu/sparse/sparse_attention_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ class SparseAttentionBase {
184184
// Concatenate past_k + k -> present_k
185185
// TODO: avoid copying mutiple times for a group.
186186
k = ConcatStateChunkGQA(past_key, k, present_key, present_buff_chunk_length, past_buff_chunk_length,
187-
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
187+
is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
188188
i / kv_num_heads_factor);
189189

190190
// Compute Q*K' + AttentionMask
@@ -365,7 +365,7 @@ class SparseAttentionBase {
365365

366366
// Concatenate past_v + v -> present_v
367367
v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length,
368-
past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer,
368+
is_prompt ? 0 : past_chunk_length, kv_input_chunk_length, past_present_share_buffer,
369369
i / kv_num_heads_factor);
370370

371371
DUMP_CPU_TENSOR("present_value", v, total_seq_len, head_size);

onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ struct RightPaddingBatchHook {
4242

4343
auto lse_dim = ceil_div((int32_t)(p.num_queries), kAlignLSE) * kAlignLSE;
4444

45-
// Advance to current batch - in case of different sequence lengths
4645
if (p.seqlen_k_ptr) {
4746
p.num_keys = p.seqlen_k_ptr[batch_id];
4847
}

onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include "core/platform/env_var_utils.h"
66
#include "contrib_ops/cuda/bert/group_query_attention_impl.h"
77
#include "contrib_ops/cuda/bert/group_query_attention.h"
8-
#include "contrib_ops/cuda/bert/group_query_attention_helper.h"
8+
#include "contrib_ops/cpu/bert/group_query_attention_helper.h"
99
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
1010
#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
1111

@@ -95,7 +95,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
9595
kv_num_heads_,
9696
seqlens_k,
9797
total_seqlen,
98-
is_past_bsnh_,
9998
scale_,
10099
softcap_,
101100
device_prop.maxThreadsPerBlock));
@@ -253,7 +252,7 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
253252
data.out_accum = reinterpret_cast<CudaT*>(out_accum_buffer.get());
254253
}
255254
if (seqlens_k_buffer != nullptr) {
256-
data.seqlens_k_total = reinterpret_cast<int*>(seqlens_k_buffer.get());
255+
data.seqlens_k_buff = reinterpret_cast<int*>(seqlens_k_buffer.get());
257256
}
258257
// Memory Efficient Buffers
259258
if (k_buffer != nullptr) {

0 commit comments

Comments
 (0)