|
| 1 | +/* |
| 2 | + * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | + * All rights reserved. |
| 4 | + * |
| 5 | + * This source code is licensed under the BSD-style license found in the |
| 6 | + * LICENSE file in the root directory of this source tree. |
| 7 | + */ |
| 8 | + |
| 9 | +#include <ATen/ATen.h> |
| 10 | + |
| 11 | +#include "kernels/fp8fp8fp16_rowwise_kernel_manifest.h" |
| 12 | + |
| 13 | +namespace fbgemm_gpu { |
| 14 | +namespace { |
| 15 | + |
| 16 | +using RowwiseKernel = std::function< |
| 17 | + at::Tensor(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>; |
| 18 | + |
| 19 | +RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) { |
| 20 | + // Apply shape heuristics to find a suitable kernel implementation. |
| 21 | + |
| 22 | + //Fallback for irregular data types: some instances require K to be a multiple |
| 23 | + //of K Tile. |
| 24 | + //To-Do: Need a systemic solution for various restrictions from different |
| 25 | + //instances. |
| 26 | + if(!((N % 8 == 0) && (K % 16 == 0))) { |
| 27 | + return fp8fp8f16_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; |
| 28 | + } |
| 29 | + |
| 30 | + if (M < 64 && N < 2048 && K < 2048) { |
| 31 | + // Kernel that generally works well on small shapes. |
| 32 | + return fp8fp8f16_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1; |
| 33 | + } else if (M < 64 && K < 2048) { |
| 34 | + // Kernel that works well for small batch size and small K. |
| 35 | + return fp8fp8f16_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1; |
| 36 | + } else if (M < 64 && N < 2048) { |
| 37 | + // Kernel that works well for small batch size and small N. |
| 38 | + return fp8fp8f16_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1; |
| 39 | + } else if (M < 64 && N > 2048 && K > 2048) { |
| 40 | + // Kernel that works well for small M but larger N and K. |
| 41 | + return fp8fp8f16_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; |
| 42 | + } else if (M < 64) { |
| 43 | + // Fallback to generic small batch kernel if we cant find a good match. |
| 44 | + return fp8fp8f16_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1; |
| 45 | + } else if ( |
| 46 | + ((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || |
| 47 | + (K <= 2048 && N <= 8192)) && |
| 48 | + K >= 1024) { |
| 49 | + // Kernel that is optimized for larger batch sizes but otherwise small |
| 50 | + // tensors. |
| 51 | + return fp8fp8f16_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5; |
| 52 | + } else if (K < 1024) { |
| 53 | + // Special case for small K. |
| 54 | + return fp8fp8f16_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; |
| 55 | + } else if (M < 1024) { |
| 56 | + // Kernel for generic medium batch sizes. |
| 57 | + return fp8fp8f16_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; |
| 58 | + } else if (M >= 1024 && N >= 1024 && K >= 1024) { |
| 59 | + // Kernel for very large gemm |
| 60 | + return fp8fp8f16_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; |
| 61 | + } else { |
| 62 | + // Fallback large kernel. |
| 63 | + return fp8fp8f16_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +at::Tensor f8f8_rowwise_wrapper( |
| 68 | + at::Tensor XQ, |
| 69 | + at::Tensor WQ, |
| 70 | + at::Tensor x_scale, |
| 71 | + at::Tensor w_scale, |
| 72 | + std::optional<at::Tensor> bias, |
| 73 | + bool use_fast_accum, |
| 74 | + std::optional<at::Tensor> output = std::nullopt) { |
| 75 | + // Check that input datatypes are valid. |
| 76 | + TORCH_CHECK( |
| 77 | + (XQ.dtype() == at::kFloat8_e4m3fnuz) && |
| 78 | + (WQ.dtype() == at::kFloat8_e4m3fnuz), |
| 79 | + "Inputs must be type float8_e4m3fnuz."); |
| 80 | + TORCH_CHECK( |
| 81 | + (x_scale.dtype() == at::kFloat) && (w_scale.dtype() == at::kFloat), |
| 82 | + "Scales must be float32."); |
| 83 | + TORCH_CHECK(use_fast_accum, "AMD does not support disabling use_fast_accum."); |
| 84 | + TORCH_CHECK(!bias.has_value(), "AMD does not support fused bias."); |
| 85 | + |
| 86 | + // Check inputs are in expected format. |
| 87 | + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); |
| 88 | + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); |
| 89 | + |
| 90 | + // XQ: M x K |
| 91 | + // WQ: N x K |
| 92 | + // output: M x N |
| 93 | + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); |
| 94 | + int N = WQ.size(0); |
| 95 | + int K = WQ.size(1); |
| 96 | + // Compute target output sizes. |
| 97 | + auto out_sizes = XQ.sizes().vec(); |
| 98 | + out_sizes.back() = N; |
| 99 | + // Handle case where an input dimension is zero. |
| 100 | + if (M == 0 || N == 0 || K == 0) { |
| 101 | + // Return a tensor of zeros to handle case where K is 0. |
| 102 | + return at::zeros(out_sizes, XQ.options().dtype(at::kHalf)); |
| 103 | + } |
| 104 | + |
| 105 | + // Prepare output tensor if needed. |
| 106 | + at::Tensor Y; |
| 107 | + if (output.has_value()) { |
| 108 | + Y = output.value(); |
| 109 | + // Make sure the provided output has the proper shape and dtype. |
| 110 | + int Y_M = size_to_dim_(Y.dim() - 1, Y.sizes()); |
| 111 | + TORCH_CHECK(Y_M == M && Y.sizes().vec().back() == N); |
| 112 | + TORCH_CHECK(Y.dtype() == at::kHalf); |
| 113 | + } else { |
| 114 | + Y = at::empty(out_sizes, XQ.options().dtype(at::kHalf)); |
| 115 | + } |
| 116 | + |
| 117 | + RowwiseKernel rowwise_impl = rowwise_heuristic_dispatch(M, N, K); |
| 118 | + return rowwise_impl(XQ, WQ, x_scale, w_scale, Y); |
| 119 | +} |
| 120 | +} // namespace |
| 121 | + |
| 122 | +at::Tensor f8f8f16_rowwise( |
| 123 | + at::Tensor XQ, |
| 124 | + at::Tensor WQ, |
| 125 | + at::Tensor x_scale, |
| 126 | + at::Tensor w_scale, |
| 127 | + std::optional<at::Tensor> bias, |
| 128 | + bool use_fast_accum) { |
| 129 | + // Invoke f8f8f16 rowwise without preallocated output. |
| 130 | + return f8f8_rowwise_wrapper( |
| 131 | + XQ, WQ, x_scale, w_scale, bias, use_fast_accum); |
| 132 | +} |
| 133 | + |
| 134 | +void f8f8f16_rowwise_out( |
| 135 | + at::Tensor XQ, |
| 136 | + at::Tensor WQ, |
| 137 | + at::Tensor x_scale, |
| 138 | + at::Tensor w_scale, |
| 139 | + at::Tensor output, |
| 140 | + std::optional<at::Tensor> bias, |
| 141 | + bool use_fast_accum) { |
| 142 | + // Invoke f8f8f16 rowwise with preallocated output. |
| 143 | + f8f8_rowwise_wrapper( |
| 144 | + XQ, WQ, x_scale, w_scale, bias, use_fast_accum, output); |
| 145 | +} |
| 146 | + |
| 147 | +} // namespace fbgemm_gpu |
0 commit comments