Skip to content

Commit 462a8b3

Browse files
cthifacebook-github-bot
authored andcommitted
Don't compile additional kernels without bias for cutlass FP8 (pytorch#4409)
Summary: Pull Request resolved: pytorch#4409 X-link: facebookresearch/FBGEMM#1480 Currently we compile multiple variants of FP8 rowwise and FP8 rowwise batched. One of these is due to `USE_BIAS`. While examining the torch core implementation of FP8 rowwise (which has drifted slightly from the FBGEMM version), [I noticed it doesn't do this](https://github.com/pytorch/pytorch/blob/644cc58dfffe1b5bd15688495551b49462c163f6/aten/src/ATen/native/cuda/RowwiseScaledMM.cu#L253). In [this torch PR](pytorch/pytorch#134113) it was found out that CUTLASS will just skip that part of the epilogue if the argument is nullptr. I tested this out and it seems to be the case - correctness and performance seem to check out. This should reduce footprint of these 2 kernels by 33%. As a side note - this really highlights one big win if we could merge the implementations, to benefit from more eyes on the kernels. Reviewed By: jiawenliu64 Differential Revision: D77447346 fbshipit-source-id: ea3fbc61d06d0199eed9c40bfcdd06aa9735335b
1 parent 387e8a0 commit 462a8b3

File tree

3 files changed

+187
-358
lines changed

3 files changed

+187
-358
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise/f8f8bf16_rowwise_common.cuh

Lines changed: 79 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ template <
3636
bool PONG,
3737
bool COOP,
3838
bool FAST_ACCUM,
39-
bool USE_BIAS,
4039
typename INPUT_DTYPE,
4140
typename BIAS_DTYPE>
4241
at::Tensor f8f8bf16_rowwise_impl(
@@ -158,10 +157,7 @@ at::Tensor f8f8bf16_rowwise_impl(
158157

159158
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
160159
cutlass::multiplies,
161-
cute::conditional_t< // Second stage output type.
162-
USE_BIAS,
163-
ElementBias,
164-
ElementOutput>,
160+
ElementBias, // Second stage output type.
165161
ElementComputeEpilogue, // Second stage input types.
166162
cutlass::FloatRoundStyle::round_to_nearest>;
167163

@@ -174,11 +170,8 @@ at::Tensor f8f8bf16_rowwise_impl(
174170
ElementBias, // Final stage input types.
175171
cutlass::FloatRoundStyle::round_to_nearest>;
176172

177-
using EVTComputeBias =
178-
cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
179-
180173
using EpilogueEVT =
181-
cute::conditional_t<USE_BIAS, EVTComputeBias, EVTCompute1>;
174+
cutlass::epilogue::fusion::Sm90EVT<ComputeBias, Bias, EVTCompute1>;
182175

183176
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
184177
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
@@ -273,38 +266,26 @@ at::Tensor f8f8bf16_rowwise_impl(
273266
(ElementOutput*)Y.data_ptr<at::BFloat16>(),
274267
stride_output}};
275268

276-
if constexpr (USE_BIAS) {
277-
arguments.epilogue.thread = {
278-
{reinterpret_cast<ElementBias*>(bias.value().data_ptr())}, // bias
279-
// compute_1
280-
{
281-
{reinterpret_cast<ElementComputeEpilogue*>(
282-
w_scale.data_ptr())}, // x_scale
283-
// compute_0
284-
{
285-
{reinterpret_cast<ElementComputeEpilogue*>(
286-
x_scale.data_ptr())}, // w_scale
287-
{}, // Accumulator
288-
{} // Multiplies
289-
},
290-
{}, // Multiplies
291-
},
292-
{}, // Plus
293-
};
294-
} else {
295-
arguments.epilogue.thread = {
296-
{reinterpret_cast<ElementComputeEpilogue*>(
297-
w_scale.data_ptr())}, // x_scale
298-
// compute_0
299-
{
300-
{reinterpret_cast<ElementComputeEpilogue*>(
301-
x_scale.data_ptr())}, // w_scale
302-
{}, // Accumulator
303-
{} // Multiplies
304-
},
305-
{}, // Multiplies
306-
};
307-
}
269+
arguments.epilogue.thread = {
270+
{bias.has_value()
271+
? reinterpret_cast<ElementBias*>(bias.value().data_ptr())
272+
: nullptr}, // bias. Note Cutlass EVT will skip node if argument is
273+
// nullptr
274+
// compute_1
275+
{
276+
{reinterpret_cast<ElementComputeEpilogue*>(
277+
w_scale.data_ptr())}, // x_scale
278+
// compute_0
279+
{
280+
{reinterpret_cast<ElementComputeEpilogue*>(
281+
x_scale.data_ptr())}, // w_scale
282+
{}, // Accumulator
283+
{} // Multiplies
284+
},
285+
{}, // Multiplies
286+
},
287+
{}, // Plus
288+
};
308289

309290
Gemm gemm;
310291

@@ -367,144 +348,71 @@ at::Tensor f8f8bf16_rowwise_wrapper(
367348
bias.value().dtype() == at::kBFloat16,
368349
"Bias type must be bfloat16 or float32 if provided.");
369350
}
370-
bool use_bias = bias.has_value();
371-
bool bf16_bias = use_bias && bias.value().dtype() == at::kBFloat16;
351+
bool bf16_bias = bias.has_value() && bias.value().dtype() == at::kBFloat16;
372352

373353
// Templatize based on input dtype.
374354
bool use_e5m2 = XQ.dtype() == at::kFloat8_e5m2;
375355

376-
if (use_bias) {
377-
if (bf16_bias) {
378-
if (use_fast_accum) {
379-
if (use_e5m2) {
380-
return f8f8bf16_rowwise_impl<
381-
TB_M,
382-
TB_N,
383-
TB_K,
384-
TBS_M,
385-
TBS_N,
386-
TBS_K,
387-
ARCH,
388-
PONG,
389-
COOP,
390-
true,
391-
true,
392-
cutlass::float_e5m2_t,
393-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
394-
} else {
395-
return f8f8bf16_rowwise_impl<
396-
TB_M,
397-
TB_N,
398-
TB_K,
399-
TBS_M,
400-
TBS_N,
401-
TBS_K,
402-
ARCH,
403-
PONG,
404-
COOP,
405-
true,
406-
true,
407-
cutlass::float_e4m3_t,
408-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
409-
}
356+
if (bf16_bias) {
357+
if (use_fast_accum) {
358+
if (use_e5m2) {
359+
return f8f8bf16_rowwise_impl<
360+
TB_M,
361+
TB_N,
362+
TB_K,
363+
TBS_M,
364+
TBS_N,
365+
TBS_K,
366+
ARCH,
367+
PONG,
368+
COOP,
369+
true,
370+
cutlass::float_e5m2_t,
371+
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
410372
} else {
411-
if (use_e5m2) {
412-
return f8f8bf16_rowwise_impl<
413-
TB_M,
414-
TB_N,
415-
TB_K,
416-
TBS_M,
417-
TBS_N,
418-
TBS_K,
419-
ARCH,
420-
PONG,
421-
COOP,
422-
false,
423-
true,
424-
cutlass::float_e5m2_t,
425-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
426-
} else {
427-
return f8f8bf16_rowwise_impl<
428-
TB_M,
429-
TB_N,
430-
TB_K,
431-
TBS_M,
432-
TBS_N,
433-
TBS_K,
434-
ARCH,
435-
PONG,
436-
COOP,
437-
false,
438-
true,
439-
cutlass::float_e4m3_t,
440-
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
441-
}
373+
return f8f8bf16_rowwise_impl<
374+
TB_M,
375+
TB_N,
376+
TB_K,
377+
TBS_M,
378+
TBS_N,
379+
TBS_K,
380+
ARCH,
381+
PONG,
382+
COOP,
383+
true,
384+
cutlass::float_e4m3_t,
385+
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
442386
}
443387
} else {
444-
if (use_fast_accum) {
445-
if (use_e5m2) {
446-
return f8f8bf16_rowwise_impl<
447-
TB_M,
448-
TB_N,
449-
TB_K,
450-
TBS_M,
451-
TBS_N,
452-
TBS_K,
453-
ARCH,
454-
PONG,
455-
COOP,
456-
true,
457-
true,
458-
cutlass::float_e5m2_t,
459-
float>(XQ, WQ, x_scale, w_scale, bias, output);
460-
} else {
461-
return f8f8bf16_rowwise_impl<
462-
TB_M,
463-
TB_N,
464-
TB_K,
465-
TBS_M,
466-
TBS_N,
467-
TBS_K,
468-
ARCH,
469-
PONG,
470-
COOP,
471-
true,
472-
true,
473-
cutlass::float_e4m3_t,
474-
float>(XQ, WQ, x_scale, w_scale, bias, output);
475-
}
388+
if (use_e5m2) {
389+
return f8f8bf16_rowwise_impl<
390+
TB_M,
391+
TB_N,
392+
TB_K,
393+
TBS_M,
394+
TBS_N,
395+
TBS_K,
396+
ARCH,
397+
PONG,
398+
COOP,
399+
false,
400+
cutlass::float_e5m2_t,
401+
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
476402
} else {
477-
if (use_e5m2) {
478-
return f8f8bf16_rowwise_impl<
479-
TB_M,
480-
TB_N,
481-
TB_K,
482-
TBS_M,
483-
TBS_N,
484-
TBS_K,
485-
ARCH,
486-
PONG,
487-
COOP,
488-
false,
489-
true,
490-
cutlass::float_e5m2_t,
491-
float>(XQ, WQ, x_scale, w_scale, bias, output);
492-
} else {
493-
return f8f8bf16_rowwise_impl<
494-
TB_M,
495-
TB_N,
496-
TB_K,
497-
TBS_M,
498-
TBS_N,
499-
TBS_K,
500-
ARCH,
501-
PONG,
502-
COOP,
503-
false,
504-
true,
505-
cutlass::float_e4m3_t,
506-
float>(XQ, WQ, x_scale, w_scale, bias, output);
507-
}
403+
return f8f8bf16_rowwise_impl<
404+
TB_M,
405+
TB_N,
406+
TB_K,
407+
TBS_M,
408+
TBS_N,
409+
TBS_K,
410+
ARCH,
411+
PONG,
412+
COOP,
413+
false,
414+
cutlass::float_e4m3_t,
415+
cutlass::bfloat16_t>(XQ, WQ, x_scale, w_scale, bias, output);
508416
}
509417
}
510418
} else {
@@ -521,7 +429,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
521429
PONG,
522430
COOP,
523431
true,
524-
false,
525432
cutlass::float_e5m2_t,
526433
float>(XQ, WQ, x_scale, w_scale, bias, output);
527434
} else {
@@ -536,7 +443,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
536443
PONG,
537444
COOP,
538445
true,
539-
false,
540446
cutlass::float_e4m3_t,
541447
float>(XQ, WQ, x_scale, w_scale, bias, output);
542448
}
@@ -553,7 +459,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
553459
PONG,
554460
COOP,
555461
false,
556-
false,
557462
cutlass::float_e5m2_t,
558463
float>(XQ, WQ, x_scale, w_scale, bias, output);
559464
} else {
@@ -568,7 +473,6 @@ at::Tensor f8f8bf16_rowwise_wrapper(
568473
PONG,
569474
COOP,
570475
false,
571-
false,
572476
cutlass::float_e4m3_t,
573477
float>(XQ, WQ, x_scale, w_scale, bias, output);
574478
}

0 commit comments

Comments
 (0)