Skip to content

Commit 38d4930

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
1 parent 1d5f25c commit 38d4930

File tree

1 file changed

+67
-97
lines changed

1 file changed

+67
-97
lines changed

tests/test-backend-ops.cpp

Lines changed: 67 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -451,20 +451,20 @@ struct test_result {
451451
};
452452

453453
// Printer classes for different output formats
454+
enum class message_type {
455+
INFO,
456+
ERROR,
457+
STATUS_OK,
458+
STATUS_FAIL
459+
};
460+
454461
struct printer {
455462
virtual ~printer() {}
456463
FILE * fout = stdout;
457464
virtual void print_header() {}
458465
virtual void print_test_result(const test_result & result) = 0;
459466
virtual void print_footer() {}
460-
461-
// General purpose output methods
462-
virtual void print_info(const char * format, ...) = 0;
463-
virtual void print_error(const char * format, ...) = 0;
464-
virtual void print_device_info(const char * format, ...) = 0;
465-
virtual void print_test_summary(const char * format, ...) = 0;
466-
virtual void print_status_ok() = 0;
467-
virtual void print_status_fail() = 0;
467+
virtual void print_message(message_type type, const char * format, ...) = 0;
468468
};
469469

470470
struct console_printer : public printer {
@@ -476,42 +476,28 @@ struct console_printer : public printer {
476476
}
477477
}
478478

479-
void print_info(const char * format, ...) override {
480-
va_list args;
481-
va_start(args, format);
482-
vprintf(format, args);
483-
va_end(args);
484-
}
485-
486-
void print_error(const char * format, ...) override {
479+
void print_message(message_type type, const char * format, ...) override {
487480
va_list args;
488481
va_start(args, format);
489-
vfprintf(stderr, format, args);
490-
va_end(args);
491-
}
492482

493-
void print_device_info(const char * format, ...) override {
494-
va_list args;
495-
va_start(args, format);
496-
vprintf(format, args);
497-
va_end(args);
498-
}
483+
switch (type) {
484+
case message_type::INFO:
485+
vprintf(format, args);
486+
break;
487+
case message_type::ERROR:
488+
vfprintf(stderr, format, args);
489+
break;
490+
case message_type::STATUS_OK:
491+
printf("\033[1;32mOK\033[0m\n");
492+
break;
493+
case message_type::STATUS_FAIL:
494+
printf("\033[1;31mFAIL\033[0m\n");
495+
break;
496+
}
499497

500-
void print_test_summary(const char * format, ...) override {
501-
va_list args;
502-
va_start(args, format);
503-
vprintf(format, args);
504498
va_end(args);
505499
}
506500

507-
void print_status_ok() override {
508-
printf("\033[1;32mOK\033[0m\n");
509-
}
510-
511-
void print_status_fail() override {
512-
printf("\033[1;31mFAIL\033[0m\n");
513-
}
514-
515501
private:
516502
void print_test_console(const test_result & result) {
517503
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
@@ -617,32 +603,16 @@ struct sql_printer : public printer {
617603
fprintf(fout, ");\n");
618604
}
619605

620-
// SQL printer ignores general output - only outputs test results
621-
void print_info(const char * format, ...) override {
622-
// Do nothing - SQL format only outputs test results
623-
(void)format;
624-
}
625-
626-
void print_error(const char * format, ...) override {
627-
// Still output errors to stderr for SQL format
628-
va_list args;
629-
va_start(args, format);
630-
vfprintf(stderr, format, args);
631-
va_end(args);
632-
}
633-
634-
void print_device_info(const char * format, ...) override {
635-
(void)format;
636-
}
637-
638-
void print_test_summary(const char * format, ...) override {
639-
(void)format;
640-
}
641-
642-
void print_status_ok() override {
643-
}
644-
645-
void print_status_fail() override {
606+
// SQL printer ignores most output types - only outputs test results and errors
607+
void print_message(message_type type, const char * format, ...) override {
608+
if (type == message_type::ERROR) {
609+
// Still output errors to stderr for SQL format
610+
va_list args;
611+
va_start(args, format);
612+
vfprintf(stderr, format, args);
613+
va_end(args);
614+
}
615+
// All other message types are ignored in SQL format
646616
}
647617
};
648618

@@ -1087,14 +1057,14 @@ struct test_case {
10871057
ggml_tensor * out = build_graph(ctx.get());
10881058

10891059
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1090-
//output_printer->print_info(" %s: skipping\n", op_desc(out).c_str());
1060+
//output_printer->print_message(message_type::INFO, " %s: skipping\n", op_desc(out).c_str());
10911061
return true;
10921062
}
10931063

1094-
output_printer->print_info(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
1064+
output_printer->print_message(message_type::INFO, " %s(%s): ", op_desc(out).c_str(), vars().c_str());
10951065

10961066
if (out->type != GGML_TYPE_F32) {
1097-
output_printer->print_info("not supported [%s->type != FP32]\n", out->name);
1067+
output_printer->print_message(message_type::INFO, "not supported [%s->type != FP32]\n", out->name);
10981068
return true;
10991069
}
11001070

@@ -1103,25 +1073,25 @@ struct test_case {
11031073
bool any_params = false;
11041074
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
11051075
if (!ggml_backend_supports_op(backend, t)) {
1106-
output_printer->print_info("not supported [%s] ", ggml_backend_name(backend));
1076+
output_printer->print_message(message_type::INFO, "not supported [%s] ", ggml_backend_name(backend));
11071077
supported = false;
11081078
break;
11091079
}
11101080
if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
11111081
any_params = true;
11121082
if (t->type != GGML_TYPE_F32) {
1113-
output_printer->print_info("not supported [%s->type != FP32] ", t->name);
1083+
output_printer->print_message(message_type::INFO, "not supported [%s->type != FP32] ", t->name);
11141084
supported = false;
11151085
break;
11161086
}
11171087
}
11181088
}
11191089
if (!any_params) {
1120-
output_printer->print_info("not supported [%s] \n", op_desc(out).c_str());
1090+
output_printer->print_message(message_type::INFO, "not supported [%s] \n", op_desc(out).c_str());
11211091
supported = false;
11221092
}
11231093
if (!supported) {
1124-
output_printer->print_info("\n");
1094+
output_printer->print_message(message_type::INFO, "\n");
11251095
return true;
11261096
}
11271097

@@ -1132,7 +1102,7 @@ struct test_case {
11321102
}
11331103
}
11341104
if (ngrads > grad_nmax()) {
1135-
output_printer->print_info("skipping large tensors for speed \n");
1105+
output_printer->print_message(message_type::INFO, "skipping large tensors for speed \n");
11361106
return true;
11371107
}
11381108

@@ -1155,25 +1125,25 @@ struct test_case {
11551125

11561126
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
11571127
if (!ggml_backend_supports_op(backend, t)) {
1158-
output_printer->print_info("not supported [%s] ", ggml_backend_name(backend));
1128+
output_printer->print_message(message_type::INFO, "not supported [%s] ", ggml_backend_name(backend));
11591129
supported = false;
11601130
break;
11611131
}
11621132
if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
1163-
output_printer->print_info("not supported [%s->type != FP32] ", t->name);
1133+
output_printer->print_message(message_type::INFO, "not supported [%s->type != FP32] ", t->name);
11641134
supported = false;
11651135
break;
11661136
}
11671137
}
11681138
if (!supported) {
1169-
output_printer->print_info("\n");
1139+
output_printer->print_message(message_type::INFO, "\n");
11701140
return true;
11711141
}
11721142

11731143
// allocate
11741144
ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
11751145
if (buf == NULL) {
1176-
output_printer->print_error("failed to allocate tensors [%s] ", ggml_backend_name(backend));
1146+
output_printer->print_message(message_type::ERROR, "failed to allocate tensors [%s] ", ggml_backend_name(backend));
11771147
return false;
11781148
}
11791149

@@ -1211,7 +1181,7 @@ struct test_case {
12111181
for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
12121182
// check for nans
12131183
if (!std::isfinite(ga[i])) {
1214-
output_printer->print_info("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
1184+
output_printer->print_message(message_type::INFO, "[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
12151185
ok = false;
12161186
break;
12171187
}
@@ -1279,7 +1249,7 @@ struct test_case {
12791249

12801250
const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect);
12811251
if (err > max_maa_err()) {
1282-
output_printer->print_info("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err());
1252+
output_printer->print_message(message_type::INFO, "[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err());
12831253
ok = false;
12841254
break;
12851255
}
@@ -1289,15 +1259,15 @@ struct test_case {
12891259
}
12901260

12911261
if (!ok) {
1292-
output_printer->print_info("compare failed ");
1262+
output_printer->print_message(message_type::INFO, "compare failed ");
12931263
}
12941264

12951265
if (ok) {
1296-
output_printer->print_status_ok();
1266+
output_printer->print_message(message_type::STATUS_OK, "");
12971267
return true;
12981268
}
12991269

1300-
output_printer->print_status_fail();
1270+
output_printer->print_message(message_type::STATUS_FAIL, "");
13011271
return false;
13021272
}
13031273
};
@@ -5002,7 +4972,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
50024972
filter_test_cases(test_cases, params_filter);
50034973
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
50044974
if (backend_cpu == NULL) {
5005-
output_printer->print_error(" Failed to initialize CPU backend\n");
4975+
output_printer->print_message(message_type::ERROR, " Failed to initialize CPU backend\n");
50064976
return false;
50074977
}
50084978

@@ -5012,7 +4982,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
50124982
n_ok++;
50134983
}
50144984
}
5015-
output_printer->print_test_summary(" %zu/%zu tests passed\n", n_ok, test_cases.size());
4985+
output_printer->print_message(message_type::INFO, " %zu/%zu tests passed\n", n_ok, test_cases.size());
50164986

50174987
ggml_backend_free(backend_cpu);
50184988

@@ -5028,7 +4998,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
50284998
n_ok++;
50294999
}
50305000
}
5031-
output_printer->print_test_summary(" %zu/%zu tests passed\n", n_ok, test_cases.size());
5001+
output_printer->print_message(message_type::INFO, " %zu/%zu tests passed\n", n_ok, test_cases.size());
50325002

50335003
return n_ok == test_cases.size();
50345004
}
@@ -5115,23 +5085,23 @@ int main(int argc, char ** argv) {
51155085
output_printer->print_header();
51165086
}
51175087

5118-
output_printer->print_info("Testing %zu devices\n\n", ggml_backend_dev_count());
5088+
output_printer->print_message(message_type::INFO, "Testing %zu devices\n\n", ggml_backend_dev_count());
51195089

51205090
size_t n_ok = 0;
51215091

51225092
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
51235093
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
51245094

5125-
output_printer->print_device_info("Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev));
5095+
output_printer->print_message(message_type::INFO, "Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev));
51265096

51275097
if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {
5128-
output_printer->print_device_info(" Skipping\n");
5098+
output_printer->print_message(message_type::INFO, " Skipping\n");
51295099
n_ok++;
51305100
continue;
51315101
}
51325102

51335103
if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
5134-
output_printer->print_device_info(" Skipping CPU backend\n");
5104+
output_printer->print_message(message_type::INFO, " Skipping CPU backend\n");
51355105
n_ok++;
51365106
continue;
51375107
}
@@ -5146,23 +5116,23 @@ int main(int argc, char ** argv) {
51465116
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
51475117
}
51485118

5149-
output_printer->print_device_info(" Device description: %s\n", ggml_backend_dev_description(dev));
5119+
output_printer->print_message(message_type::INFO, " Device description: %s\n", ggml_backend_dev_description(dev));
51505120
size_t free, total; // NOLINT
51515121
ggml_backend_dev_memory(dev, &free, &total);
5152-
output_printer->print_device_info(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
5153-
output_printer->print_device_info("\n");
5122+
output_printer->print_message(message_type::INFO, " Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
5123+
output_printer->print_message(message_type::INFO, "\n");
51545124

51555125
bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
51565126

5157-
output_printer->print_device_info(" Backend %s: ", ggml_backend_name(backend));
5127+
output_printer->print_message(message_type::INFO, " Backend %s: ", ggml_backend_name(backend));
51585128
if (ok) {
5159-
output_printer->print_status_ok();
5129+
output_printer->print_message(message_type::STATUS_OK, "");
51605130
n_ok++;
51615131
} else {
5162-
output_printer->print_status_fail();
5132+
output_printer->print_message(message_type::STATUS_FAIL, "");
51635133
}
51645134

5165-
output_printer->print_device_info("\n");
5135+
output_printer->print_message(message_type::INFO, "\n");
51665136

51675137
ggml_backend_free(backend);
51685138
}
@@ -5173,13 +5143,13 @@ int main(int argc, char ** argv) {
51735143
output_printer->print_footer();
51745144
}
51755145

5176-
output_printer->print_test_summary("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
5146+
output_printer->print_message(message_type::INFO, "%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
51775147

51785148
if (n_ok != ggml_backend_dev_count()) {
5179-
output_printer->print_status_fail();
5149+
output_printer->print_message(message_type::STATUS_FAIL, "");
51805150
return 1;
51815151
}
51825152

5183-
output_printer->print_status_ok();
5153+
output_printer->print_message(message_type::STATUS_OK, "");
51845154
return 0;
51855155
}

0 commit comments

Comments
 (0)