@@ -220,6 +220,7 @@ struct cmd_params {
220
220
std::vector<int > n_prompt;
221
221
std::vector<int > n_gen;
222
222
std::vector<std::pair<int , int >> n_pg;
223
+ std::vector<std::pair<int , int >> n_gp;
223
224
std::vector<int > n_batch;
224
225
std::vector<int > n_ubatch;
225
226
std::vector<ggml_type> type_k;
@@ -248,6 +249,7 @@ static const cmd_params cmd_params_defaults = {
248
249
/* n_prompt */ {512 },
249
250
/* n_gen */ {128 },
250
251
/* n_pg */ {},
252
+ /* n_gp */ {},
251
253
/* n_batch */ {2048 },
252
254
/* n_ubatch */ {512 },
253
255
/* type_k */ {GGML_TYPE_F16},
@@ -280,6 +282,7 @@ static void print_usage(int /* argc */, char ** argv) {
280
282
printf (" -p, --n-prompt <n> (default: %s)\n " , join (cmd_params_defaults.n_prompt , " ," ).c_str ());
281
283
printf (" -n, --n-gen <n> (default: %s)\n " , join (cmd_params_defaults.n_gen , " ," ).c_str ());
282
284
printf (" -pg <pp,tg> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.n_pg , pair_str), " ," ).c_str ());
285
+ printf (" -gp <pp,tg> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.n_gp , pair_str), " ," ).c_str ());
283
286
printf (" -b, --batch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_batch , " ," ).c_str ());
284
287
printf (" -ub, --ubatch-size <n> (default: %s)\n " , join (cmd_params_defaults.n_ubatch , " ," ).c_str ());
285
288
printf (" -ctk, --cache-type-k <t> (default: %s)\n " , join (transform_to_str (cmd_params_defaults.type_k , ggml_type_name), " ," ).c_str ());
@@ -393,6 +396,17 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
393
396
break ;
394
397
}
395
398
params.n_pg .push_back ({std::stoi (p[0 ]), std::stoi (p[1 ])});
399
+ } else if (arg == " -gp" ) {
400
+ if (++i >= argc) {
401
+ invalid_param = true ;
402
+ break ;
403
+ }
404
+ auto p = string_split<std::string>(argv[i], ' ,' );
405
+ if (p.size () != 2 ) {
406
+ invalid_param = true ;
407
+ break ;
408
+ }
409
+ params.n_gp .push_back ({ std::stoi (p[0 ]), std::stoi (p[1 ]) });
396
410
} else if (arg == " -b" || arg == " --batch-size" ) {
397
411
if (++i >= argc) {
398
412
invalid_param = true ;
@@ -596,6 +610,7 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
596
610
if (params.n_prompt .empty ()) { params.n_prompt = cmd_params_defaults.n_prompt ; }
597
611
if (params.n_gen .empty ()) { params.n_gen = cmd_params_defaults.n_gen ; }
598
612
if (params.n_pg .empty ()) { params.n_pg = cmd_params_defaults.n_pg ; }
613
+ if (params.n_gp .empty ()) { params.n_gp = cmd_params_defaults.n_gp ; }
599
614
if (params.n_batch .empty ()) { params.n_batch = cmd_params_defaults.n_batch ; }
600
615
if (params.n_ubatch .empty ()) { params.n_ubatch = cmd_params_defaults.n_ubatch ; }
601
616
if (params.type_k .empty ()) { params.type_k = cmd_params_defaults.type_k ; }
@@ -614,7 +629,19 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
614
629
return params;
615
630
}
616
631
632
+ enum test_kind_type {
633
+ // measure mean prompt processing rate without token generation
634
+ TEST_KIND_PP,
635
+ // measure mean token generation rate without prompt processing
636
+ TEST_KIND_TG,
637
+ // measure mean prompt processing and token generation rate
638
+ TEST_KIND_PG,
639
+ // measure mean token generation rate after processing prompt of given length
640
+ TEST_KIND_GP,
641
+ };
642
+
617
643
struct cmd_params_instance {
644
+ test_kind_type test_kind;
618
645
std::string model;
619
646
int n_prompt;
620
647
int n_gen;
@@ -701,6 +728,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
701
728
continue ;
702
729
}
703
730
cmd_params_instance instance = {
731
+ /* .test_kind = */ TEST_KIND_PP,
704
732
/* .model = */ m,
705
733
/* .n_prompt = */ n_prompt,
706
734
/* .n_gen = */ 0 ,
@@ -728,6 +756,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
728
756
continue ;
729
757
}
730
758
cmd_params_instance instance = {
759
+ /* .test_kind = */ TEST_KIND_PP,
731
760
/* .model = */ m,
732
761
/* .n_prompt = */ 0 ,
733
762
/* .n_gen = */ n_gen,
@@ -755,6 +784,7 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
755
784
continue ;
756
785
}
757
786
cmd_params_instance instance = {
787
+ /* .test_kind = */ TEST_KIND_PP,
758
788
/* .model = */ m,
759
789
/* .n_prompt = */ n_pg.first ,
760
790
/* .n_gen = */ n_pg.second ,
@@ -776,6 +806,34 @@ static std::vector<cmd_params_instance> get_cmd_params_instances(const cmd_param
776
806
};
777
807
instances.push_back (instance);
778
808
}
809
+
810
+ for (const auto & n_gp : params.n_gp ) {
811
+ if (n_gp.first == 0 && n_gp.second == 0 ) {
812
+ continue ;
813
+ }
814
+ cmd_params_instance instance = {
815
+ /* .test_kind = */ TEST_KIND_GP,
816
+ /* .model = */ m,
817
+ /* .n_prompt = */ n_gp.first ,
818
+ /* .n_gen = */ n_gp.second ,
819
+ /* .n_batch = */ nb,
820
+ /* .n_ubatch = */ nub,
821
+ /* .type_k = */ tk,
822
+ /* .type_v = */ tv,
823
+ /* .n_threads = */ nt,
824
+ /* .n_gpu_layers = */ nl,
825
+ /* .rpc_servers = */ rpc,
826
+ /* .split_mode = */ sm,
827
+ /* .main_gpu = */ mg,
828
+ /* .no_kv_offload= */ nkvo,
829
+ /* .flash_attn = */ fa,
830
+ /* .tensor_split = */ ts,
831
+ /* .use_mmap = */ mmp,
832
+ /* .embeddings = */ embd,
833
+ /* .repack = */ params.repack ,
834
+ };
835
+ instances.push_back (instance);
836
+ }
779
837
}
780
838
781
839
return instances;
@@ -816,6 +874,8 @@ struct test {
816
874
int n_gen;
817
875
std::string test_time;
818
876
std::vector<uint64_t > samples_ns;
877
+ test_kind_type test_kind;
878
+ std::string test_label;
819
879
820
880
test (const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
821
881
model_filename = inst.model ;
@@ -841,11 +901,32 @@ struct test {
841
901
repack = inst.repack ;
842
902
n_prompt = inst.n_prompt ;
843
903
n_gen = inst.n_gen ;
904
+ test_kind = inst.test_kind ;
844
905
// RFC 3339 date-time format
845
906
time_t t = time (NULL );
846
907
std::strftime (buf, sizeof (buf), " %FT%TZ" , gmtime (&t));
847
908
test_time = buf;
848
909
910
+ // prepare test label for printing
911
+ switch (test_kind) {
912
+ case TEST_KIND_PP:
913
+ snprintf (buf, sizeof (buf), " pp%d" , n_prompt);
914
+ break ;
915
+ case TEST_KIND_TG:
916
+ snprintf (buf, sizeof (buf), " tg%d" , n_gen);
917
+ break ;
918
+ case TEST_KIND_PG:
919
+ snprintf (buf, sizeof (buf), " pp%d+tg%d" , n_prompt, n_gen);
920
+ break ;
921
+ case TEST_KIND_GP:
922
+ snprintf (buf, sizeof (buf), " tg%d@pp%d" , n_gen, n_prompt);
923
+ break ;
924
+ default :
925
+ snprintf (buf, sizeof (buf), " unknown" );
926
+ break ;
927
+ }
928
+ test_label = buf;
929
+
849
930
(void ) ctx;
850
931
}
851
932
@@ -858,7 +939,7 @@ struct test {
858
939
}
859
940
860
941
std::vector<double > get_ts () const {
861
- int n_tokens = n_prompt + n_gen;
942
+ int n_tokens = (test_kind == TEST_KIND_GP ? 0 : n_prompt) + n_gen;
862
943
std::vector<double > ts;
863
944
std::transform (samples_ns.begin (), samples_ns.end (), std::back_inserter (ts), [n_tokens](uint64_t t) { return 1e9 * n_tokens / t; });
864
945
return ts;
@@ -911,7 +992,7 @@ struct test {
911
992
" tensor_split" , " use_mmap" , " embeddings" , " repack" ,
912
993
" n_prompt" , " n_gen" , " test_time" ,
913
994
" avg_ns" , " stddev_ns" ,
914
- " avg_ts" , " stddev_ts"
995
+ " avg_ts" , " stddev_ts" , " test " ,
915
996
};
916
997
return fields;
917
998
}
@@ -967,7 +1048,8 @@ struct test {
967
1048
tensor_split_str, std::to_string (use_mmap), std::to_string (embeddings), std::to_string (repack),
968
1049
std::to_string (n_prompt), std::to_string (n_gen), test_time,
969
1050
std::to_string (avg_ns ()), std::to_string (stdev_ns ()),
970
- std::to_string (avg_ts ()), std::to_string (stdev_ts ())
1051
+ std::to_string (avg_ts ()), std::to_string (stdev_ts ()),
1052
+ test_label
971
1053
};
972
1054
return values;
973
1055
}
@@ -1269,14 +1351,15 @@ struct markdown_printer : public printer {
1269
1351
value += " +RPC" ;
1270
1352
}
1271
1353
} else if (field == " test" ) {
1272
- if (t.n_prompt > 0 && t.n_gen == 0 ) {
1273
- snprintf (buf, sizeof (buf), " pp%d" , t.n_prompt );
1274
- } else if (t.n_gen > 0 && t.n_prompt == 0 ) {
1275
- snprintf (buf, sizeof (buf), " tg%d" , t.n_gen );
1276
- } else {
1277
- snprintf (buf, sizeof (buf), " pp%d+tg%d" , t.n_prompt , t.n_gen );
1278
- }
1279
- value = buf;
1354
+ // if (t.n_prompt > 0 && t.n_gen == 0) {
1355
+ // snprintf(buf, sizeof(buf), "pp%d", t.n_prompt);
1356
+ // } else if (t.n_gen > 0 && t.n_prompt == 0) {
1357
+ // snprintf(buf, sizeof(buf), "tg%d", t.n_gen);
1358
+ // } else {
1359
+ // snprintf(buf, sizeof(buf), "pp%d+tg%d", t.n_prompt, t.n_gen);
1360
+ // }
1361
+ // value = buf;
1362
+ value = t.test_label ;
1280
1363
} else if (field == " t/s" ) {
1281
1364
snprintf (buf, sizeof (buf), " %.2f ± %.2f" , t.avg_ts (), t.stdev_ts ());
1282
1365
value = buf;
@@ -1489,6 +1572,7 @@ int main(int argc, char ** argv) {
1489
1572
if (t.n_prompt > 0 ) {
1490
1573
test_prompt (ctx, t.n_prompt , 0 , t.n_batch , t.n_threads );
1491
1574
}
1575
+ if (t.test_kind == TEST_KIND_GP) t_start = get_time_ns ();
1492
1576
if (t.n_gen > 0 ) {
1493
1577
test_gen (ctx, t.n_gen , t.n_prompt , t.n_threads );
1494
1578
}
0 commit comments