Skip to content

Commit e835e29

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Minor improvement of FP4 GEMM for Llama4 shapes (#4459)
Summary: Pull Request resolved: #4459 X-link: facebookresearch/FBGEMM#1519 There are some minor performance improvement of FP4 GEMM for some Llama4 shapes, while the previous general heuristics provide the best tiling/cluster for most of Llama4 shapes Reviewed By: cthi Differential Revision: D77984623 fbshipit-source-id: 3d057dab8968270584d45f14e9c10581999efadf
1 parent b4137fc commit e835e29

22 files changed

+913
-1
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f4f4bf16.cu

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ at::Tensor dispatch_f4f4bf16_kernel(
2727
std::optional<at::Tensor> global_scale,
2828
bool use_mx = true) {
2929
auto M = XQ.size(0);
30-
auto K = XQ.size(1);
3130
auto N = WQ.size(0);
31+
auto K = XQ.size(1) * 2; // Since K is packed
3232
auto BLOCK_SIZE = 16;
3333
TORCH_CHECK(
3434
N % BLOCK_SIZE == 0 && K % BLOCK_SIZE == 0,
@@ -45,6 +45,62 @@ at::Tensor dispatch_f4f4bf16_kernel(
4545
return f4f4bf16_128_128_4_1_1_t(XQ, WQ, x_scale, w_scale, global_scale);
4646
}
4747
} else if (M <= 2048) {
48+
if (M <= 256) {
49+
if (N == 896) {
50+
return f4f4bf16_128_128_2_2_1_t(
51+
XQ, WQ, x_scale, w_scale, global_scale);
52+
} else if (N == 5120) {
53+
if (K == 640 || K == 5120) {
54+
return f4f4bf16_128_128_4_1_1_t(
55+
XQ, WQ, x_scale, w_scale, global_scale);
56+
} else if ((K == 8192) || (K == 16384)) {
57+
return f4f4bf16_256_128_2_2_1_t(
58+
XQ, WQ, x_scale, w_scale, global_scale);
59+
}
60+
} else if (N == 5632) {
61+
return f4f4bf16_128_192_2_2_1_t(
62+
XQ, WQ, x_scale, w_scale, global_scale);
63+
} else if (N == 8192) {
64+
return f4f4bf16_256_128_2_2_1_t(
65+
XQ, WQ, x_scale, w_scale, global_scale);
66+
}
67+
} else if (M <= 512) {
68+
if (N == 896) {
69+
return f4f4bf16_128_128_2_2_1_t(
70+
XQ, WQ, x_scale, w_scale, global_scale);
71+
} else if (N == 5120) {
72+
return f4f4bf16_256_192_4_1_1_t(
73+
XQ, WQ, x_scale, w_scale, global_scale);
74+
} else if (N == 5632) {
75+
return f4f4bf16_256_128_2_4_1_t(
76+
XQ, WQ, x_scale, w_scale, global_scale);
77+
} else if (N == 8192) {
78+
return f4f4bf16_256_128_2_2_1_t(
79+
XQ, WQ, x_scale, w_scale, global_scale);
80+
}
81+
} else if (M <= 1024) {
82+
if (N == 896) {
83+
return f4f4bf16_256_128_2_4_1_t(
84+
XQ, WQ, x_scale, w_scale, global_scale);
85+
} else if (N == 5120) {
86+
if (K == 640) {
87+
return f4f4bf16_128_128_1_4_1_t(
88+
XQ, WQ, x_scale, w_scale, global_scale);
89+
} else if (K == 5120) {
90+
return f4f4bf16_128_192_4_2_1_t(
91+
XQ, WQ, x_scale, w_scale, global_scale);
92+
} else if (K == 5120 || K == 16384) {
93+
return f4f4bf16_256_128_2_4_1_t(
94+
XQ, WQ, x_scale, w_scale, global_scale);
95+
}
96+
} else if (N == 5632) {
97+
return f4f4bf16_256_128_2_4_1_t(
98+
XQ, WQ, x_scale, w_scale, global_scale);
99+
} else if (N == 8192) {
100+
return f4f4bf16_256_256_4_1_1_t(
101+
XQ, WQ, x_scale, w_scale, global_scale);
102+
}
103+
}
48104
if (N <= 2048) {
49105
return f4f4bf16_256_128_2_2_1_t(XQ, WQ, x_scale, w_scale, global_scale);
50106
} else if (N <= 8192) {
@@ -111,6 +167,62 @@ at::Tensor dispatch_f4f4bf16_kernel(
111167
return f4f4bf16_128_128_4_1_1_f(XQ, WQ, x_scale, w_scale, global_scale);
112168
}
113169
} else if (M <= 2048) {
170+
if (M <= 256) {
171+
if (N == 896) {
172+
return f4f4bf16_128_128_2_2_1_f(
173+
XQ, WQ, x_scale, w_scale, global_scale);
174+
} else if (N == 5120) {
175+
if (K == 640 || K == 5120) {
176+
return f4f4bf16_128_128_4_1_1_f(
177+
XQ, WQ, x_scale, w_scale, global_scale);
178+
} else if ((K == 8192) || (K == 16384)) {
179+
return f4f4bf16_256_128_2_2_1_f(
180+
XQ, WQ, x_scale, w_scale, global_scale);
181+
}
182+
} else if (N == 5632) {
183+
return f4f4bf16_128_192_2_2_1_f(
184+
XQ, WQ, x_scale, w_scale, global_scale);
185+
} else if (N == 8192 || N == 16384) {
186+
return f4f4bf16_256_128_2_2_1_f(
187+
XQ, WQ, x_scale, w_scale, global_scale);
188+
}
189+
} else if (M <= 512) {
190+
if (N == 896) {
191+
return f4f4bf16_128_128_2_2_1_f(
192+
XQ, WQ, x_scale, w_scale, global_scale);
193+
} else if (N == 5120) {
194+
return f4f4bf16_256_192_4_1_1_f(
195+
XQ, WQ, x_scale, w_scale, global_scale);
196+
} else if (N == 5632) {
197+
return f4f4bf16_256_128_2_4_1_f(
198+
XQ, WQ, x_scale, w_scale, global_scale);
199+
} else if (N == 8192) {
200+
return f4f4bf16_256_128_2_2_1_f(
201+
XQ, WQ, x_scale, w_scale, global_scale);
202+
}
203+
} else if (M <= 1024) {
204+
if (N == 896) {
205+
return f4f4bf16_256_128_2_4_1_f(
206+
XQ, WQ, x_scale, w_scale, global_scale);
207+
} else if (N == 5120) {
208+
if (K == 640) {
209+
return f4f4bf16_128_128_1_4_1_f(
210+
XQ, WQ, x_scale, w_scale, global_scale);
211+
} else if (K == 5120) {
212+
return f4f4bf16_128_192_4_2_1_f(
213+
XQ, WQ, x_scale, w_scale, global_scale);
214+
} else if (K == 5120 || K == 16384) {
215+
return f4f4bf16_256_128_2_4_1_f(
216+
XQ, WQ, x_scale, w_scale, global_scale);
217+
}
218+
} else if (N == 5632) {
219+
return f4f4bf16_256_128_2_4_1_f(
220+
XQ, WQ, x_scale, w_scale, global_scale);
221+
} else if (N == 8192) {
222+
return f4f4bf16_256_256_4_1_1_f(
223+
XQ, WQ, x_scale, w_scale, global_scale);
224+
}
225+
}
114226
if (N <= 2048) {
115227
return f4f4bf16_256_128_2_2_1_f(XQ, WQ, x_scale, w_scale, global_scale);
116228
} else if (N <= 8192) {
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_1_1_1_f(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
1,
27+
1,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_1_1_1_t(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::mx_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
1,
27+
1,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_1_2_1_f(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
1,
27+
2,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_1_2_1_t(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::mx_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
1,
27+
2,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_1_4_1_f(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
1,
27+
4,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_1_4_1_t(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::mx_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
1,
27+
4,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "f4f4bf16_common.cuh"
10+
11+
namespace fbgemm_gpu {
12+
13+
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12080)
14+
15+
at::Tensor f4f4bf16_128_128_2_2_1_f(
16+
at::Tensor XQ, // FP4
17+
at::Tensor WQ, // FP4
18+
at::Tensor x_scale,
19+
at::Tensor w_scale,
20+
std::optional<at::Tensor> global_scale = std::nullopt) {
21+
// Dispatch this kernel to the correct underlying implementation.
22+
return _f4f4bf16<
23+
cutlass::nv_float4_t<cutlass::float_e2m1_t>,
24+
128,
25+
128,
26+
2,
27+
2,
28+
1>(XQ, WQ, x_scale, w_scale, global_scale);
29+
}
30+
31+
#endif
32+
33+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)