@@ -355,15 +355,15 @@ struct test_result {
355
355
bool passed;
356
356
std::string error_message;
357
357
double time_us;
358
- double flops_per_sec ;
358
+ double flops ;
359
359
double bandwidth_gb_s;
360
360
size_t memory_kb;
361
361
int n_runs;
362
362
363
363
test_result () {
364
364
// Initialize with default values
365
365
time_us = 0.0 ;
366
- flops_per_sec = 0.0 ;
366
+ flops = 0.0 ;
367
367
bandwidth_gb_s = 0.0 ;
368
368
memory_kb = 0 ;
369
369
n_runs = 0 ;
@@ -377,10 +377,24 @@ struct test_result {
377
377
test_time = buf;
378
378
}
379
379
380
+ test_result (const std::string& backend_name, const std::string& op_name, const std::string& op_params,
381
+ const std::string& test_mode, bool supported, bool passed, const std::string& error_message = " " ,
382
+ double time_us = 0.0 , double flops = 0.0 , double bandwidth_gb_s = 0.0 ,
383
+ size_t memory_kb = 0 , int n_runs = 0 )
384
+ : backend_name(backend_name), op_name(op_name), op_params(op_params), test_mode(test_mode),
385
+ supported (supported), passed(passed), error_message(error_message), time_us(time_us),
386
+ flops(flops), bandwidth_gb_s(bandwidth_gb_s), memory_kb(memory_kb), n_runs(n_runs) {
387
+ // Set test time
388
+ time_t t = time (NULL );
389
+ char buf[32 ];
390
+ std::strftime (buf, sizeof (buf), " %FT%TZ" , gmtime (&t));
391
+ test_time = buf;
392
+ }
393
+
380
394
static const std::vector<std::string> & get_fields () {
381
395
static const std::vector<std::string> fields = {
382
396
" test_time" , " backend_name" , " op_name" , " op_params" , " test_mode" ,
383
- " supported" , " passed" , " error_message" , " time_us" , " flops_per_sec " ,
397
+ " supported" , " passed" , " error_message" , " time_us" , " flops " ,
384
398
" bandwidth_gb_s" , " memory_kb" , " n_runs"
385
399
};
386
400
return fields;
@@ -395,7 +409,7 @@ struct test_result {
395
409
if (field == " memory_kb" || field == " n_runs" ) {
396
410
return INT;
397
411
}
398
- if (field == " time_us" || field == " flops_per_sec " || field == " bandwidth_gb_s" ) {
412
+ if (field == " time_us" || field == " flops " || field == " bandwidth_gb_s" ) {
399
413
return FLOAT;
400
414
}
401
415
return STRING;
@@ -412,7 +426,7 @@ struct test_result {
412
426
std::to_string (passed),
413
427
error_message,
414
428
std::to_string (time_us),
415
- std::to_string (flops_per_sec ),
429
+ std::to_string (flops ),
416
430
std::to_string (bandwidth_gb_s),
417
431
std::to_string (memory_kb),
418
432
std::to_string (n_runs)
@@ -521,7 +535,7 @@ struct console_printer : public printer {
521
535
result.n_runs ,
522
536
result.time_us );
523
537
524
- if (result.flops_per_sec > 0 ) {
538
+ if (result.flops > 0 ) {
525
539
auto format_flops = [](double flops) -> std::string {
526
540
char buf[256 ];
527
541
if (flops >= 1e12 ) {
@@ -531,14 +545,14 @@ struct console_printer : public printer {
531
545
} else if (flops >= 1e6 ) {
532
546
snprintf (buf, sizeof (buf), " %6.2f MFLOP" , flops / 1e6 );
533
547
} else {
534
- snprintf (buf, sizeof (buf), " %6.2f KFLOP " , flops / 1e3 );
548
+ snprintf (buf, sizeof (buf), " %6.2f kFLOP " , flops / 1e3 );
535
549
}
536
550
return buf;
537
551
};
538
- uint64_t op_flops_per_run = result.flops_per_sec * result.time_us / 1e6 ;
552
+ uint64_t op_flops_per_run = result.flops * result.time_us / 1e6 ;
539
553
printf (" %s/run - \033 [1;34m%sS\033 [0m" ,
540
554
format_flops (op_flops_per_run).c_str (),
541
- format_flops (result.flops_per_sec ).c_str ());
555
+ format_flops (result.flops ).c_str ());
542
556
} else {
543
557
printf (" %8zu kB/run - \033 [1;34m%7.2f GB/s\033 [0m" ,
544
558
result.memory_kb ,
@@ -565,7 +579,7 @@ struct sql_printer : public printer {
565
579
566
580
void print_header () override {
567
581
std::vector<std::string> fields = test_result::get_fields ();
568
- fprintf (fout, " CREATE TABLE IF NOT EXISTS test_results (\n " );
582
+ fprintf (fout, " CREATE TABLE IF NOT EXISTS test_backend_ops (\n " );
569
583
for (size_t i = 0 ; i < fields.size (); i++) {
570
584
fprintf (fout, " %s %s%s\n " , fields[i].c_str (), get_sql_field_type (fields[i]).c_str (),
571
585
i < fields.size () - 1 ? " ," : " " );
@@ -574,7 +588,7 @@ struct sql_printer : public printer {
574
588
}
575
589
576
590
void print_test_result (const test_result & result) override {
577
- fprintf (fout, " INSERT INTO test_results (" );
591
+ fprintf (fout, " INSERT INTO test_backend_ops (" );
578
592
std::vector<std::string> fields = test_result::get_fields ();
579
593
for (size_t i = 0 ; i < fields.size (); i++) {
580
594
fprintf (fout, " %s%s" , fields[i].c_str (), i < fields.size () - 1 ? " , " : " " );
@@ -602,21 +616,17 @@ struct sql_printer : public printer {
602
616
}
603
617
604
618
void print_device_info (const char * format, ...) override {
605
- // Do nothing - SQL format only outputs test results
606
619
(void )format;
607
620
}
608
621
609
622
void print_test_summary (const char * format, ...) override {
610
- // Do nothing - SQL format only outputs test results
611
623
(void )format;
612
624
}
613
625
614
626
void print_status_ok () override {
615
- // Do nothing - SQL format only outputs test results
616
627
}
617
628
618
629
void print_status_fail () override {
619
- // Do nothing - SQL format only outputs test results
620
630
}
621
631
};
622
632
@@ -782,19 +792,8 @@ struct test_case {
782
792
783
793
if (!supported) {
784
794
// Create test result for unsupported operation
785
- test_result result;
786
- result.backend_name = ggml_backend_name (backend1);
787
- result.op_name = current_op_name;
788
- result.op_params = vars ();
789
- result.test_mode = " test" ;
790
- result.supported = false ;
791
- result.passed = false ;
792
- result.error_message = " not supported" ;
793
- result.time_us = 0.0 ;
794
- result.flops_per_sec = 0.0 ;
795
- result.bandwidth_gb_s = 0.0 ;
796
- result.memory_kb = 0 ;
797
- result.n_runs = 0 ;
795
+ test_result result (ggml_backend_name (backend1), current_op_name, vars (), " test" ,
796
+ false , false , " not supported" );
798
797
799
798
if (output_printer) {
800
799
output_printer->print_test_result (result);
@@ -910,19 +909,9 @@ struct test_case {
910
909
911
910
// Create test result
912
911
bool test_passed = ud.ok && cmp_ok;
913
- test_result result;
914
- result.backend_name = ggml_backend_name (backend1);
915
- result.op_name = current_op_name;
916
- result.op_params = vars ();
917
- result.test_mode = " test" ;
918
- result.supported = supported;
919
- result.passed = test_passed;
920
- result.error_message = test_passed ? " " : (!cmp_ok ? " compare failed" : " test failed" );
921
- result.time_us = 0.0 ; // Not measured in test mode
922
- result.flops_per_sec = 0.0 ;
923
- result.bandwidth_gb_s = 0.0 ;
924
- result.memory_kb = 0 ;
925
- result.n_runs = 0 ;
912
+ std::string error_msg = test_passed ? " " : (!cmp_ok ? " compare failed" : " test failed" );
913
+ test_result result (ggml_backend_name (backend1), current_op_name, vars (), " test" ,
914
+ supported, test_passed, error_msg);
926
915
927
916
if (output_printer) {
928
917
output_printer->print_test_result (result);
@@ -954,19 +943,8 @@ struct test_case {
954
943
// check if backends support op
955
944
if (!ggml_backend_supports_op (backend, out)) {
956
945
// Create test result for unsupported performance test
957
- test_result result;
958
- result.backend_name = ggml_backend_name (backend);
959
- result.op_name = current_op_name;
960
- result.op_params = vars ();
961
- result.test_mode = " perf" ;
962
- result.supported = false ;
963
- result.passed = false ;
964
- result.error_message = " not supported" ;
965
- result.time_us = 0.0 ;
966
- result.flops_per_sec = 0.0 ;
967
- result.bandwidth_gb_s = 0.0 ;
968
- result.memory_kb = 0 ;
969
- result.n_runs = 0 ;
946
+ test_result result (ggml_backend_name (backend), current_op_name, vars (), " perf" ,
947
+ false , false , " not supported" );
970
948
971
949
if (output_printer) {
972
950
output_printer->print_test_result (result);
@@ -1059,19 +1037,14 @@ struct test_case {
1059
1037
} while (total_time_us < 1000 *1000 ); // run for at least 1 second
1060
1038
1061
1039
// Create test result
1062
- test_result result;
1063
- result.backend_name = ggml_backend_name (backend);
1064
- result.op_name = current_op_name;
1065
- result.op_params = vars ();
1066
- result.test_mode = " perf" ;
1067
- result.supported = true ; // If we got this far, it's supported
1068
- result.passed = true ; // Performance tests don't fail
1069
- result.error_message = " " ;
1070
- result.time_us = (double )total_time_us / total_runs;
1071
- result.flops_per_sec = (op_flops (out) > 0 ) ? (op_flops (out) * total_runs) / (total_time_us / 1e6 ) : 0.0 ;
1072
- result.bandwidth_gb_s = (op_flops (out) == 0 ) ? total_mem / (total_time_us / 1e6 ) / 1024.0 / 1024.0 / 1024.0 : 0.0 ;
1073
- result.memory_kb = op_size (out) / 1024 ;
1074
- result.n_runs = total_runs;
1040
+ double avg_time_us = (double )total_time_us / total_runs;
1041
+ double calculated_flops = (op_flops (out) > 0 ) ? (op_flops (out) * total_runs) / (total_time_us / 1e6 ) : 0.0 ;
1042
+ double calculated_bandwidth = (op_flops (out) == 0 ) ? total_mem / (total_time_us / 1e6 ) / 1024.0 / 1024.0 / 1024.0 : 0.0 ;
1043
+ size_t calculated_memory_kb = op_size (out) / 1024 ;
1044
+
1045
+ test_result result (ggml_backend_name (backend), current_op_name, vars (), " perf" ,
1046
+ true , true , " " , avg_time_us, calculated_flops, calculated_bandwidth,
1047
+ calculated_memory_kb, total_runs);
1075
1048
1076
1049
if (output_printer) {
1077
1050
output_printer->print_test_result (result);
0 commit comments