Skip to content

Commit 76c16e5

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Make more kernels support 3D inputs. (#4476)
Summary: X-link: facebookresearch/FBGEMM#1533 Pull Request resolved: #4476 Some of our kernels still require X to be 2D, but there are occasionally cases where we'd like it to include a batch dimension. This diff patches up support for such caes across a few kernels. Reviewed By: jerryzh168, jiawenliu64 Differential Revision: D78171092 fbshipit-source-id: 9d3cdd71da2b0b0b9ff3fa6e67700dc331c82c34
1 parent bb85def commit 76c16e5

File tree

7 files changed

+103
-38
lines changed

7 files changed

+103
-38
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_blockwise_gemm.hip

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ at::Tensor f8f8bf16_blockwise_impl(
5959
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
6060

6161
// Get input information.
62-
int M = XQ.size(0);
63-
int N = WQ.size(0);
64-
int K = XQ.size(1);
62+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
63+
int K = XQ.size(-1);
64+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
65+
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
66+
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
67+
auto out_sizes = XQ.sizes().vec();
68+
out_sizes.back() = N;
6569

6670
// Create output tensor.
67-
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
71+
auto Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));
6872
// If inputs are empty return an empty tensor.
6973
if (M == 0 || N == 0 || K == 0) {
7074
return Y;

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_tensorwise_gemm.hip

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,18 +62,27 @@ template <
6262
at::Tensor
6363
f8f8bf16_tensorwise_impl(at::Tensor XQ, at::Tensor WQ, double scale) {
6464
// Get input information.
65-
int M = XQ.size(0);
66-
int N = WQ.size(0);
67-
int K = XQ.size(1);
65+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
66+
int K = XQ.size(-1);
67+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
68+
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
69+
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
70+
auto out_sizes = XQ.sizes().vec();
71+
out_sizes.back() = N;
6872

6973
int StrideA = K;
7074
int StrideB = K;
7175
int StrideC = N;
7276

77+
// Handle case where inputs are empty.
78+
if (M == 0 || N == 0 || K == 0) {
79+
return at::zeros(out_sizes, XQ.options().dtype(at::kBFloat16));
80+
}
81+
7382
TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous());
7483
TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous());
7584

76-
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
85+
auto Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16));
7786

7887
using ADataType = ck::f8_t;
7988
using BDataType = ck::f8_t;
@@ -185,9 +194,9 @@ f8f8bf16_tensorwise_impl(at::Tensor XQ, at::Tensor WQ, double scale) {
185194
enum class KernelMode { Small, Large, Default };
186195

187196
std::tuple<KernelMode, bool> get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
188-
auto M = XQ.size(0);
189-
auto K = XQ.size(1);
190-
auto N = WQ.size(0);
197+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
198+
int K = XQ.size(-1);
199+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
191200
// Use small kernel when input matrices are small.
192201
bool use_small_kernel = (M <= 512 && N <= 512) || (M <= 128) || (N <= 128);
193202
// For larger workloads, specialize to large gemm.

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ at::Tensor f8f8bf16_impl(
4040
// WQ: N x K
4141
// output: M x N
4242
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
43-
int N = WQ.size(0);
44-
int K = WQ.size(1);
43+
int K = XQ.size(-1);
44+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
4545
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
4646
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
4747
auto out_sizes = XQ.sizes().vec();

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_tensorwise.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ at::Tensor f8f8bf16_tensorwise_impl(
4242
// WQ: N x K
4343
// output: M x N
4444
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
45-
int N = WQ.size(0);
46-
int K = WQ.size(1);
45+
int K = XQ.size(-1);
46+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
4747
// 1. If the input tensor is {M, K}, the output tensor is {M, N}.
4848
// 2. If the input tensor is {b, M, K}, the output tensor is {b, M, N}.
4949
auto out_sizes = XQ.sizes().vec();

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/include/kernel_mode.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ namespace fbgemm_gpu {
1515
enum class KernelMode { Small, Medium, Large, Default };
1616

1717
inline KernelMode get_kernel_mode(at::Tensor XQ, at::Tensor WQ) {
18-
auto M = XQ.size(0);
19-
auto K = XQ.size(1);
20-
auto N = WQ.size(0);
18+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
19+
int K = XQ.size(-1);
20+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
2121
// Use a large kernel if at least two shapes are large....
2222
bool use_large_kernel =
2323
((M >= 2048 && K >= 2048) || (M >= 2048 && N >= 2048) ||

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp

Lines changed: 71 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -489,9 +489,22 @@ at::Tensor f8f8bf16_blockwise_meta(
489489
int64_t /* block_m = 128*/,
490490
int64_t /* block_n = 128*/,
491491
int64_t /* block_k = 128*/) {
492-
const at::SymInt M = XQ.sym_size(0);
493-
const at::SymInt N = WQ.sym_size(0);
494-
auto Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16));
492+
int64_t x_dims = XQ.dim();
493+
int64_t w_dims = WQ.dim();
494+
TORCH_CHECK(
495+
(x_dims == 2 || x_dims == 3) && (w_dims == 2),
496+
"The dim of XQ must be 2 or 3, and dim of WQ must be 2");
497+
at::Tensor Y;
498+
if (x_dims == 2) {
499+
const at::SymInt M = XQ.sym_size(0);
500+
const at::SymInt N = WQ.sym_size(0);
501+
Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16));
502+
} else {
503+
const at::SymInt B = XQ.sym_size(0);
504+
const at::SymInt M = XQ.sym_size(1);
505+
const at::SymInt N = WQ.sym_size(0);
506+
Y = at::empty_symint({B, M, N}, XQ.options().dtype(at::kBFloat16));
507+
}
495508
return Y;
496509
}
497510

@@ -575,13 +588,26 @@ at::Tensor fp8fp8bf16_fast_gemv_meta(
575588
}
576589

577590
at::Tensor f8f8bf16_tensorwise_meta(
578-
at::Tensor X,
579-
at::Tensor W,
580-
double scale,
581-
bool use_fast_accum = true) {
582-
const at::SymInt M = X.sym_size(0);
583-
const at::SymInt N = W.sym_size(0);
584-
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
591+
at::Tensor XQ,
592+
at::Tensor WQ,
593+
double /* scale */,
594+
bool /* use_fast_accum = true */) {
595+
int64_t x_dims = XQ.dim();
596+
int64_t w_dims = WQ.dim();
597+
TORCH_CHECK(
598+
(x_dims == 2 || x_dims == 3) && (w_dims == 2),
599+
"The dim of XQ must be 2 or 3, and dim of WQ must be 2");
600+
at::Tensor Y;
601+
if (x_dims == 2) {
602+
const at::SymInt M = XQ.sym_size(0);
603+
const at::SymInt N = WQ.sym_size(0);
604+
Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16));
605+
} else {
606+
const at::SymInt B = XQ.sym_size(0);
607+
const at::SymInt M = XQ.sym_size(1);
608+
const at::SymInt N = WQ.sym_size(0);
609+
Y = at::empty_symint({B, M, N}, XQ.options().dtype(at::kBFloat16));
610+
}
585611
return Y;
586612
}
587613

@@ -595,12 +621,25 @@ at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) {
595621
at::Tensor f8i4bf16_rowwise_meta(
596622
at::Tensor XQ, // FP8
597623
at::Tensor WQ, // INT4
598-
at::Tensor x_scale,
599-
at::Tensor w_scale,
600-
at::Tensor w_zp) {
601-
const at::SymInt M = XQ.sym_size(0);
602-
const at::SymInt N = WQ.sym_size(0);
603-
auto Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16));
624+
at::Tensor /* x_scale */,
625+
at::Tensor /* w_scale */,
626+
at::Tensor /* w_zp */) {
627+
int64_t x_dims = XQ.dim();
628+
int64_t w_dims = WQ.dim();
629+
TORCH_CHECK(
630+
(x_dims == 2 || x_dims == 3) && (w_dims == 2),
631+
"The dim of X must be 2 or 3, and dim of W must be 2");
632+
at::Tensor Y;
633+
if (x_dims == 2) {
634+
const at::SymInt M = XQ.sym_size(0);
635+
const at::SymInt N = WQ.sym_size(0);
636+
Y = at::empty_symint({M, N}, XQ.options().dtype(at::kBFloat16));
637+
} else {
638+
const at::SymInt B = XQ.sym_size(0);
639+
const at::SymInt M = XQ.sym_size(1);
640+
const at::SymInt N = WQ.sym_size(0);
641+
Y = at::empty_symint({B, M, N}, XQ.options().dtype(at::kBFloat16));
642+
}
604643
return Y;
605644
}
606645

@@ -632,9 +671,22 @@ at::Tensor bf16i4bf16_rowwise_meta(
632671
at::Tensor /* w_scale_group */,
633672
at::Tensor /* w_zero_group */
634673
) {
635-
const at::SymInt M = X.sym_size(0);
636-
const at::SymInt N = W.sym_size(0);
637-
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
674+
int64_t x_dims = X.dim();
675+
int64_t w_dims = W.dim();
676+
TORCH_CHECK(
677+
(x_dims == 2 || x_dims == 3) && (w_dims == 2),
678+
"The dim of XQ must be 2 or 3, and dim of WQ must be 2");
679+
at::Tensor Y;
680+
if (x_dims == 2) {
681+
const at::SymInt M = X.sym_size(0);
682+
const at::SymInt N = W.sym_size(0);
683+
Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
684+
} else {
685+
const at::SymInt B = X.sym_size(0);
686+
const at::SymInt M = X.sym_size(1);
687+
const at::SymInt N = W.sym_size(0);
688+
Y = at::empty_symint({B, M, N}, X.options().dtype(at::kBFloat16));
689+
}
638690
return Y;
639691
}
640692

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def test_quantize_fp8_matmul(
300300
if torch.version.hip:
301301
UseFastAccum = True
302302
# Setup input shapes.
303-
if InputMultiDim and not torch.version.hip:
303+
if InputMultiDim:
304304
x = (
305305
torch.randn(
306306
size=(3, B_T, D),

0 commit comments

Comments
 (0)