@@ -489,8 +489,12 @@ struct result_timings {
489
489
double predicted_per_token_ms;
490
490
double predicted_per_second;
491
491
492
+ // Optional speculative metrics - only included when > 0
493
+ int32_t draft_n = 0 ;
494
+ int32_t draft_n_accepted = 0 ;
495
+
492
496
json to_json () const {
493
- return {
497
+ json base = {
494
498
{" prompt_n" , prompt_n},
495
499
{" prompt_ms" , prompt_ms},
496
500
{" prompt_per_token_ms" , prompt_per_token_ms},
@@ -501,6 +505,13 @@ struct result_timings {
501
505
{" predicted_per_token_ms" , predicted_per_token_ms},
502
506
{" predicted_per_second" , predicted_per_second},
503
507
};
508
+
509
+ if (draft_n > 0 ) {
510
+ base[" draft_n" ] = draft_n;
511
+ base[" draft_n_accepted" ] = draft_n_accepted;
512
+ }
513
+
514
+ return base;
504
515
}
505
516
};
506
517
@@ -1299,6 +1310,10 @@ struct server_slot {
1299
1310
1300
1311
std::function<void (int )> callback_on_release;
1301
1312
1313
+ // Speculative decoding stats
1314
+ int32_t n_draft_total = 0 ; // Total draft tokens generated
1315
+ int32_t n_draft_accepted = 0 ; // Draft tokens actually accepted
1316
+
1302
1317
void reset () {
1303
1318
SLT_DBG (*this , " %s" , " \n " );
1304
1319
@@ -1315,6 +1330,10 @@ struct server_slot {
1315
1330
1316
1331
generated_tokens.clear ();
1317
1332
generated_token_probs.clear ();
1333
+
1334
+ // clear speculative decoding stats
1335
+ n_draft_total = 0 ;
1336
+ n_draft_accepted = 0 ;
1318
1337
}
1319
1338
1320
1339
bool is_non_causal () const {
@@ -1381,6 +1400,12 @@ struct server_slot {
1381
1400
timings.predicted_per_token_ms = t_token_generation / n_decoded;
1382
1401
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
1383
1402
1403
+ // Add speculative metrics
1404
+ if (n_draft_total > 0 ) {
1405
+ timings.draft_n = n_draft_total;
1406
+ timings.draft_n_accepted = n_draft_accepted;
1407
+ }
1408
+
1384
1409
return timings;
1385
1410
}
1386
1411
@@ -1428,6 +1453,15 @@ struct server_slot {
1428
1453
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
1429
1454
t_token_generation, n_decoded, t_gen, n_gen_second,
1430
1455
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
1456
+
1457
+ if (n_draft_total > 0 ) {
1458
+ const float draft_ratio = (float ) n_draft_accepted / n_draft_total;
1459
+ SLT_INF (*this ,
1460
+ " \n "
1461
+ " draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n " ,
1462
+ draft_ratio, n_draft_accepted, n_draft_total
1463
+ );
1464
+ }
1431
1465
}
1432
1466
1433
1467
json to_json () const {
@@ -3290,6 +3324,9 @@ struct server_context {
3290
3324
3291
3325
llama_tokens draft = common_speculative_gen_draft (slot.spec , params_spec, slot.cache_tokens , id);
3292
3326
3327
+ // keep track of total number of tokens generated in the draft
3328
+ slot.n_draft_total += draft.size ();
3329
+
3293
3330
// ignore small drafts
3294
3331
if (slot.params .speculative .n_min > (int ) draft.size ()) {
3295
3332
SLT_DBG (slot, " ignoring small draft: %d < %d\n " , (int ) draft.size (), slot.params .speculative .n_min );
@@ -3315,6 +3352,9 @@ struct server_context {
3315
3352
slot.n_past += ids.size ();
3316
3353
slot.n_decoded += ids.size ();
3317
3354
3355
+ // update how many tokens out of draft was accepted
3356
+ slot.n_draft_accepted += ids.size () - 1 ;
3357
+
3318
3358
slot.cache_tokens .push_back (id);
3319
3359
slot.cache_tokens .insert (slot.cache_tokens .end (), ids.begin (), ids.end () - 1 );
3320
3360
0 commit comments