Skip to content

Commit 95bae74

Browse files
bottlerfacebook-github-bot
authored andcommitted
add actual_batch_size to rope_qkv_varseq_prefill (#4380)
Summary: Pull Request resolved: #4380 X-link: facebookresearch/FBGEMM#1450 The validation pass in Parallel Decoding uses prefill logic inside a cudagraph, and can need this for correctness. Reviewed By: jianyuh Differential Revision: D76900768 fbshipit-source-id: c87f057654a2839a3416d48b047f386fc828fe6a
1 parent ef65192 commit 95bae74

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ at::Tensor rope_qkv_varseq_prefill_meta(
5252
std::optional<int64_t> /* num_groups */,
5353
std::optional<at::Tensor> /* block_tables */,
5454
int64_t /* page_size */,
55+
std::optional<at::Tensor> /* actual_batch_size */,
5556
std::optional<at::Tensor> /* varseq_cache_seqpos */,
5657
int64_t /* cache_logical_dtype_int */,
5758
bool /* rope_scaling */,
@@ -109,6 +110,7 @@ at::Tensor nope_qkv_varseq_prefill_meta(
109110
at::Tensor /* varseq_seqpos */,
110111
std::optional<at::Tensor> /* block_tables */,
111112
int64_t /* page_size */,
113+
std::optional<at::Tensor> /* actual_batch_size */,
112114
std::optional<at::Tensor> /* varseq_cache_seqpos */,
113115
int64_t /* cache_logical_dtype_int */,
114116
std::optional<int64_t> /* num_groups */,
@@ -160,6 +162,7 @@ at::Tensor xpos_qkv_varseq_prefill_meta(
160162
std::optional<int64_t> /* num_groups */,
161163
std::optional<at::Tensor> /* block_tables */,
162164
int64_t /* page_size */,
165+
std::optional<at::Tensor> /* actual_batch_size */,
163166
std::optional<at::Tensor> /* varseq_cache_seqpos */,
164167
int64_t /* cache_logical_dtype_int */,
165168
bool /* rope_scaling */,

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.cu

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,7 @@ at::Tensor nope_qkv_varseq_prefill(
12041204
at::Tensor varseq_seqpos,
12051205
std::optional<at::Tensor> block_tables,
12061206
int64_t page_size,
1207+
std::optional<at::Tensor> actual_batch_size,
12071208
std::optional<at::Tensor> varseq_cache_seqpos,
12081209
int64_t cache_logical_dtype_int,
12091210
std::optional<int64_t> num_groups,
@@ -1252,6 +1253,11 @@ at::Tensor nope_qkv_varseq_prefill(
12521253
block_tables_ptr = static_cast<int32_t*>(block_tables.value().data_ptr());
12531254
block_tables_b_stride = block_tables.value().stride(0);
12541255
}
1256+
int64_t* actual_batch_size_ptr = nullptr;
1257+
if (actual_batch_size.has_value()) {
1258+
actual_batch_size_ptr =
1259+
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
1260+
}
12551261
CacheLogicalDtype cache_logical_dtype =
12561262
static_cast<CacheLogicalDtype>(cache_logical_dtype_int);
12571263
if (cache_K.dtype() == at::kBFloat16) {
@@ -1273,7 +1279,7 @@ at::Tensor nope_qkv_varseq_prefill(
12731279
block_tables_b_stride,
12741280
varseq_cache_seqpos_
12751281
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1276-
nullptr,
1282+
actual_batch_size_ptr,
12771283
update_kv);
12781284
C10_CUDA_KERNEL_LAUNCH_CHECK();
12791285
} else {
@@ -1356,7 +1362,7 @@ at::Tensor nope_qkv_varseq_prefill(
13561362
block_tables_b_stride,
13571363
(varseq_cache_seqpos_
13581364
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1359-
nullptr,
1365+
actual_batch_size_ptr,
13601366
false,
13611367
0,
13621368
0,
@@ -1386,7 +1392,7 @@ at::Tensor nope_qkv_varseq_prefill(
13861392
block_tables_b_stride,
13871393
(varseq_cache_seqpos_
13881394
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1389-
nullptr,
1395+
actual_batch_size_ptr,
13901396
false,
13911397
0,
13921398
0,
@@ -1614,6 +1620,7 @@ at::Tensor rope_qkv_varseq_prefill(
16141620
std::optional<int64_t> num_groups,
16151621
std::optional<at::Tensor> block_tables,
16161622
int64_t page_size,
1623+
std::optional<at::Tensor> actual_batch_size,
16171624
std::optional<at::Tensor> varseq_cache_seqpos,
16181625
int64_t cache_logical_dtype_int,
16191626
bool rope_scaling = false,
@@ -1669,6 +1676,11 @@ at::Tensor rope_qkv_varseq_prefill(
16691676
block_tables_ptr = static_cast<int32_t*>(block_tables.value().data_ptr());
16701677
block_tables_b_stride = block_tables.value().stride(0);
16711678
}
1679+
int64_t* actual_batch_size_ptr = nullptr;
1680+
if (actual_batch_size.has_value()) {
1681+
actual_batch_size_ptr =
1682+
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
1683+
}
16721684
if (cache_K.dtype() == at::kBFloat16) {
16731685
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>
16741686
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
@@ -1690,7 +1702,7 @@ at::Tensor rope_qkv_varseq_prefill(
16901702
block_tables_b_stride,
16911703
varseq_cache_seqpos_
16921704
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1693-
nullptr,
1705+
actual_batch_size_ptr,
16941706
rope_scaling,
16951707
old_context_len,
16961708
scaling_factor,
@@ -1780,7 +1792,7 @@ at::Tensor rope_qkv_varseq_prefill(
17801792
block_tables_b_stride,
17811793
(varseq_cache_seqpos_
17821794
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1783-
nullptr,
1795+
actual_batch_size_ptr,
17841796
rope_scaling,
17851797
old_context_len,
17861798
scaling_factor,
@@ -1810,7 +1822,7 @@ at::Tensor rope_qkv_varseq_prefill(
18101822
block_tables_b_stride,
18111823
(varseq_cache_seqpos_
18121824
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1813-
nullptr,
1825+
actual_batch_size_ptr,
18141826
rope_scaling,
18151827
old_context_len,
18161828
scaling_factor,
@@ -1840,6 +1852,7 @@ at::Tensor xpos_qkv_varseq_prefill(
18401852
std::optional<int64_t> num_groups,
18411853
std::optional<at::Tensor> block_tables,
18421854
int64_t page_size,
1855+
std::optional<at::Tensor> actual_batch_size,
18431856
std::optional<at::Tensor> varseq_cache_seqpos,
18441857
int64_t cache_logical_dtype_int,
18451858
bool rope_scaling = false,
@@ -1876,6 +1889,11 @@ at::Tensor xpos_qkv_varseq_prefill(
18761889
block_tables_b_stride = block_tables.value().stride(0);
18771890
}
18781891
1892+
int64_t* actual_batch_size_ptr = nullptr;
1893+
if (actual_batch_size.has_value()) {
1894+
actual_batch_size_ptr =
1895+
static_cast<int64_t*>(actual_batch_size.value().data_ptr());
1896+
}
18791897
if (cache_K.dtype() == at::kBFloat16) {
18801898
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>
18811899
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
@@ -1897,7 +1915,7 @@ at::Tensor xpos_qkv_varseq_prefill(
18971915
block_tables_b_stride,
18981916
varseq_cache_seqpos_
18991917
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1900-
nullptr,
1918+
actual_batch_size_ptr,
19011919
rope_scaling,
19021920
old_context_len,
19031921
scaling_factor,
@@ -1934,7 +1952,7 @@ at::Tensor xpos_qkv_varseq_prefill(
19341952
block_tables_b_stride,
19351953
(varseq_cache_seqpos_
19361954
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1937-
nullptr,
1955+
actual_batch_size_ptr,
19381956
rope_scaling,
19391957
old_context_len,
19401958
scaling_factor,
@@ -1964,7 +1982,7 @@ at::Tensor xpos_qkv_varseq_prefill(
19641982
block_tables_b_stride,
19651983
(varseq_cache_seqpos_
19661984
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>()),
1967-
nullptr,
1985+
actual_batch_size_ptr,
19681986
rope_scaling,
19691987
old_context_len,
19701988
scaling_factor,

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ at::Tensor nope_qkv_varseq_prefill(
2020
at::Tensor varseq_seqpos,
2121
std::optional<at::Tensor> block_tables,
2222
int64_t page_size,
23+
std::optional<at::Tensor> actual_batch_size,
2324
std::optional<at::Tensor> varseq_cache_seqpos,
2425
int64_t cache_logical_dtype_int,
2526
std::optional<int64_t> num_groups,
@@ -62,6 +63,7 @@ at::Tensor rope_qkv_varseq_prefill(
6263
std::optional<int64_t> num_groups,
6364
std::optional<at::Tensor> block_tables,
6465
int64_t page_size,
66+
std::optional<at::Tensor> actual_batch_size,
6567
std::optional<at::Tensor> varseq_cache_seqpos,
6668
int64_t cache_logical_dtype_int,
6769
bool rope_scaling,
@@ -118,6 +120,7 @@ at::Tensor xpos_qkv_varseq_prefill(
118120
std::optional<int64_t> num_groups,
119121
std::optional<at::Tensor> block_tables,
120122
int64_t page_size,
123+
std::optional<at::Tensor> actual_batch_size,
121124
std::optional<at::Tensor> varseq_cache_seqpos,
122125
int64_t cache_logical_dtype_int,
123126
bool rope_scaling,

fbgemm_gpu/experimental/gen_ai/src/kv_cache/kv_cache_defs.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,16 @@ namespace fbgemm_gpu {
1616

1717
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
1818
m.def("rope_qkv_varseq_prefill(Tensor XQ, Tensor(a!)? XK, Tensor? XV, Tensor(b!) cache_K, Tensor(c!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
19-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
19+
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192"
2020
", float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool write_k_back=False, bool k_norm=False,bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor");
2121
m.def("rope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
2222
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None) -> Tensor");
2323
m.def("nope_qkv_varseq_prefill(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, Tensor? block_tables=None, int page_size=" STRING(
24-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor");
24+
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None, Tensor?kv_quant_scale_precomputed=None) -> Tensor");
2525
m.def("nope_qkv_decoding(Tensor XQ, Tensor? XK, Tensor? XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, Tensor? block_tables=None, int page_size=" STRING(
2626
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, int? num_groups=1, Tensor? qparam_k=None, Tensor? qparam_v=None, bool k_norm=False, bool update_kv=True, Tensor?amax_qkv=None) -> Tensor");
2727
m.def("xpos_qkv_varseq_prefill(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor varseq_batch, Tensor varseq_seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
28-
DEFAULT_PAGE_SIZE) ", Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
28+
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? varseq_cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
2929
m.def("xpos_qkv_decoding(Tensor XQ, Tensor XK, Tensor XV, Tensor(a!) cache_K, Tensor(b!) cache_V, Tensor seqpos, float theta, float gamma, float scale_base, float exponent_offset, int? num_groups=1, Tensor? block_tables=None, int page_size=" STRING(
3030
DEFAULT_PAGE_SIZE) ", Tensor? actual_batch_size=None, Tensor? batch=None, Tensor? cache_seqpos=None, int cache_logical_dtype_int=0, bool rope_scaling=False, int old_context_len=8192, float scaling_factor=16, float lo_freq_factor=1, float hi_freq_factor=32, Tensor? qparam_k=None, Tensor? qparam_v=None) -> Tensor");
3131
m.def(

0 commit comments

Comments
 (0)