Skip to content

Commit 2a5cd8e

Browse files
committed
Address review comments
Signed-off-by: Xiaodong Ye <yeahdongcn@gmail.com>
1 parent 3cabda8 commit 2a5cd8e

File tree

1 file changed

+67
-97
lines changed

1 file changed

+67
-97
lines changed

tests/test-backend-ops.cpp

Lines changed: 67 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -451,20 +451,20 @@ struct test_result {
451451
};
452452

453453
// Printer classes for different output formats
454+
enum class message_type {
455+
INFO,
456+
ERROR,
457+
STATUS_OK,
458+
STATUS_FAIL
459+
};
460+
454461
struct printer {
455462
virtual ~printer() {}
456463
FILE * fout = stdout;
457464
virtual void print_header() {}
458465
virtual void print_test_result(const test_result & result) = 0;
459466
virtual void print_footer() {}
460-
461-
// General purpose output methods
462-
virtual void print_info(const char * format, ...) = 0;
463-
virtual void print_error(const char * format, ...) = 0;
464-
virtual void print_device_info(const char * format, ...) = 0;
465-
virtual void print_test_summary(const char * format, ...) = 0;
466-
virtual void print_status_ok() = 0;
467-
virtual void print_status_fail() = 0;
467+
virtual void print_message(message_type type, const char * format, ...) = 0;
468468
};
469469

470470
struct console_printer : public printer {
@@ -476,42 +476,28 @@ struct console_printer : public printer {
476476
}
477477
}
478478

479-
void print_info(const char * format, ...) override {
480-
va_list args;
481-
va_start(args, format);
482-
vprintf(format, args);
483-
va_end(args);
484-
}
485-
486-
void print_error(const char * format, ...) override {
479+
void print_message(message_type type, const char * format, ...) override {
487480
va_list args;
488481
va_start(args, format);
489-
vfprintf(stderr, format, args);
490-
va_end(args);
491-
}
492482

493-
void print_device_info(const char * format, ...) override {
494-
va_list args;
495-
va_start(args, format);
496-
vprintf(format, args);
497-
va_end(args);
498-
}
483+
switch (type) {
484+
case message_type::INFO:
485+
vprintf(format, args);
486+
break;
487+
case message_type::ERROR:
488+
vfprintf(stderr, format, args);
489+
break;
490+
case message_type::STATUS_OK:
491+
printf("\033[1;32mOK\033[0m\n");
492+
break;
493+
case message_type::STATUS_FAIL:
494+
printf("\033[1;31mFAIL\033[0m\n");
495+
break;
496+
}
499497

500-
void print_test_summary(const char * format, ...) override {
501-
va_list args;
502-
va_start(args, format);
503-
vprintf(format, args);
504498
va_end(args);
505499
}
506500

507-
void print_status_ok() override {
508-
printf("\033[1;32mOK\033[0m\n");
509-
}
510-
511-
void print_status_fail() override {
512-
printf("\033[1;31mFAIL\033[0m\n");
513-
}
514-
515501
private:
516502
void print_test_console(const test_result & result) {
517503
printf(" %s(%s): ", result.op_name.c_str(), result.op_params.c_str());
@@ -617,32 +603,16 @@ struct sql_printer : public printer {
617603
fprintf(fout, ");\n");
618604
}
619605

620-
// SQL printer ignores general output - only outputs test results
621-
void print_info(const char * format, ...) override {
622-
// Do nothing - SQL format only outputs test results
623-
(void)format;
624-
}
625-
626-
void print_error(const char * format, ...) override {
627-
// Still output errors to stderr for SQL format
628-
va_list args;
629-
va_start(args, format);
630-
vfprintf(stderr, format, args);
631-
va_end(args);
632-
}
633-
634-
void print_device_info(const char * format, ...) override {
635-
(void)format;
636-
}
637-
638-
void print_test_summary(const char * format, ...) override {
639-
(void)format;
640-
}
641-
642-
void print_status_ok() override {
643-
}
644-
645-
void print_status_fail() override {
606+
// SQL printer ignores most output types - only outputs test results and errors
607+
void print_message(message_type type, const char * format, ...) override {
608+
if (type == message_type::ERROR) {
609+
// Still output errors to stderr for SQL format
610+
va_list args;
611+
va_start(args, format);
612+
vfprintf(stderr, format, args);
613+
va_end(args);
614+
}
615+
// All other message types are ignored in SQL format
646616
}
647617
};
648618

@@ -1089,14 +1059,14 @@ struct test_case {
10891059
ggml_tensor * out = build_graph(ctx.get());
10901060

10911061
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1092-
//output_printer->print_info(" %s: skipping\n", op_desc(out).c_str());
1062+
//output_printer->print_message(message_type::INFO, " %s: skipping\n", op_desc(out).c_str());
10931063
return true;
10941064
}
10951065

1096-
output_printer->print_info(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
1066+
output_printer->print_message(message_type::INFO, " %s(%s): ", op_desc(out).c_str(), vars().c_str());
10971067

10981068
if (out->type != GGML_TYPE_F32) {
1099-
output_printer->print_info("not supported [%s->type != FP32]\n", out->name);
1069+
output_printer->print_message(message_type::INFO, "not supported [%s->type != FP32]\n", out->name);
11001070
return true;
11011071
}
11021072

@@ -1105,25 +1075,25 @@ struct test_case {
11051075
bool any_params = false;
11061076
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
11071077
if (!ggml_backend_supports_op(backend, t)) {
1108-
output_printer->print_info("not supported [%s] ", ggml_backend_name(backend));
1078+
output_printer->print_message(message_type::INFO, "not supported [%s] ", ggml_backend_name(backend));
11091079
supported = false;
11101080
break;
11111081
}
11121082
if ((t->flags & GGML_TENSOR_FLAG_PARAM)) {
11131083
any_params = true;
11141084
if (t->type != GGML_TYPE_F32) {
1115-
output_printer->print_info("not supported [%s->type != FP32] ", t->name);
1085+
output_printer->print_message(message_type::INFO, "not supported [%s->type != FP32] ", t->name);
11161086
supported = false;
11171087
break;
11181088
}
11191089
}
11201090
}
11211091
if (!any_params) {
1122-
output_printer->print_info("not supported [%s] \n", op_desc(out).c_str());
1092+
output_printer->print_message(message_type::INFO, "not supported [%s] \n", op_desc(out).c_str());
11231093
supported = false;
11241094
}
11251095
if (!supported) {
1126-
output_printer->print_info("\n");
1096+
output_printer->print_message(message_type::INFO, "\n");
11271097
return true;
11281098
}
11291099

@@ -1134,7 +1104,7 @@ struct test_case {
11341104
}
11351105
}
11361106
if (ngrads > grad_nmax()) {
1137-
output_printer->print_info("skipping large tensors for speed \n");
1107+
output_printer->print_message(message_type::INFO, "skipping large tensors for speed \n");
11381108
return true;
11391109
}
11401110

@@ -1157,25 +1127,25 @@ struct test_case {
11571127

11581128
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
11591129
if (!ggml_backend_supports_op(backend, t)) {
1160-
output_printer->print_info("not supported [%s] ", ggml_backend_name(backend));
1130+
output_printer->print_message(message_type::INFO, "not supported [%s] ", ggml_backend_name(backend));
11611131
supported = false;
11621132
break;
11631133
}
11641134
if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
1165-
output_printer->print_info("not supported [%s->type != FP32] ", t->name);
1135+
output_printer->print_message(message_type::INFO, "not supported [%s->type != FP32] ", t->name);
11661136
supported = false;
11671137
break;
11681138
}
11691139
}
11701140
if (!supported) {
1171-
output_printer->print_info("\n");
1141+
output_printer->print_message(message_type::INFO, "\n");
11721142
return true;
11731143
}
11741144

11751145
// allocate
11761146
ggml_backend_buffer_ptr buf(ggml_backend_alloc_ctx_tensors(ctx.get(), backend)); // smart ptr
11771147
if (buf == NULL) {
1178-
output_printer->print_error("failed to allocate tensors [%s] ", ggml_backend_name(backend));
1148+
output_printer->print_message(message_type::ERROR, "failed to allocate tensors [%s] ", ggml_backend_name(backend));
11791149
return false;
11801150
}
11811151

@@ -1213,7 +1183,7 @@ struct test_case {
12131183
for (int64_t i = 0; i < ne; ++i) { // gradient algebraic
12141184
// check for nans
12151185
if (!std::isfinite(ga[i])) {
1216-
output_printer->print_info("[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
1186+
output_printer->print_message(message_type::INFO, "[%s] nonfinite gradient at index %" PRId64 " (%s=%f) ", ggml_op_desc(t), i, bn, ga[i]);
12171187
ok = false;
12181188
break;
12191189
}
@@ -1281,7 +1251,7 @@ struct test_case {
12811251

12821252
const double err = mean_abs_asymm(gn.data(), ga.data(), gn.size(), expect);
12831253
if (err > max_maa_err()) {
1284-
output_printer->print_info("[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err());
1254+
output_printer->print_message(message_type::INFO, "[%s] MAA = %.9f > %.9f ", ggml_op_desc(t), err, max_maa_err());
12851255
ok = false;
12861256
break;
12871257
}
@@ -1291,15 +1261,15 @@ struct test_case {
12911261
}
12921262

12931263
if (!ok) {
1294-
output_printer->print_info("compare failed ");
1264+
output_printer->print_message(message_type::INFO, "compare failed ");
12951265
}
12961266

12971267
if (ok) {
1298-
output_printer->print_status_ok();
1268+
output_printer->print_message(message_type::STATUS_OK, "");
12991269
return true;
13001270
}
13011271

1302-
output_printer->print_status_fail();
1272+
output_printer->print_message(message_type::STATUS_FAIL, "");
13031273
return false;
13041274
}
13051275
};
@@ -5330,7 +5300,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
53305300
filter_test_cases(test_cases, params_filter);
53315301
ggml_backend_t backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, NULL);
53325302
if (backend_cpu == NULL) {
5333-
output_printer->print_error(" Failed to initialize CPU backend\n");
5303+
output_printer->print_message(message_type::ERROR, " Failed to initialize CPU backend\n");
53345304
return false;
53355305
}
53365306

@@ -5340,7 +5310,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
53405310
n_ok++;
53415311
}
53425312
}
5343-
output_printer->print_test_summary(" %zu/%zu tests passed\n", n_ok, test_cases.size());
5313+
output_printer->print_message(message_type::INFO, " %zu/%zu tests passed\n", n_ok, test_cases.size());
53445314

53455315
ggml_backend_free(backend_cpu);
53465316

@@ -5356,7 +5326,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
53565326
n_ok++;
53575327
}
53585328
}
5359-
output_printer->print_test_summary(" %zu/%zu tests passed\n", n_ok, test_cases.size());
5329+
output_printer->print_message(message_type::INFO, " %zu/%zu tests passed\n", n_ok, test_cases.size());
53605330

53615331
return n_ok == test_cases.size();
53625332
}
@@ -5443,23 +5413,23 @@ int main(int argc, char ** argv) {
54435413
output_printer->print_header();
54445414
}
54455415

5446-
output_printer->print_info("Testing %zu devices\n\n", ggml_backend_dev_count());
5416+
output_printer->print_message(message_type::INFO, "Testing %zu devices\n\n", ggml_backend_dev_count());
54475417

54485418
size_t n_ok = 0;
54495419

54505420
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
54515421
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
54525422

5453-
output_printer->print_device_info("Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev));
5423+
output_printer->print_message(message_type::INFO, "Backend %zu/%zu: %s\n", i + 1, ggml_backend_dev_count(), ggml_backend_dev_name(dev));
54545424

54555425
if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {
5456-
output_printer->print_device_info(" Skipping\n");
5426+
output_printer->print_message(message_type::INFO, " Skipping\n");
54575427
n_ok++;
54585428
continue;
54595429
}
54605430

54615431
if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
5462-
output_printer->print_device_info(" Skipping CPU backend\n");
5432+
output_printer->print_message(message_type::INFO, " Skipping CPU backend\n");
54635433
n_ok++;
54645434
continue;
54655435
}
@@ -5474,23 +5444,23 @@ int main(int argc, char ** argv) {
54745444
ggml_backend_set_n_threads_fn(backend, std::thread::hardware_concurrency());
54755445
}
54765446

5477-
output_printer->print_device_info(" Device description: %s\n", ggml_backend_dev_description(dev));
5447+
output_printer->print_message(message_type::INFO, " Device description: %s\n", ggml_backend_dev_description(dev));
54785448
size_t free, total; // NOLINT
54795449
ggml_backend_dev_memory(dev, &free, &total);
5480-
output_printer->print_device_info(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
5481-
output_printer->print_device_info("\n");
5450+
output_printer->print_message(message_type::INFO, " Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024);
5451+
output_printer->print_message(message_type::INFO, "\n");
54825452

54835453
bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
54845454

5485-
output_printer->print_device_info(" Backend %s: ", ggml_backend_name(backend));
5455+
output_printer->print_message(message_type::INFO, " Backend %s: ", ggml_backend_name(backend));
54865456
if (ok) {
5487-
output_printer->print_status_ok();
5457+
output_printer->print_message(message_type::STATUS_OK, "");
54885458
n_ok++;
54895459
} else {
5490-
output_printer->print_status_fail();
5460+
output_printer->print_message(message_type::STATUS_FAIL, "");
54915461
}
54925462

5493-
output_printer->print_device_info("\n");
5463+
output_printer->print_message(message_type::INFO, "\n");
54945464

54955465
ggml_backend_free(backend);
54965466
}
@@ -5501,13 +5471,13 @@ int main(int argc, char ** argv) {
55015471
output_printer->print_footer();
55025472
}
55035473

5504-
output_printer->print_test_summary("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
5474+
output_printer->print_message(message_type::INFO, "%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count());
55055475

55065476
if (n_ok != ggml_backend_dev_count()) {
5507-
output_printer->print_status_fail();
5477+
output_printer->print_message(message_type::STATUS_FAIL, "");
55085478
return 1;
55095479
}
55105480

5511-
output_printer->print_status_ok();
5481+
output_printer->print_message(message_type::STATUS_OK, "");
55125482
return 0;
55135483
}

0 commit comments

Comments
 (0)