@@ -1204,6 +1204,7 @@ at::Tensor nope_qkv_varseq_prefill(
1204
1204
at::Tensor varseq_seqpos,
1205
1205
std::optional<at::Tensor> block_tables,
1206
1206
int64_t page_size,
1207
+ std::optional<at::Tensor> actual_batch_size,
1207
1208
std::optional<at::Tensor> varseq_cache_seqpos,
1208
1209
int64_t cache_logical_dtype_int,
1209
1210
std::optional<int64_t > num_groups,
@@ -1252,6 +1253,11 @@ at::Tensor nope_qkv_varseq_prefill(
1252
1253
block_tables_ptr = static_cast <int32_t *>(block_tables.value ().data_ptr ());
1253
1254
block_tables_b_stride = block_tables.value ().stride (0 );
1254
1255
}
1256
+ int64_t * actual_batch_size_ptr = nullptr ;
1257
+ if (actual_batch_size.has_value ()) {
1258
+ actual_batch_size_ptr =
1259
+ static_cast <int64_t *>(actual_batch_size.value ().data_ptr ());
1260
+ }
1255
1261
CacheLogicalDtype cache_logical_dtype =
1256
1262
static_cast <CacheLogicalDtype>(cache_logical_dtype_int);
1257
1263
if (cache_K.dtype () == at::kBFloat16 ) {
@@ -1273,7 +1279,7 @@ at::Tensor nope_qkv_varseq_prefill(
1273
1279
block_tables_b_stride,
1274
1280
varseq_cache_seqpos_
1275
1281
.packed_accessor32<int32_t, 1, at::RestrictPtrTraits>(),
1276
- nullptr ,
1282
+ actual_batch_size_ptr ,
1277
1283
update_kv);
1278
1284
C10_CUDA_KERNEL_LAUNCH_CHECK ();
1279
1285
} else {
@@ -1356,7 +1362,7 @@ at::Tensor nope_qkv_varseq_prefill(
1356
1362
block_tables_b_stride,
1357
1363
(varseq_cache_seqpos_
1358
1364
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1359
- nullptr ,
1365
+ actual_batch_size_ptr ,
1360
1366
false ,
1361
1367
0 ,
1362
1368
0 ,
@@ -1386,7 +1392,7 @@ at::Tensor nope_qkv_varseq_prefill(
1386
1392
block_tables_b_stride,
1387
1393
(varseq_cache_seqpos_
1388
1394
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1389
- nullptr ,
1395
+ actual_batch_size_ptr ,
1390
1396
false ,
1391
1397
0 ,
1392
1398
0 ,
@@ -1614,6 +1620,7 @@ at::Tensor rope_qkv_varseq_prefill(
1614
1620
std::optional<int64_t > num_groups,
1615
1621
std::optional<at::Tensor> block_tables,
1616
1622
int64_t page_size,
1623
+ std::optional<at::Tensor> actual_batch_size,
1617
1624
std::optional<at::Tensor> varseq_cache_seqpos,
1618
1625
int64_t cache_logical_dtype_int,
1619
1626
bool rope_scaling = false ,
@@ -1669,6 +1676,11 @@ at::Tensor rope_qkv_varseq_prefill(
1669
1676
block_tables_ptr = static_cast <int32_t *>(block_tables.value ().data_ptr ());
1670
1677
block_tables_b_stride = block_tables.value ().stride (0 );
1671
1678
}
1679
+ int64_t * actual_batch_size_ptr = nullptr ;
1680
+ if (actual_batch_size.has_value ()) {
1681
+ actual_batch_size_ptr =
1682
+ static_cast <int64_t *>(actual_batch_size.value ().data_ptr ());
1683
+ }
1672
1684
if (cache_K.dtype () == at::kBFloat16 ) {
1673
1685
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::ROPE>
1674
1686
<<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
@@ -1690,7 +1702,7 @@ at::Tensor rope_qkv_varseq_prefill(
1690
1702
block_tables_b_stride,
1691
1703
varseq_cache_seqpos_
1692
1704
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1693
- nullptr ,
1705
+ actual_batch_size_ptr ,
1694
1706
rope_scaling,
1695
1707
old_context_len,
1696
1708
scaling_factor,
@@ -1780,7 +1792,7 @@ at::Tensor rope_qkv_varseq_prefill(
1780
1792
block_tables_b_stride,
1781
1793
(varseq_cache_seqpos_
1782
1794
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1783
- nullptr ,
1795
+ actual_batch_size_ptr ,
1784
1796
rope_scaling,
1785
1797
old_context_len,
1786
1798
scaling_factor,
@@ -1810,7 +1822,7 @@ at::Tensor rope_qkv_varseq_prefill(
1810
1822
block_tables_b_stride,
1811
1823
(varseq_cache_seqpos_
1812
1824
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1813
- nullptr ,
1825
+ actual_batch_size_ptr ,
1814
1826
rope_scaling,
1815
1827
old_context_len,
1816
1828
scaling_factor,
@@ -1840,6 +1852,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1840
1852
std::optional<int64_t > num_groups,
1841
1853
std::optional<at::Tensor> block_tables,
1842
1854
int64_t page_size,
1855
+ std::optional<at::Tensor> actual_batch_size,
1843
1856
std::optional<at::Tensor> varseq_cache_seqpos,
1844
1857
int64_t cache_logical_dtype_int,
1845
1858
bool rope_scaling = false ,
@@ -1876,6 +1889,11 @@ at::Tensor xpos_qkv_varseq_prefill(
1876
1889
block_tables_b_stride = block_tables.value ().stride (0 );
1877
1890
}
1878
1891
1892
+ int64_t * actual_batch_size_ptr = nullptr ;
1893
+ if (actual_batch_size.has_value ()) {
1894
+ actual_batch_size_ptr =
1895
+ static_cast <int64_t *>(actual_batch_size.value ().data_ptr ());
1896
+ }
1879
1897
if (cache_K.dtype () == at::kBFloat16 ) {
1880
1898
rope_xpos_qkv_varseq_prefill_kernel<PositionEmbeddingMode::XPOS>
1881
1899
<<<blocks, threads, 0 , at::cuda::getCurrentCUDAStream()>>> (
@@ -1897,7 +1915,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1897
1915
block_tables_b_stride,
1898
1916
varseq_cache_seqpos_
1899
1917
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>(),
1900
- nullptr ,
1918
+ actual_batch_size_ptr ,
1901
1919
rope_scaling,
1902
1920
old_context_len,
1903
1921
scaling_factor,
@@ -1934,7 +1952,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1934
1952
block_tables_b_stride,
1935
1953
(varseq_cache_seqpos_
1936
1954
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1937
- nullptr ,
1955
+ actual_batch_size_ptr ,
1938
1956
rope_scaling,
1939
1957
old_context_len,
1940
1958
scaling_factor,
@@ -1964,7 +1982,7 @@ at::Tensor xpos_qkv_varseq_prefill(
1964
1982
block_tables_b_stride,
1965
1983
(varseq_cache_seqpos_
1966
1984
.packed_accessor32 <int32_t , 1 , at::RestrictPtrTraits>()),
1967
- nullptr ,
1985
+ actual_batch_size_ptr ,
1968
1986
rope_scaling,
1969
1987
old_context_len,
1970
1988
scaling_factor,
0 commit comments