@@ -2353,9 +2353,12 @@ struct test_bin_bcast : public test_case {
2353
2353
const ggml_type type;
2354
2354
const std::array<int64_t , 4 > ne;
2355
2355
const std::array<int , 4 > nr;
2356
+ int nf; // number of fused ops, nf == 1 -> single op (no fusion)
2357
+
2358
+ bool run_whole_graph () override { return true ; }
2356
2359
2357
2360
std::string vars () override {
2358
- return VARS_TO_STR3 (type, ne, nr);
2361
+ return VARS_TO_STR4 (type, ne, nr, nf );
2359
2362
}
2360
2363
2361
2364
size_t op_size (ggml_tensor * t) override {
@@ -2364,24 +2367,35 @@ struct test_bin_bcast : public test_case {
2364
2367
2365
2368
test_bin_bcast (op_t op, ggml_type type = GGML_TYPE_F32,
2366
2369
std::array<int64_t , 4 > ne = {10 , 10 , 1 , 1 },
2367
- std::array<int , 4 > nr = {1 , 2 , 1 , 1 })
2368
- : op(op), type(type), ne(ne), nr(nr) {}
2370
+ std::array<int , 4 > nr = {1 , 2 , 1 , 1 },
2371
+ int nf = 1 )
2372
+ : op(op), type(type), ne(ne), nr(nr), nf(nf) {}
2369
2373
2370
2374
ggml_tensor * build_graph (ggml_context * ctx) override {
2375
+ GGML_ASSERT (nf <= 8 );
2376
+
2371
2377
ggml_tensor * a = ggml_new_tensor_4d (ctx, type, ne[0 ]*nr[0 ], ne[1 ]*nr[1 ], ne[2 ]*nr[2 ], ne[3 ]*nr[3 ]);
2372
2378
ggml_set_name (a, " a" );
2373
2379
2374
- ggml_tensor * b = ggml_new_tensor (ctx, type, 4 , ne.data ());
2375
- ggml_set_name (b, " b" );
2380
+ ggml_tensor * b[8 ];
2381
+ for (int i = 0 ; i < nf; ++i) {
2382
+ b[i] = ggml_new_tensor (ctx, type, 4 , ne.data ());
2383
+ ggml_set_name (b[i], (std::string (" b" ) + std::to_string (i)).c_str ());
2384
+ }
2376
2385
2377
2386
// The backward pass supports broadcasting only for GGML_ADD:
2378
- const bool grad_supported = op == ggml_add || ggml_are_same_shape (a, b) ;
2387
+ const bool grad_supported = op == ggml_add && ggml_are_same_shape (a, b[ 0 ]) && nf == 1 ;
2379
2388
if (grad_supported) {
2380
2389
ggml_set_param (a);
2381
- ggml_set_param (b);
2390
+ ggml_set_param (b[0 ]);
2391
+ }
2392
+
2393
+ ggml_tensor * out = a;
2394
+
2395
+ for (int i = 0 ; i < nf; ++i) {
2396
+ out = op (ctx, out, b[i]);
2382
2397
}
2383
2398
2384
- ggml_tensor * out = op (ctx, a, b);
2385
2399
ggml_set_name (out, " out" );
2386
2400
2387
2401
return out;
@@ -5151,6 +5165,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
5151
5165
// add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1});
5152
5166
}
5153
5167
5168
+ // fusion
5169
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {2 , 1 , 1 , 1 }, 2 ));
5170
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {16 , 5 , 4 , 3 }, {1 , 2 , 1 , 1 }, 3 ));
5171
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 1 , 2 , 1 }, 4 ));
5172
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {16 , 5 , 4 , 3 }, {1 , 1 , 1 , 2 }, 5 ));
5173
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 1 , 2 , 2 }, 6 ));
5174
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {10 , 5 , 4 , 3 }, {1 , 2 , 2 , 2 }, 7 ));
5175
+ test_cases.emplace_back (new test_bin_bcast (ggml_add, GGML_TYPE_F32, {16 , 5 , 4 , 3 }, {2 , 2 , 2 , 2 }, 8 ));
5176
+
5154
5177
test_cases.emplace_back (new test_add1 ());
5155
5178
test_cases.emplace_back (new test_scale ());
5156
5179
test_cases.emplace_back (new test_scale (GGML_TYPE_F32, {10 , 10 , 10 , 10 }, 2 .0f , 1 .0f ));
0 commit comments