@@ -36,7 +36,6 @@ template <
36
36
bool PONG,
37
37
bool COOP,
38
38
bool FAST_ACCUM,
39
- bool USE_BIAS,
40
39
typename INPUT_DTYPE,
41
40
typename BIAS_DTYPE>
42
41
at::Tensor f8f8bf16_rowwise_impl (
@@ -158,10 +157,7 @@ at::Tensor f8f8bf16_rowwise_impl(
158
157
159
158
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
160
159
cutlass::multiplies,
161
- cute::conditional_t < // Second stage output type.
162
- USE_BIAS,
163
- ElementBias,
164
- ElementOutput>,
160
+ ElementBias, // Second stage output type.
165
161
ElementComputeEpilogue, // Second stage input types.
166
162
cutlass::FloatRoundStyle::round_to_nearest>;
167
163
@@ -174,11 +170,8 @@ at::Tensor f8f8bf16_rowwise_impl(
174
170
ElementBias, // Final stage input types.
175
171
cutlass::FloatRoundStyle::round_to_nearest>;
176
172
177
- using EVTComputeBias =
178
- cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
179
-
180
173
using EpilogueEVT =
181
- cute:: conditional_t <USE_BIAS, EVTComputeBias , EVTCompute1>;
174
+ cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias , EVTCompute1>;
182
175
183
176
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
184
177
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
@@ -273,38 +266,26 @@ at::Tensor f8f8bf16_rowwise_impl(
273
266
(ElementOutput*)Y.data_ptr <at::BFloat16>(),
274
267
stride_output}};
275
268
276
- if constexpr (USE_BIAS) {
277
- arguments.epilogue .thread = {
278
- {reinterpret_cast <ElementBias*>(bias.value ().data_ptr ())}, // bias
279
- // compute_1
280
- {
281
- {reinterpret_cast <ElementComputeEpilogue*>(
282
- w_scale.data_ptr ())}, // x_scale
283
- // compute_0
284
- {
285
- {reinterpret_cast <ElementComputeEpilogue*>(
286
- x_scale.data_ptr ())}, // w_scale
287
- {}, // Accumulator
288
- {} // Multiplies
289
- },
290
- {}, // Multiplies
291
- },
292
- {}, // Plus
293
- };
294
- } else {
295
- arguments.epilogue .thread = {
296
- {reinterpret_cast <ElementComputeEpilogue*>(
297
- w_scale.data_ptr ())}, // x_scale
298
- // compute_0
299
- {
300
- {reinterpret_cast <ElementComputeEpilogue*>(
301
- x_scale.data_ptr ())}, // w_scale
302
- {}, // Accumulator
303
- {} // Multiplies
304
- },
305
- {}, // Multiplies
306
- };
307
- }
269
+ arguments.epilogue .thread = {
270
+ {bias.has_value ()
271
+ ? reinterpret_cast <ElementBias*>(bias.value ().data_ptr ())
272
+ : nullptr }, // bias. Note Cutlass EVT will skip node if argument is
273
+ // nullptr
274
+ // compute_1
275
+ {
276
+ {reinterpret_cast <ElementComputeEpilogue*>(
277
+ w_scale.data_ptr ())}, // x_scale
278
+ // compute_0
279
+ {
280
+ {reinterpret_cast <ElementComputeEpilogue*>(
281
+ x_scale.data_ptr ())}, // w_scale
282
+ {}, // Accumulator
283
+ {} // Multiplies
284
+ },
285
+ {}, // Multiplies
286
+ },
287
+ {}, // Plus
288
+ };
308
289
309
290
Gemm gemm;
310
291
@@ -367,144 +348,71 @@ at::Tensor f8f8bf16_rowwise_wrapper(
367
348
bias.value ().dtype () == at::kBFloat16 ,
368
349
" Bias type must be bfloat16 or float32 if provided." );
369
350
}
370
- bool use_bias = bias.has_value ();
371
- bool bf16_bias = use_bias && bias.value ().dtype () == at::kBFloat16 ;
351
+ bool bf16_bias = bias.has_value () && bias.value ().dtype () == at::kBFloat16 ;
372
352
373
353
// Templatize based on input dtype.
374
354
bool use_e5m2 = XQ.dtype () == at::kFloat8_e5m2 ;
375
355
376
- if (use_bias) {
377
- if (bf16_bias) {
378
- if (use_fast_accum) {
379
- if (use_e5m2) {
380
- return f8f8bf16_rowwise_impl<
381
- TB_M,
382
- TB_N,
383
- TB_K,
384
- TBS_M,
385
- TBS_N,
386
- TBS_K,
387
- ARCH,
388
- PONG,
389
- COOP,
390
- true ,
391
- true ,
392
- cutlass::float_e5m2_t ,
393
- cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
394
- } else {
395
- return f8f8bf16_rowwise_impl<
396
- TB_M,
397
- TB_N,
398
- TB_K,
399
- TBS_M,
400
- TBS_N,
401
- TBS_K,
402
- ARCH,
403
- PONG,
404
- COOP,
405
- true ,
406
- true ,
407
- cutlass::float_e4m3_t ,
408
- cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
409
- }
356
+ if (bf16_bias) {
357
+ if (use_fast_accum) {
358
+ if (use_e5m2) {
359
+ return f8f8bf16_rowwise_impl<
360
+ TB_M,
361
+ TB_N,
362
+ TB_K,
363
+ TBS_M,
364
+ TBS_N,
365
+ TBS_K,
366
+ ARCH,
367
+ PONG,
368
+ COOP,
369
+ true ,
370
+ cutlass::float_e5m2_t ,
371
+ cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
410
372
} else {
411
- if (use_e5m2) {
412
- return f8f8bf16_rowwise_impl<
413
- TB_M,
414
- TB_N,
415
- TB_K,
416
- TBS_M,
417
- TBS_N,
418
- TBS_K,
419
- ARCH,
420
- PONG,
421
- COOP,
422
- false ,
423
- true ,
424
- cutlass::float_e5m2_t ,
425
- cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
426
- } else {
427
- return f8f8bf16_rowwise_impl<
428
- TB_M,
429
- TB_N,
430
- TB_K,
431
- TBS_M,
432
- TBS_N,
433
- TBS_K,
434
- ARCH,
435
- PONG,
436
- COOP,
437
- false ,
438
- true ,
439
- cutlass::float_e4m3_t ,
440
- cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
441
- }
373
+ return f8f8bf16_rowwise_impl<
374
+ TB_M,
375
+ TB_N,
376
+ TB_K,
377
+ TBS_M,
378
+ TBS_N,
379
+ TBS_K,
380
+ ARCH,
381
+ PONG,
382
+ COOP,
383
+ true ,
384
+ cutlass::float_e4m3_t ,
385
+ cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
442
386
}
443
387
} else {
444
- if (use_fast_accum) {
445
- if (use_e5m2) {
446
- return f8f8bf16_rowwise_impl<
447
- TB_M,
448
- TB_N,
449
- TB_K,
450
- TBS_M,
451
- TBS_N,
452
- TBS_K,
453
- ARCH,
454
- PONG,
455
- COOP,
456
- true ,
457
- true ,
458
- cutlass::float_e5m2_t ,
459
- float >(XQ, WQ, x_scale, w_scale, bias, output);
460
- } else {
461
- return f8f8bf16_rowwise_impl<
462
- TB_M,
463
- TB_N,
464
- TB_K,
465
- TBS_M,
466
- TBS_N,
467
- TBS_K,
468
- ARCH,
469
- PONG,
470
- COOP,
471
- true ,
472
- true ,
473
- cutlass::float_e4m3_t ,
474
- float >(XQ, WQ, x_scale, w_scale, bias, output);
475
- }
388
+ if (use_e5m2) {
389
+ return f8f8bf16_rowwise_impl<
390
+ TB_M,
391
+ TB_N,
392
+ TB_K,
393
+ TBS_M,
394
+ TBS_N,
395
+ TBS_K,
396
+ ARCH,
397
+ PONG,
398
+ COOP,
399
+ false ,
400
+ cutlass::float_e5m2_t ,
401
+ cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
476
402
} else {
477
- if (use_e5m2) {
478
- return f8f8bf16_rowwise_impl<
479
- TB_M,
480
- TB_N,
481
- TB_K,
482
- TBS_M,
483
- TBS_N,
484
- TBS_K,
485
- ARCH,
486
- PONG,
487
- COOP,
488
- false ,
489
- true ,
490
- cutlass::float_e5m2_t ,
491
- float >(XQ, WQ, x_scale, w_scale, bias, output);
492
- } else {
493
- return f8f8bf16_rowwise_impl<
494
- TB_M,
495
- TB_N,
496
- TB_K,
497
- TBS_M,
498
- TBS_N,
499
- TBS_K,
500
- ARCH,
501
- PONG,
502
- COOP,
503
- false ,
504
- true ,
505
- cutlass::float_e4m3_t ,
506
- float >(XQ, WQ, x_scale, w_scale, bias, output);
507
- }
403
+ return f8f8bf16_rowwise_impl<
404
+ TB_M,
405
+ TB_N,
406
+ TB_K,
407
+ TBS_M,
408
+ TBS_N,
409
+ TBS_K,
410
+ ARCH,
411
+ PONG,
412
+ COOP,
413
+ false ,
414
+ cutlass::float_e4m3_t ,
415
+ cutlass::bfloat16_t >(XQ, WQ, x_scale, w_scale, bias, output);
508
416
}
509
417
}
510
418
} else {
@@ -521,7 +429,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
521
429
PONG,
522
430
COOP,
523
431
true ,
524
- false ,
525
432
cutlass::float_e5m2_t ,
526
433
float >(XQ, WQ, x_scale, w_scale, bias, output);
527
434
} else {
@@ -536,7 +443,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
536
443
PONG,
537
444
COOP,
538
445
true ,
539
- false ,
540
446
cutlass::float_e4m3_t ,
541
447
float >(XQ, WQ, x_scale, w_scale, bias, output);
542
448
}
@@ -553,7 +459,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
553
459
PONG,
554
460
COOP,
555
461
false ,
556
- false ,
557
462
cutlass::float_e5m2_t ,
558
463
float >(XQ, WQ, x_scale, w_scale, bias, output);
559
464
} else {
@@ -568,7 +473,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
568
473
PONG,
569
474
COOP,
570
475
false ,
571
- false ,
572
476
cutlass::float_e4m3_t ,
573
477
float >(XQ, WQ, x_scale, w_scale, bias, output);
574
478
}
0 commit comments