Skip to content

Commit e8bc633

Browse files
slarenyeahdongcn
authored andcommitted
remove visitor nonsense
1 parent a3252a0 commit e8bc633

File tree

1 file changed

+38
-106
lines changed

1 file changed

+38
-106
lines changed

tests/test-backend-ops.cpp

Lines changed: 38 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -442,34 +442,14 @@ struct test_result {
442442
}
443443
};
444444

445-
// Forward declarations for the visitor pattern
446-
struct message_visitor;
447-
448-
// Base class for all message types that can be printed
449-
struct message_data {
450-
virtual ~message_data() {}
451-
virtual void accept(message_visitor& visitor) const = 0;
452-
};
453-
454-
// Message visitor interface
455-
struct message_visitor {
456-
virtual ~message_visitor() {}
457-
virtual void visit(const struct test_operation_info& info) = 0;
458-
virtual void visit(const struct test_summary_info& info) = 0;
459-
virtual void visit(const struct testing_start_info& info) = 0;
460-
virtual void visit(const struct backend_init_info& info) = 0;
461-
virtual void visit(const struct backend_status_info& info) = 0;
462-
virtual void visit(const struct overall_summary_info& info) = 0;
463-
};
464-
465445
// Printer classes for different output formats
466446
enum class test_status_t {
467447
NOT_SUPPORTED,
468448
OK,
469449
FAIL
470450
};
471451

472-
struct test_operation_info : public message_data {
452+
struct test_operation_info {
473453
std::string op_name;
474454
std::string op_params;
475455
std::string backend_name;
@@ -501,10 +481,6 @@ struct test_operation_info : public message_data {
501481
test_status_t status = test_status_t::OK, const std::string& failure_reason = "")
502482
: op_name(op_name), op_params(op_params), backend_name(backend_name), status(status), failure_reason(failure_reason) {}
503483

504-
void accept(message_visitor& visitor) const override {
505-
visitor.visit(*this);
506-
}
507-
508484
// Set error information
509485
void set_error(const std::string& component, const std::string& details) {
510486
has_error = true;
@@ -550,32 +526,24 @@ struct test_operation_info : public message_data {
550526
}
551527
};
552528

553-
struct test_summary_info : public message_data {
529+
struct test_summary_info {
554530
size_t tests_passed;
555531
size_t tests_total;
556532
bool is_backend_summary = false; // true for backend summary, false for test summary
557533

558534
test_summary_info() = default;
559535
test_summary_info(size_t tests_passed, size_t tests_total, bool is_backend_summary = false)
560536
: tests_passed(tests_passed), tests_total(tests_total), is_backend_summary(is_backend_summary) {}
561-
562-
void accept(message_visitor& visitor) const override {
563-
visitor.visit(*this);
564-
}
565537
};
566538

567-
struct testing_start_info : public message_data {
539+
struct testing_start_info {
568540
size_t device_count;
569541

570542
testing_start_info() = default;
571543
testing_start_info(size_t device_count) : device_count(device_count) {}
572-
573-
void accept(message_visitor& visitor) const override {
574-
visitor.visit(*this);
575-
}
576544
};
577545

578-
struct backend_init_info : public message_data {
546+
struct backend_init_info {
579547
size_t device_index;
580548
size_t total_devices;
581549
std::string device_name;
@@ -593,75 +561,39 @@ struct backend_init_info : public message_data {
593561
: device_index(device_index), total_devices(total_devices), device_name(device_name), skipped(skipped),
594562
skip_reason(skip_reason), description(description), memory_total_mb(memory_total_mb),
595563
memory_free_mb(memory_free_mb), has_memory_info(has_memory_info) {}
596-
597-
void accept(message_visitor& visitor) const override {
598-
visitor.visit(*this);
599-
}
600564
};
601565

602-
struct backend_status_info : public message_data {
566+
struct backend_status_info {
603567
std::string backend_name;
604568
test_status_t status;
605569

606570
backend_status_info() = default;
607571
backend_status_info(const std::string& backend_name, test_status_t status)
608572
: backend_name(backend_name), status(status) {}
609-
610-
void accept(message_visitor& visitor) const override {
611-
visitor.visit(*this);
612-
}
613573
};
614574

615-
struct overall_summary_info : public message_data {
575+
struct overall_summary_info {
616576
size_t backends_passed;
617577
size_t backends_total;
618578
bool all_passed;
619579

620580
overall_summary_info() = default;
621581
overall_summary_info(size_t backends_passed, size_t backends_total, bool all_passed)
622582
: backends_passed(backends_passed), backends_total(backends_total), all_passed(all_passed) {}
623-
624-
void accept(message_visitor& visitor) const override {
625-
visitor.visit(*this);
626-
}
627583
};
628584

629-
struct printer : public message_visitor {
585+
struct printer {
630586
virtual ~printer() {}
631587
FILE * fout = stdout;
632588
virtual void print_header() {}
633589
virtual void print_test_result(const test_result & result) = 0;
634590
virtual void print_footer() {}
635-
636-
void print_message(const message_data& data) {
637-
data.accept(*this);
638-
}
639-
640-
// Default implementations for all visit methods (no-op)
641-
// Derived classes can override only the ones they care about
642-
void visit(const test_operation_info& info) override {
643-
(void)info;
644-
}
645-
646-
void visit(const test_summary_info& info) override {
647-
(void)info;
648-
}
649-
650-
void visit(const testing_start_info& info) override {
651-
(void)info;
652-
}
653-
654-
void visit(const backend_init_info& info) override {
655-
(void)info;
656-
}
657-
658-
void visit(const backend_status_info& info) override {
659-
(void)info;
660-
}
661-
662-
void visit(const overall_summary_info& info) override {
663-
(void)info;
664-
}
591+
virtual void print_operation(const test_operation_info & info) { (void) info; }
592+
virtual void print_summary(const test_summary_info & info) { (void) info; }
593+
virtual void print_testing_start(const testing_start_info & info) { (void) info; }
594+
virtual void print_backend_init(const backend_init_info & info) { (void) info; }
595+
virtual void print_backend_status(const backend_status_info & info) { (void) info; }
596+
virtual void print_overall_summary(const overall_summary_info & info) { (void) info; }
665597
};
666598

667599
struct console_printer : public printer {
@@ -674,7 +606,7 @@ struct console_printer : public printer {
674606
}
675607

676608
// Visitor pattern implementations
677-
void visit(const test_operation_info& info) override {
609+
void print_operation(const test_operation_info& info) override {
678610
printf(" %s(%s): ", info.op_name.c_str(), info.op_params.c_str());
679611
fflush(stdout);
680612

@@ -729,15 +661,15 @@ struct console_printer : public printer {
729661
}
730662
}
731663

732-
void visit(const test_summary_info& info) override {
664+
void print_summary(const test_summary_info& info) override {
733665
if (info.is_backend_summary) {
734666
printf("%zu/%zu backends passed\n", info.tests_passed, info.tests_total);
735667
} else {
736668
printf(" %zu/%zu tests passed\n", info.tests_passed, info.tests_total);
737669
}
738670
}
739671

740-
void visit(const backend_status_info& info) override {
672+
void print_backend_status(const backend_status_info& info) override {
741673
printf(" Backend %s: ", info.backend_name.c_str());
742674
if (info.status == test_status_t::OK) {
743675
printf("\033[1;32mOK\033[0m\n");
@@ -746,11 +678,11 @@ struct console_printer : public printer {
746678
}
747679
}
748680

749-
void visit(const testing_start_info& info) override {
681+
void print_testing_start(const testing_start_info& info) override {
750682
printf("Testing %zu devices\n\n", info.device_count);
751683
}
752684

753-
void visit(const backend_init_info& info) override {
685+
void print_backend_init(const backend_init_info& info) override {
754686
printf("Backend %zu/%zu: %s\n", info.device_index + 1, info.total_devices, info.device_name.c_str());
755687

756688
if (info.skipped) {
@@ -769,7 +701,7 @@ struct console_printer : public printer {
769701
printf("\n");
770702
}
771703

772-
void visit(const overall_summary_info& info) override {
704+
void print_overall_summary(const overall_summary_info& info) override {
773705
printf("%zu/%zu backends passed\n", info.backends_passed, info.backends_total);
774706
if (info.all_passed) {
775707
printf("\033[1;32mOK\033[0m\n");
@@ -1331,12 +1263,12 @@ struct test_case {
13311263
}
13321264

13331265
if (out->type != GGML_TYPE_F32) {
1334-
output_printer->print_message(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, out->name + std::string("->type != FP32")));
1266+
output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, out->name + std::string("->type != FP32")));
13351267
return true;
13361268
}
13371269

13381270
// Print operation info first
1339-
output_printer->print_message(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend)));
1271+
output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend)));
13401272

13411273
// check if the backend supports the ops
13421274
bool supported = true;
@@ -1364,7 +1296,7 @@ struct test_case {
13641296
}
13651297

13661298
if (!supported) {
1367-
output_printer->print_message(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, failure_reason));
1299+
output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, failure_reason));
13681300
return true;
13691301
}
13701302

@@ -1377,7 +1309,7 @@ struct test_case {
13771309
if (ngrads > grad_nmax()) {
13781310
test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
13791311
info.set_large_tensor_skip();
1380-
output_printer->print_message(info);
1312+
output_printer->print_operation(info);
13811313
return true;
13821314
}
13831315

@@ -1400,12 +1332,12 @@ struct test_case {
14001332

14011333
for (ggml_tensor * t = ggml_get_first_tensor(ctx.get()); t != NULL; t = ggml_get_next_tensor(ctx.get(), t)) {
14021334
if (!ggml_backend_supports_op(backend, t)) {
1403-
output_printer->print_message(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, ggml_backend_name(backend)));
1335+
output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, ggml_backend_name(backend)));
14041336
supported = false;
14051337
break;
14061338
}
14071339
if ((t->flags & GGML_TENSOR_FLAG_PARAM) && t->type != GGML_TYPE_F32) {
1408-
output_printer->print_message(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, std::string(t->name) + "->type != FP32"));
1340+
output_printer->print_operation(test_operation_info(op_desc(out), vars(), ggml_backend_name(backend), test_status_t::NOT_SUPPORTED, std::string(t->name) + "->type != FP32"));
14091341
supported = false;
14101342
break;
14111343
}
@@ -1419,7 +1351,7 @@ struct test_case {
14191351
if (buf == NULL) {
14201352
test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
14211353
info.set_error("allocation", "");
1422-
output_printer->print_message(info);
1354+
output_printer->print_operation(info);
14231355
return false;
14241356
}
14251357

@@ -1459,7 +1391,7 @@ struct test_case {
14591391
if (!std::isfinite(ga[i])) {
14601392
test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
14611393
info.set_gradient_info(i, bn, ga[i]);
1462-
output_printer->print_message(info);
1394+
output_printer->print_operation(info);
14631395
ok = false;
14641396
break;
14651397
}
@@ -1529,7 +1461,7 @@ struct test_case {
15291461
if (err > max_maa_err()) {
15301462
test_operation_info info(op_desc(out), vars(), ggml_backend_name(backend));
15311463
info.set_maa_error(err, max_maa_err());
1532-
output_printer->print_message(info);
1464+
output_printer->print_operation(info);
15331465
ok = false;
15341466
break;
15351467
}
@@ -1544,7 +1476,7 @@ struct test_case {
15441476
final_info.set_compare_failure();
15451477
}
15461478
final_info.status = ok ? test_status_t::OK : test_status_t::FAIL;
1547-
output_printer->print_message(final_info);
1479+
output_printer->print_operation(final_info);
15481480

15491481
if (ok) {
15501482
return true;
@@ -5582,7 +5514,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
55825514
if (backend_cpu == NULL) {
55835515
test_operation_info info("", "", "CPU");
55845516
info.set_error("backend", "Failed to initialize CPU backend");
5585-
output_printer->print_message(info);
5517+
output_printer->print_operation(info);
55865518
return false;
55875519
}
55885520

@@ -5592,7 +5524,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
55925524
n_ok++;
55935525
}
55945526
}
5595-
output_printer->print_message(test_summary_info(n_ok, test_cases.size(), false));
5527+
output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
55965528

55975529
ggml_backend_free(backend_cpu);
55985530

@@ -5608,7 +5540,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
56085540
n_ok++;
56095541
}
56105542
}
5611-
output_printer->print_message(test_summary_info(n_ok, test_cases.size(), false));
5543+
output_printer->print_summary(test_summary_info(n_ok, test_cases.size(), false));
56125544

56135545
return n_ok == test_cases.size();
56145546
}
@@ -5695,21 +5627,21 @@ int main(int argc, char ** argv) {
56955627
output_printer->print_header();
56965628
}
56975629

5698-
output_printer->print_message(testing_start_info(ggml_backend_dev_count()));
5630+
output_printer->print_testing_start(testing_start_info(ggml_backend_dev_count()));
56995631

57005632
size_t n_ok = 0;
57015633

57025634
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
57035635
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
57045636

57055637
if (backend_filter != NULL && strcmp(backend_filter, ggml_backend_dev_name(dev)) != 0) {
5706-
output_printer->print_message(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping"));
5638+
output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping"));
57075639
n_ok++;
57085640
continue;
57095641
}
57105642

57115643
if (backend_filter == NULL && ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU && mode != MODE_GRAD) {
5712-
output_printer->print_message(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend"));
5644+
output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), true, "Skipping CPU backend"));
57135645
n_ok++;
57145646
continue;
57155647
}
@@ -5726,14 +5658,14 @@ int main(int argc, char ** argv) {
57265658

57275659
size_t free, total; // NOLINT
57285660
ggml_backend_dev_memory(dev, &free, &total);
5729-
output_printer->print_message(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), false, "", ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024, true));
5661+
output_printer->print_backend_init(backend_init_info(i, ggml_backend_dev_count(), ggml_backend_dev_name(dev), false, "", ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024, true));
57305662

57315663
bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
57325664

57335665
if (ok) {
57345666
n_ok++;
57355667
}
5736-
output_printer->print_message(backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL));
5668+
output_printer->print_backend_status(backend_status_info(ggml_backend_name(backend), ok ? test_status_t::OK : test_status_t::FAIL));
57375669

57385670
ggml_backend_free(backend);
57395671
}
@@ -5744,7 +5676,7 @@ int main(int argc, char ** argv) {
57445676
output_printer->print_footer();
57455677
}
57465678

5747-
output_printer->print_message(overall_summary_info(n_ok, ggml_backend_dev_count(), n_ok == ggml_backend_dev_count()));
5679+
output_printer->print_overall_summary(overall_summary_info(n_ok, ggml_backend_dev_count(), n_ok == ggml_backend_dev_count()));
57485680

57495681
if (n_ok != ggml_backend_dev_count()) {
57505682
return 1;

0 commit comments

Comments
 (0)