Skip to content

Commit bd51a6b

Browse files
committed
Revert "llama : add high-throughput mode (ggml-org#14363)"
1 parent bdff33e commit bd51a6b

27 files changed

+444
-1064
lines changed

common/arg.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,14 +1466,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14661466
params.swa_full = true;
14671467
}
14681468
).set_env("LLAMA_ARG_SWA_FULL"));
1469-
add_opt(common_arg(
1470-
{"--kv-unified", "-kvu"},
1471-
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
1472-
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
1473-
[](common_params & params) {
1474-
params.kv_unified = true;
1475-
}
1476-
).set_env("LLAMA_ARG_KV_SPLIT"));
14771469
add_opt(common_arg(
14781470
{"--no-context-shift"},
14791471
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),

common/common.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1171,7 +1171,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
11711171
cparams.no_perf = params.no_perf;
11721172
cparams.op_offload = !params.no_op_offload;
11731173
cparams.swa_full = params.swa_full;
1174-
cparams.kv_unified = params.kv_unified;
11751174

11761175
cparams.type_k = params.cache_type_k;
11771176
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,6 @@ struct common_params {
337337
bool no_perf = false; // disable performance metrics
338338
bool ctx_shift = true; // context shift on inifinite text generation
339339
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
340-
bool kv_unified = false; // enable unified KV cache
341340

342341
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
343342
bool use_mmap = true; // use mmap for faster loads

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
107107
const llama_vocab * vocab = llama_model_get_vocab(model);
108108

109109
const int n_ctx_train = llama_model_n_ctx_train(model);
110-
const int n_ctx = llama_n_ctx(ctx);
110+
const int n_ctx = llama_n_ctx(ctx);
111111

112112
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
113113

ggml/src/ggml-cuda/fattn-common.cuh

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@ typedef void (* fattn_kernel_t)(
3333
const int ne13,
3434
const int ne31,
3535
const int ne32,
36-
const int ne33,
3736
const int nb31,
3837
const int nb32,
39-
const int nb33,
4038
const int nb01,
4139
const int nb02,
4240
const int nb03,
@@ -523,7 +521,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
523521
template<int D, int ncols1, int ncols2> // D == head size
524522
__launch_bounds__(D, 1)
525523
static __global__ void flash_attn_stream_k_fixup(
526-
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
524+
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
527525
constexpr int ncols = ncols1*ncols2;
528526

529527
const int bidx0 = blockIdx.x;
@@ -537,8 +535,8 @@ static __global__ void flash_attn_stream_k_fixup(
537535
const int iter_k = ne11 / FATTN_KQ_STRIDE;
538536
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
539537

540-
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
541-
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
538+
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
539+
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
542540

543541
const bool did_not_have_any_data = kbc0 == kbc0_stop;
544542
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -547,15 +545,14 @@ static __global__ void flash_attn_stream_k_fixup(
547545
return;
548546
}
549547

550-
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
551-
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
552-
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
548+
const int channel = kbc0 / (iter_k*iter_j);
549+
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
553550

554551
if (jt*ncols1 + j >= ne01) {
555552
return;
556553
}
557554

558-
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
555+
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
559556

560557
// Load the partial result that needs a fixup:
561558
float dst_val = 0.0f;
@@ -574,7 +571,7 @@ static __global__ void flash_attn_stream_k_fixup(
574571
int bidx = bidx0 - 1;
575572
int kbc_stop = kbc0;
576573
while(true) {
577-
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
574+
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
578575
if (kbc == kbc_stop) { // Did not have any data.
579576
bidx--;
580577
kbc_stop = kbc;
@@ -620,31 +617,16 @@ static __global__ void flash_attn_combine_results(
620617
const float2 * __restrict__ VKQ_meta,
621618
float * __restrict__ dst,
622619
const int parallel_blocks) {
623-
// Dimension 0: threadIdx.x
624-
// Dimension 1: blockIdx.x
625-
// Dimension 2: blockIdx.y
626-
// Dimension 3: blockIdx.z
627-
// Memory layout is permuted with [0, 2, 1, 3]
628-
629-
const int ne01 = gridDim.x;
630-
const int ne02 = gridDim.y;
631-
632-
const int col = blockIdx.x;
633-
const int head = blockIdx.y;
634-
const int sequence = blockIdx.z;
635-
636-
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
637-
638-
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
639-
VKQ_meta += j_dst_unrolled * parallel_blocks;
640-
dst += j_dst_unrolled * D;
620+
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
621+
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
622+
dst += D * gridDim.z*blockIdx.x;
641623

642624
const int tid = threadIdx.x;
643625
__builtin_assume(tid < D);
644626

645627
extern __shared__ float2 meta[];
646628
for (int i = tid; i < 2*parallel_blocks; i += D) {
647-
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
629+
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
648630
}
649631

650632
__syncthreads();
@@ -662,11 +644,11 @@ static __global__ void flash_attn_combine_results(
662644
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
663645
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
664646

665-
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
647+
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
666648
VKQ_denominator += KQ_max_scale * meta[l].y;
667649
}
668650

669-
dst[tid] = VKQ_numerator / VKQ_denominator;
651+
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
670652
}
671653

672654
[[noreturn]]
@@ -723,6 +705,8 @@ void launch_fattn(
723705

724706
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
725707

708+
GGML_ASSERT(Q->ne[3] == 1);
709+
726710
ggml_cuda_pool & pool = ctx.pool();
727711
cudaStream_t main_stream = ctx.stream();
728712
const int id = ggml_cuda_get_device();
@@ -869,8 +853,8 @@ void launch_fattn(
869853
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
870854
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
871855
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
872-
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
873-
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
856+
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
857+
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
874858
Q->nb[1], Q->nb[2], Q->nb[3],
875859
nb11, nb12, nb13,
876860
nb21, nb22, nb23,
@@ -885,11 +869,11 @@ void launch_fattn(
885869

886870
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
887871
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
888-
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
872+
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
889873
}
890874
} else if (parallel_blocks > 1) {
891875
const dim3 block_dim_combine(DV, 1, 1);
892-
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
876+
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
893877
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
894878

895879
flash_attn_combine_results<DV>

ggml/src/ggml-cuda/fattn-mma-f16.cuh

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,10 +1224,8 @@ static __global__ void flash_attn_ext_f16(
12241224
const int ne13,
12251225
const int ne31,
12261226
const int ne32,
1227-
const int ne33,
12281227
const int nb31,
12291228
const int nb32,
1230-
const int nb33,
12311229
const int nb01,
12321230
const int nb02,
12331231
const int nb03,
@@ -1276,8 +1274,8 @@ static __global__ void flash_attn_ext_f16(
12761274
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
12771275

12781276
// kbc == k block continuous, current index in continuous ijk space.
1279-
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1280-
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1277+
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
1278+
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
12811279

12821280
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
12831281
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1287,19 +1285,18 @@ static __global__ void flash_attn_ext_f16(
12871285
int kb0_start = kbc % iter_k;
12881286
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
12891287
while (kbc < kbc_stop && kb0_stop == iter_k) {
1290-
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1291-
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1292-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1288+
const int channel = kbc / (iter_k*iter_j);
1289+
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
12931290

1294-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1295-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1291+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1292+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
12961293
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1297-
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1298-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1294+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1295+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
12991296

1300-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1297+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
13011298

1302-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1299+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
13031300

13041301
const int kb0_start_kernel = kb0_start * kb_niter;
13051302
const int kb0_stop_kernel = kb0_stop * kb_niter;
@@ -1328,19 +1325,18 @@ static __global__ void flash_attn_ext_f16(
13281325
return;
13291326
}
13301327

1331-
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1332-
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
1333-
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
1328+
const int channel = kbc / (iter_k*iter_j);
1329+
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
13341330

1335-
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
1336-
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
1331+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
1332+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
13371333
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1338-
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1339-
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
1334+
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
1335+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
13401336

1341-
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
1337+
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
13421338

1343-
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
1339+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
13441340

13451341
const int kb0_start_kernel = kb0_start * kb_niter;
13461342
const int kb0_stop_kernel = kb0_stop * kb_niter;

ggml/src/ggml-cuda/fattn-tile-f16.cu

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,8 @@ static __global__ void flash_attn_tile_ext_f16(
3131
const int ne13,
3232
const int ne31,
3333
const int ne32,
34-
const int ne33,
3534
const int nb31,
3635
const int nb32,
37-
const int nb33,
3836
const int nb01,
3937
const int nb02,
4038
const int nb03,
@@ -64,17 +62,15 @@ static __global__ void flash_attn_tile_ext_f16(
6462

6563
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
6664

67-
const int sequence = blockIdx.z / ne02;
68-
const int head = blockIdx.z - sequence*ne02;
6965
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
70-
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
71-
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
72-
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
73-
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
66+
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
67+
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
68+
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
69+
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
7470

7571
const int stride_KV2 = nb11 / sizeof(half2);
7672

77-
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
73+
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
7874
const half slopeh = __float2half(slopef);
7975

8076
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
@@ -259,8 +255,6 @@ static __global__ void flash_attn_tile_ext_f16(
259255
__syncthreads();
260256
}
261257

262-
float2 * dst2 = (float2 *) dst;
263-
264258
#pragma unroll
265259
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
266260
const int j_VKQ = j_VKQ_0 + threadIdx.y;
@@ -272,21 +266,21 @@ static __global__ void flash_attn_tile_ext_f16(
272266
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
273267
kqsum_j = warp_reduce_sum((float)kqsum_j);
274268

275-
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
276-
277269
#pragma unroll
278-
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
279-
const int i0 = i00 + threadIdx.x;
270+
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
271+
const int i0 = i00 + 2*threadIdx.x;
280272

281-
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
273+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
282274
if (gridDim.y == 1) {
283275
dst_val /= __half2half2(kqsum_j);
284276
}
285-
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
277+
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
278+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
279+
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
286280
}
287281

288282
if (gridDim.y != 1 && threadIdx.x == 0) {
289-
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
283+
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
290284
}
291285
}
292286
#else
@@ -296,8 +290,8 @@ static __global__ void flash_attn_tile_ext_f16(
296290
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
297291
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
298292
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
299-
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
300-
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
293+
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
294+
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
301295
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
302296
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
303297
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);

0 commit comments

Comments
 (0)