-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Description
Report needed documentation
Report needed documentation
I am trying to write a gemm for BF16xFP4 using cutlass. I want to know if cutlass supports this operation on the 5090.
Describe the documentation you'd like
There is a document that records the mixed precision calculations supported and unsupported by various platforms.
Steps taken to search for needed documentation
I am trying to write a gemm for BF16xFP4 using cutlass. Here is some of my code:
template <typename Config, typename OutType> struct Bf16_Fp4_GemmSm120 {
// using ElementA = cutlass::bfloat16_t;
using ElementA = __nv_bfloat16;
using LayoutATag = cutlass::layout::RowMajor;
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
using ElementB = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
using LayoutBTag = cutlass::layout::ColumnMajor;
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
using ElementD = OutType;
using ElementC = OutType;
using LayoutCTag = cutlass::layout::RowMajor;
using LayoutDTag = cutlass::layout::RowMajor;
static constexpr int AlignmentD =
128 / cutlass::sizeof_bits<ElementD>::value;
static constexpr int AlignmentC =
128 / cutlass::sizeof_bits<ElementC>::value;
using ElementAccumulator = float;
using ArchTag = cutlass::arch::Sm120;
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
using MmaTileShape = typename Config::MmaTileShape;
using ClusterShape = typename Config::ClusterShape;
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, PerSmTileShape_MNK, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
ElementAccumulator, ElementC, LayoutCTag, AlignmentC, ElementD,
LayoutDTag, AlignmentD,
cutlass::epilogue::collective::EpilogueScheduleAuto>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass, ElementA, LayoutATag, AlignmentA, ElementB,
LayoutBTag, AlignmentB, ElementAccumulator, MmaTileShape,
ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
using GemmKernel =
cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>,
CollectiveMainloop,
CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
};
when I compile my code, I get this error:
.../cutlass/gemm/collective/collective_builder_decl.hpp(94): error: static assertion failed with "Could not build a collective for given parameters." static_assert(sizeof(ElementA) == 0, "Could not build a collective for given parameters.");
here is my compile command:
/home/cheny/miniconda3/envs/venv_chitu_2/bin/nvcc --generate-dependencies-with-compile --dependency-output /home/cheny/cinfer/build/temp.linux-x86_64-cpython-312/home/cheny/cinfer/csrc/cuda/hard_fp4/nvfp4_scaled_mm_kernels.o.d -I/home/cheny/cinfer/csrc/../third_party/spdlog/include -I/home/cheny/cinfer/csrc/cuda/common -I/home/cheny/miniconda3/envs/venv_chitu_2/lib/python3.12/site-packages/torch/include -I/home/cheny/miniconda3/envs/venv_chitu_2/lib/python3.12/site-packages/torch/include/torch/csrc/api/include -I/home/cheny/miniconda3/envs/venv_chitu_2/include -I/home/cheny/miniconda3/envs/venv_chitu_2/include/python3.12 -c -c /home/cheny/cinfer/csrc/cuda/hard_fp4/nvfp4_scaled_mm_kernels.cu -o /home/cheny/cinfer/build/temp.linux-x86_64-cpython-312/home/cheny/cinfer/csrc/cuda/hard_fp4/nvfp4_scaled_mm_kernels.o -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -DENABLE_NVFP4 -gencode=arch=compute_120a,code=sm_120a -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1016"' -DTORCH_EXTENSION_NAME=chitu_backend -D_GLIBCXX_USE_CXX11_ABI=1 -ccbin /home/cheny/miniconda3/envs/venv_chitu_2/bin/x86_64-conda-linux-gnu-cc -std=c++17
I want to know if it was my programming error or if Cutlass currently does not support such mixed precision calculations.
This is my environmental information:
GPU: NVIDIA GeForce RTX 5090
nvcc: release 12.8, V12.8.93
cmake: 4.0.3