Skip to content

Commit 0dbd1bc

Browse files
cthifacebook-github-bot
authored andcommitted
Split bf16 and fp16 out for CK fp8_rowwise (#4419)
Summary: X-link: facebookresearch/FBGEMM#1490 Pull Request resolved: #4419 While integrating AMD fp8 rowwise in torch, I noticed the library size is a lot larger than FBGEMM 1.2.0. The cause is we recently added support for fp16 output (in addition to bf16) in fp8 rowwise in D74770197, so this would cause the size to double. You can also see this in the nightly wheel. This diff does 2 changes: - Split out the fp16 out into their own files, while sharing common CK kernel implementation template `fp8_rowwise_common.h`, this will let us control which kernels to add in, e.g. only add bf16 out for torch. - For bf16 we will use `fp8fp8bf16` filename prefix - For fp16 we will use `fp8fp8fp16` filename prefix - The fp16 output is for a future rec-sys related use-case, so we can probably remove the llama specific tuning/optimized kernels for now (also reducing the size for that kernel). We can add new tuning to the fp16 out use-case for recsys after. Reviewed By: jiawenliu64 Differential Revision: D77544503 fbshipit-source-id: 75a5d42b202df9b43c63fafddb0c53e3ffd97f19
1 parent f6100fc commit 0dbd1bc

File tree

102 files changed

+1569
-891
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

102 files changed

+1569
-891
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip

Lines changed: 0 additions & 590 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8fp8bf16_rowwise_gemm.hip

Lines changed: 579 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
11+
#include "kernels/fp8fp8fp16_rowwise_kernel_manifest.h"
12+
13+
namespace fbgemm_gpu {
14+
namespace {
15+
16+
using RowwiseKernel = std::function<
17+
at::Tensor(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>;
18+
19+
RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
20+
// Apply shape heuristics to find a suitable kernel implementation.
21+
22+
//Fallback for irregular data types: some instances require K to be a multiple
23+
//of K Tile.
24+
//To-Do: Need a systemic solution for various restrictions from different
25+
//instances.
26+
if(!((N % 8 == 0) && (K % 16 == 0))) {
27+
return fp8fp8f16_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
28+
}
29+
30+
if (M < 64 && N < 2048 && K < 2048) {
31+
// Kernel that generally works well on small shapes.
32+
return fp8fp8f16_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1;
33+
} else if (M < 64 && K < 2048) {
34+
// Kernel that works well for small batch size and small K.
35+
return fp8fp8f16_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1;
36+
} else if (M < 64 && N < 2048) {
37+
// Kernel that works well for small batch size and small N.
38+
return fp8fp8f16_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1;
39+
} else if (M < 64 && N > 2048 && K > 2048) {
40+
// Kernel that works well for small M but larger N and K.
41+
return fp8fp8f16_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
42+
} else if (M < 64) {
43+
// Fallback to generic small batch kernel if we cant find a good match.
44+
return fp8fp8f16_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1;
45+
} else if (
46+
((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) ||
47+
(K <= 2048 && N <= 8192)) &&
48+
K >= 1024) {
49+
// Kernel that is optimized for larger batch sizes but otherwise small
50+
// tensors.
51+
return fp8fp8f16_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5;
52+
} else if (K < 1024) {
53+
// Special case for small K.
54+
return fp8fp8f16_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1;
55+
} else if (M < 1024) {
56+
// Kernel for generic medium batch sizes.
57+
return fp8fp8f16_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3;
58+
} else if (M >= 1024 && N >= 1024 && K >= 1024) {
59+
// Kernel for very large gemm
60+
return fp8fp8f16_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
61+
} else {
62+
// Fallback large kernel.
63+
return fp8fp8f16_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3;
64+
}
65+
}
66+
67+
at::Tensor f8f8_rowwise_wrapper(
68+
at::Tensor XQ,
69+
at::Tensor WQ,
70+
at::Tensor x_scale,
71+
at::Tensor w_scale,
72+
std::optional<at::Tensor> bias,
73+
bool use_fast_accum,
74+
std::optional<at::Tensor> output = std::nullopt) {
75+
// Check that input datatypes are valid.
76+
TORCH_CHECK(
77+
(XQ.dtype() == at::kFloat8_e4m3fnuz) &&
78+
(WQ.dtype() == at::kFloat8_e4m3fnuz),
79+
"Inputs must be type float8_e4m3fnuz.");
80+
TORCH_CHECK(
81+
(x_scale.dtype() == at::kFloat) && (w_scale.dtype() == at::kFloat),
82+
"Scales must be float32.");
83+
TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum.");
84+
TORCH_CHECK(!bias.has_value(), "AMD does not support fused bias.");
85+
86+
// Check inputs are in expected format.
87+
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
88+
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
89+
90+
// XQ: M x K
91+
// WQ: N x K
92+
// output: M x N
93+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
94+
int N = WQ.size(0);
95+
int K = WQ.size(1);
96+
// Compute target output sizes.
97+
auto out_sizes = XQ.sizes().vec();
98+
out_sizes.back() = N;
99+
// Handle case where an input dimension is zero.
100+
if (M == 0 || N == 0 || K == 0) {
101+
// Return a tensor of zeros to handle case where K is 0.
102+
return at::zeros(out_sizes, XQ.options().dtype(at::kHalf));
103+
}
104+
105+
// Prepare output tensor if needed.
106+
at::Tensor Y;
107+
if (output.has_value()) {
108+
Y = output.value();
109+
// Make sure the provided output has the proper shape and dtype.
110+
int Y_M = size_to_dim_(Y.dim() - 1, Y.sizes());
111+
TORCH_CHECK(Y_M == M && Y.sizes().vec().back() == N);
112+
TORCH_CHECK(Y.dtype() == at::kHalf);
113+
} else {
114+
Y = at::empty(out_sizes, XQ.options().dtype(at::kHalf));
115+
}
116+
117+
RowwiseKernel rowwise_impl = rowwise_heuristic_dispatch(M, N, K);
118+
return rowwise_impl(XQ, WQ, x_scale, w_scale, Y);
119+
}
120+
} // namespace
121+
122+
at::Tensor f8f8f16_rowwise(
123+
at::Tensor XQ,
124+
at::Tensor WQ,
125+
at::Tensor x_scale,
126+
at::Tensor w_scale,
127+
std::optional<at::Tensor> bias,
128+
bool use_fast_accum) {
129+
// Invoke f8f8f16 rowwise without preallocated output.
130+
return f8f8_rowwise_wrapper(
131+
XQ, WQ, x_scale, w_scale, bias, use_fast_accum);
132+
}
133+
134+
void f8f8f16_rowwise_out(
135+
at::Tensor XQ,
136+
at::Tensor WQ,
137+
at::Tensor x_scale,
138+
at::Tensor w_scale,
139+
at::Tensor output,
140+
std::optional<at::Tensor> bias,
141+
bool use_fast_accum) {
142+
// Invoke f8f8f16 rowwise with preallocated output.
143+
f8f8_rowwise_wrapper(
144+
XQ, WQ, x_scale, w_scale, bias, use_fast_accum, output);
145+
}
146+
147+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)