Skip to content

Commit bea01ea

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
1 parent 679a141 commit bea01ea

File tree

1 file changed

+40
-67
lines changed

1 file changed

+40
-67
lines changed

tests/test-backend-ops.cpp

Lines changed: 40 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -355,15 +355,15 @@ struct test_result {
355355
bool passed;
356356
std::string error_message;
357357
double time_us;
358-
double flops_per_sec;
358+
double flops;
359359
double bandwidth_gb_s;
360360
size_t memory_kb;
361361
int n_runs;
362362

363363
test_result() {
364364
// Initialize with default values
365365
time_us = 0.0;
366-
flops_per_sec = 0.0;
366+
flops = 0.0;
367367
bandwidth_gb_s = 0.0;
368368
memory_kb = 0;
369369
n_runs = 0;
@@ -377,10 +377,24 @@ struct test_result {
377377
test_time = buf;
378378
}
379379

380+
test_result(const std::string& backend_name, const std::string& op_name, const std::string& op_params,
381+
const std::string& test_mode, bool supported, bool passed, const std::string& error_message = "",
382+
double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0,
383+
size_t memory_kb = 0, int n_runs = 0)
384+
: backend_name(backend_name), op_name(op_name), op_params(op_params), test_mode(test_mode),
385+
supported(supported), passed(passed), error_message(error_message), time_us(time_us),
386+
flops(flops), bandwidth_gb_s(bandwidth_gb_s), memory_kb(memory_kb), n_runs(n_runs) {
387+
// Set test time
388+
time_t t = time(NULL);
389+
char buf[32];
390+
std::strftime(buf, sizeof(buf), "%FT%TZ", gmtime(&t));
391+
test_time = buf;
392+
}
393+
380394
static const std::vector<std::string> & get_fields() {
381395
static const std::vector<std::string> fields = {
382396
"test_time", "backend_name", "op_name", "op_params", "test_mode",
383-
"supported", "passed", "error_message", "time_us", "flops_per_sec",
397+
"supported", "passed", "error_message", "time_us", "flops",
384398
"bandwidth_gb_s", "memory_kb", "n_runs"
385399
};
386400
return fields;
@@ -395,7 +409,7 @@ struct test_result {
395409
if (field == "memory_kb" || field == "n_runs") {
396410
return INT;
397411
}
398-
if (field == "time_us" || field == "flops_per_sec" || field == "bandwidth_gb_s") {
412+
if (field == "time_us" || field == "flops" || field == "bandwidth_gb_s") {
399413
return FLOAT;
400414
}
401415
return STRING;
@@ -412,7 +426,7 @@ struct test_result {
412426
std::to_string(passed),
413427
error_message,
414428
std::to_string(time_us),
415-
std::to_string(flops_per_sec),
429+
std::to_string(flops),
416430
std::to_string(bandwidth_gb_s),
417431
std::to_string(memory_kb),
418432
std::to_string(n_runs)
@@ -521,7 +535,7 @@ struct console_printer : public printer {
521535
result.n_runs,
522536
result.time_us);
523537

524-
if (result.flops_per_sec > 0) {
538+
if (result.flops > 0) {
525539
auto format_flops = [](double flops) -> std::string {
526540
char buf[256];
527541
if (flops >= 1e12) {
@@ -531,14 +545,14 @@ struct console_printer : public printer {
531545
} else if (flops >= 1e6) {
532546
snprintf(buf, sizeof(buf), "%6.2f MFLOP", flops / 1e6);
533547
} else {
534-
snprintf(buf, sizeof(buf), "%6.2f KFLOP", flops / 1e3);
548+
snprintf(buf, sizeof(buf), "%6.2f kFLOP", flops / 1e3);
535549
}
536550
return buf;
537551
};
538-
uint64_t op_flops_per_run = result.flops_per_sec * result.time_us / 1e6;
552+
uint64_t op_flops_per_run = result.flops * result.time_us / 1e6;
539553
printf("%s/run - \033[1;34m%sS\033[0m",
540554
format_flops(op_flops_per_run).c_str(),
541-
format_flops(result.flops_per_sec).c_str());
555+
format_flops(result.flops).c_str());
542556
} else {
543557
printf("%8zu kB/run - \033[1;34m%7.2f GB/s\033[0m",
544558
result.memory_kb,
@@ -565,7 +579,7 @@ struct sql_printer : public printer {
565579

566580
void print_header() override {
567581
std::vector<std::string> fields = test_result::get_fields();
568-
fprintf(fout, "CREATE TABLE IF NOT EXISTS test_results (\n");
582+
fprintf(fout, "CREATE TABLE IF NOT EXISTS test_backend_ops (\n");
569583
for (size_t i = 0; i < fields.size(); i++) {
570584
fprintf(fout, " %s %s%s\n", fields[i].c_str(), get_sql_field_type(fields[i]).c_str(),
571585
i < fields.size() - 1 ? "," : "");
@@ -574,7 +588,7 @@ struct sql_printer : public printer {
574588
}
575589

576590
void print_test_result(const test_result & result) override {
577-
fprintf(fout, "INSERT INTO test_results (");
591+
fprintf(fout, "INSERT INTO test_backend_ops (");
578592
std::vector<std::string> fields = test_result::get_fields();
579593
for (size_t i = 0; i < fields.size(); i++) {
580594
fprintf(fout, "%s%s", fields[i].c_str(), i < fields.size() - 1 ? ", " : "");
@@ -602,21 +616,17 @@ struct sql_printer : public printer {
602616
}
603617

604618
void print_device_info(const char * format, ...) override {
605-
// Do nothing - SQL format only outputs test results
606619
(void)format;
607620
}
608621

609622
void print_test_summary(const char * format, ...) override {
610-
// Do nothing - SQL format only outputs test results
611623
(void)format;
612624
}
613625

614626
void print_status_ok() override {
615-
// Do nothing - SQL format only outputs test results
616627
}
617628

618629
void print_status_fail() override {
619-
// Do nothing - SQL format only outputs test results
620630
}
621631
};
622632

@@ -782,19 +792,8 @@ struct test_case {
782792

783793
if (!supported) {
784794
// Create test result for unsupported operation
785-
test_result result;
786-
result.backend_name = ggml_backend_name(backend1);
787-
result.op_name = current_op_name;
788-
result.op_params = vars();
789-
result.test_mode = "test";
790-
result.supported = false;
791-
result.passed = false;
792-
result.error_message = "not supported";
793-
result.time_us = 0.0;
794-
result.flops_per_sec = 0.0;
795-
result.bandwidth_gb_s = 0.0;
796-
result.memory_kb = 0;
797-
result.n_runs = 0;
795+
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test",
796+
false, false, "not supported");
798797

799798
if (output_printer) {
800799
output_printer->print_test_result(result);
@@ -910,19 +909,9 @@ struct test_case {
910909

911910
// Create test result
912911
bool test_passed = ud.ok && cmp_ok;
913-
test_result result;
914-
result.backend_name = ggml_backend_name(backend1);
915-
result.op_name = current_op_name;
916-
result.op_params = vars();
917-
result.test_mode = "test";
918-
result.supported = supported;
919-
result.passed = test_passed;
920-
result.error_message = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed");
921-
result.time_us = 0.0; // Not measured in test mode
922-
result.flops_per_sec = 0.0;
923-
result.bandwidth_gb_s = 0.0;
924-
result.memory_kb = 0;
925-
result.n_runs = 0;
912+
std::string error_msg = test_passed ? "" : (!cmp_ok ? "compare failed" : "test failed");
913+
test_result result(ggml_backend_name(backend1), current_op_name, vars(), "test",
914+
supported, test_passed, error_msg);
926915

927916
if (output_printer) {
928917
output_printer->print_test_result(result);
@@ -954,19 +943,8 @@ struct test_case {
954943
// check if backends support op
955944
if (!ggml_backend_supports_op(backend, out)) {
956945
// Create test result for unsupported performance test
957-
test_result result;
958-
result.backend_name = ggml_backend_name(backend);
959-
result.op_name = current_op_name;
960-
result.op_params = vars();
961-
result.test_mode = "perf";
962-
result.supported = false;
963-
result.passed = false;
964-
result.error_message = "not supported";
965-
result.time_us = 0.0;
966-
result.flops_per_sec = 0.0;
967-
result.bandwidth_gb_s = 0.0;
968-
result.memory_kb = 0;
969-
result.n_runs = 0;
946+
test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf",
947+
false, false, "not supported");
970948

971949
if (output_printer) {
972950
output_printer->print_test_result(result);
@@ -1059,19 +1037,14 @@ struct test_case {
10591037
} while (total_time_us < 1000*1000); // run for at least 1 second
10601038

10611039
// Create test result
1062-
test_result result;
1063-
result.backend_name = ggml_backend_name(backend);
1064-
result.op_name = current_op_name;
1065-
result.op_params = vars();
1066-
result.test_mode = "perf";
1067-
result.supported = true; // If we got this far, it's supported
1068-
result.passed = true; // Performance tests don't fail
1069-
result.error_message = "";
1070-
result.time_us = (double)total_time_us / total_runs;
1071-
result.flops_per_sec = (op_flops(out) > 0) ? (op_flops(out) * total_runs) / (total_time_us / 1e6) : 0.0;
1072-
result.bandwidth_gb_s = (op_flops(out) == 0) ? total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0 : 0.0;
1073-
result.memory_kb = op_size(out) / 1024;
1074-
result.n_runs = total_runs;
1040+
double avg_time_us = (double)total_time_us / total_runs;
1041+
double calculated_flops = (op_flops(out) > 0) ? (op_flops(out) * total_runs) / (total_time_us / 1e6) : 0.0;
1042+
double calculated_bandwidth = (op_flops(out) == 0) ? total_mem / (total_time_us / 1e6) / 1024.0 / 1024.0 / 1024.0 : 0.0;
1043+
size_t calculated_memory_kb = op_size(out) / 1024;
1044+
1045+
test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf",
1046+
true, true, "", avg_time_us, calculated_flops, calculated_bandwidth,
1047+
calculated_memory_kb, total_runs);
10751048

10761049
if (output_printer) {
10771050
output_printer->print_test_result(result);

0 commit comments

Comments
 (0)